-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[CPU]Avoid repeated random sample compile #28260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly moves the compiled_random_sample function to the module level to prevent it from being recompiled on every call, which is a good performance improvement. I've added a comment to use an in-place operation for a minor performance gain and for consistency with other parts of the code. More critically, I've identified a pre-existing bug in how this function is utilized within forward_cpu. When only a subset of requests provides custom random generators, they are currently ignored, and the sampling falls back to the default generator for all requests. This is a correctness issue that should be addressed. I've detailed this in a critical review comment.
| def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
| return probs.div(q).argmax(dim=-1).view(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While moving this function to the module level is a great performance optimization, its usage in forward_cpu reveals a pre-existing critical bug. forward_cpu calls this function when len(generators) != logits.shape[0]. This condition is met when some, but not all, requests have custom generators (0 < len(generators) < logits.shape[0]). In this scenario, compiled_random_sample is invoked, which uses the default torch generator for all requests, thereby silently ignoring the user-provided generators. This behavior is incorrect and inconsistent with random_sample, which correctly handles this mixed-generator case.
To fix this, the logic in forward_cpu should be adjusted. The most direct fix is to ensure the compiled path is only taken when no custom generators are provided, and let the un-compiled path handle all cases with generators. This would look like:
# In forward_cpu method
if not generators:
return compiled_random_sample(logits), logits_to_return
else:
# This logic correctly handles all cases with generators, including partial ones.
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
if len(generators) != probs.shape[0]:
# If not all requests have a generator, initialize with default.
q.exponential_()
# Overwrite with per-request generators where available.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_returnSince this bug is in code not directly modified by this PR, I recommend creating a follow-up issue or pull request to address this critical correctness issue.
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
| return probs.div(q).argmax(dim=-1).view(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other sampling implementations in this file (e.g., random_sample and the else branch in forward_cpu) and to avoid an unnecessary tensor allocation, consider using the in-place div_ operation.
| return probs.div(q).argmax(dim=-1).view(-1) | |
| return probs.div_(q).argmax(dim=-1).view(-1) |
|
Hi @xiangze-arm can you show the recompile log? I verified the main branch with |
|
Hi @bigPYJ1151, I checked the log and you are right, there are no actual recompilations. They are just repeated invocations of torch.compile decorator. Sampler timeline without this PR:
Sampler timeline with this PR:
|
bigPYJ1151
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then this PR can reduce some overheads.
LGTM, thanks for the fix :)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>


Purpose
Move
compiled_random_samplefunction to module level to avoid repeated compilation.Test Plan
Run examples/offline_inference/simple_profiling.py
Test Result
torch compile no longer show up inside sampler in trace timeline.