File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -127,15 +127,6 @@ def forward_cpu(
127127 elif self .logprobs_mode == "processed_logprobs" :
128128 logits_to_return = logits .log_softmax (dim = - 1 , dtype = torch .float32 )
129129
130- # Note: this is a workaround for
131- # https://github.com/pytorch/pytorch/pull/151218
132- @torch .compile (dynamic = True )
133- def compiled_random_sample (logits : torch .Tensor ) -> torch .Tensor :
134- probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
135- q = torch .empty_like (probs )
136- q .exponential_ ()
137- return probs .div (q ).argmax (dim = - 1 ).view (- 1 )
138-
139130 if len (generators ) != logits .shape [0 ]:
140131 return compiled_random_sample (logits ), logits_to_return
141132 else :
@@ -148,6 +139,16 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
148139 return probs .div_ (q ).argmax (dim = - 1 ).view (- 1 ), logits_to_return
149140
150141
142+ # Note: this is a workaround for
143+ # https://github.com/pytorch/pytorch/pull/151218
144+ @torch .compile (dynamic = True )
145+ def compiled_random_sample (logits : torch .Tensor ) -> torch .Tensor :
146+ probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
147+ q = torch .empty_like (probs )
148+ q .exponential_ ()
149+ return probs .div (q ).argmax (dim = - 1 ).view (- 1 )
150+
151+
151152def apply_top_k_top_p (
152153 logits : torch .Tensor ,
153154 k : torch .Tensor | None ,
You can’t perform that action at this time.
0 commit comments