Skip to content

Commit 4424e75

Browse files
xiangze-armxuebwang-amd
authored andcommitted
[CPU]Avoid repeated random sample compile (vllm-project#28260)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent a1eb363 commit 4424e75

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,6 @@ def forward_cpu(
127127
elif self.logprobs_mode == "processed_logprobs":
128128
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
129129

130-
# Note: this is a workaround for
131-
# https://github.com/pytorch/pytorch/pull/151218
132-
@torch.compile(dynamic=True)
133-
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
134-
probs = logits.softmax(dim=-1, dtype=torch.float32)
135-
q = torch.empty_like(probs)
136-
q.exponential_()
137-
return probs.div(q).argmax(dim=-1).view(-1)
138-
139130
if len(generators) != logits.shape[0]:
140131
return compiled_random_sample(logits), logits_to_return
141132
else:
@@ -148,6 +139,16 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
148139
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
149140

150141

142+
# Note: this is a workaround for
143+
# https://github.com/pytorch/pytorch/pull/151218
144+
@torch.compile(dynamic=True)
145+
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
146+
probs = logits.softmax(dim=-1, dtype=torch.float32)
147+
q = torch.empty_like(probs)
148+
q.exponential_()
149+
return probs.div(q).argmax(dim=-1).view(-1)
150+
151+
151152
def apply_top_k_top_p(
152153
logits: torch.Tensor,
153154
k: torch.Tensor | None,

0 commit comments

Comments
 (0)