Skip to content

Conversation

@xiangze-arm
Copy link
Contributor

@xiangze-arm xiangze-arm commented Nov 7, 2025

Purpose

Move compiled_random_sample function to module level to avoid repeated compilation.

Test Plan

Run examples/offline_inference/simple_profiling.py

Test Result

torch compile no longer show up inside sampler in trace timeline.

Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly moves the compiled_random_sample function to the module level to prevent it from being recompiled on every call, which is a good performance improvement. I've added a comment to use an in-place operation for a minor performance gain and for consistency with other parts of the code. More critically, I've identified a pre-existing bug in how this function is utilized within forward_cpu. When only a subset of requests provides custom random generators, they are currently ignored, and the sampling falls back to the default generator for all requests. This is a correctness issue that should be addressed. I've detailed this in a critical review comment.

Comment on lines +145 to +149
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While moving this function to the module level is a great performance optimization, its usage in forward_cpu reveals a pre-existing critical bug. forward_cpu calls this function when len(generators) != logits.shape[0]. This condition is met when some, but not all, requests have custom generators (0 < len(generators) < logits.shape[0]). In this scenario, compiled_random_sample is invoked, which uses the default torch generator for all requests, thereby silently ignoring the user-provided generators. This behavior is incorrect and inconsistent with random_sample, which correctly handles this mixed-generator case.

To fix this, the logic in forward_cpu should be adjusted. The most direct fix is to ensure the compiled path is only taken when no custom generators are provided, and let the un-compiled path handle all cases with generators. This would look like:

# In forward_cpu method

if not generators:
    return compiled_random_sample(logits), logits_to_return
else:
    # This logic correctly handles all cases with generators, including partial ones.
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    q = torch.empty_like(probs)
    if len(generators) != probs.shape[0]:
        # If not all requests have a generator, initialize with default.
        q.exponential_()
    # Overwrite with per-request generators where available.
    for i, generator in generators.items():
        q[i].exponential_(generator=generator)
    return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

Since this bug is in code not directly modified by this PR, I recommend creating a follow-up issue or pull request to address this critical correctness issue.

probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with other sampling implementations in this file (e.g., random_sample and the else branch in forward_cpu) and to avoid an unnecessary tensor allocation, consider using the in-place div_ operation.

Suggested change
return probs.div(q).argmax(dim=-1).view(-1)
return probs.div_(q).argmax(dim=-1).view(-1)

@bigPYJ1151
Copy link
Member

Hi @xiangze-arm can you show the recompile log? I verified the main branch with TORCH_LOGS=recompiles and only got 1 recompilation for bs > 1, which is in expectation.

@xiangze-arm
Copy link
Contributor Author

Hi @bigPYJ1151, I checked the log and you are right, there are no actual recompilations. They are just repeated invocations of torch.compile decorator.

Sampler timeline without this PR:

Screenshot 2025-11-07 163534

Sampler timeline with this PR:

Screenshot 2025-11-07 163553

Copy link
Member

@bigPYJ1151 bigPYJ1151 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then this PR can reduce some overheads.
LGTM, thanks for the fix :)

@bigPYJ1151 bigPYJ1151 enabled auto-merge (squash) November 7, 2025 08:51
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 7, 2025
@bigPYJ1151 bigPYJ1151 merged commit 7bdb42b into vllm-project:main Nov 7, 2025
49 checks passed
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants