diff --git a/gemma/model.py b/gemma/model.py index 87280d2..6edbd52 100644 --- a/gemma/model.py +++ b/gemma/model.py @@ -469,6 +469,7 @@ def generate( batch_size = len(prompts) prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] + prompt_length = [len(p) for p in prompt_tokens] min_prompt_len = min(len(p) for p in prompt_tokens) max_prompt_len = max(len(p) for p in prompt_tokens) max_seq_len = max_prompt_len + output_len @@ -511,6 +512,7 @@ def generate( top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( device) + eos_flags_tensor = torch.tensor([False] * batch_size).to(device) # Prefill up to min_prompt_len tokens, then treat other prefill as # decode and ignore output. @@ -543,6 +545,16 @@ def generate( device) output_index = output_index + 1 + # Check if all sequences have reached EOS. + batch_eos_idx = (next_token_ids == self.tokenizer.eos_id).nonzero( + as_tuple=True)[0] + for eos_idx in batch_eos_idx: + if output_index >= prompt_length[eos_idx]: + eos_flags_tensor[eos_idx] = True + + if eos_flags_tensor.all(): + break + # Detokenization. token_ids = token_ids_tensor.tolist() results = [] diff --git a/scripts/run_xla.py b/scripts/run_xla.py index 4240fa7..d3beeb6 100644 --- a/scripts/run_xla.py +++ b/scripts/run_xla.py @@ -134,6 +134,7 @@ def generate( input_token_ids_tensor = torch.full((batch_size, min_prompt_len), tokenizer.pad_id, dtype=torch.int64) + prompt_length = [len(p) for p in prompt_tokens] for i, p in enumerate(prompt_tokens): token_ids_tensor[i, :len(p)] = torch.tensor(p) input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( @@ -152,9 +153,10 @@ def generate( top_ps_tensor = torch.FloatTensor(top_ps).to(device) top_ks_tensor = torch.LongTensor(top_ks).to(device) output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + eos_flags_tensor = torch.tensor([False] * batch_size).to(device) + if not USE_CUDA: xm.mark_step() - # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. for i in range(max_seq_len - min_prompt_len): next_token_ids = model( @@ -184,6 +186,16 @@ def generate( if not USE_CUDA: xm.mark_step() + # Check if all sequences have reached EOS. + batch_eos_idx = (next_token_ids == tokenizer.eos_id).nonzero( + as_tuple=True)[0] + for eos_idx in batch_eos_idx: + if output_index >= prompt_length[eos_idx]: + eos_flags_tensor[eos_idx] = True + + if eos_flags_tensor.all(): + break + # Detokenization. token_ids = token_ids_tensor.tolist() results = []