-
Notifications
You must be signed in to change notification settings - Fork 31k
Open
Labels
Description
System Info
transformer - 4.57.1
GPT2 model trace leads to torch.cat getting empty tensor
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Repro code:
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
from torch._functorch.aot_autograd import aot_export_joint_simple
from typing import Sequence, Any
tokenizer = AutoTokenizer.from_pretrained("gpt2")
prompt = "GPT2 is a model developed by."
input_id= tokenizer(prompt, return_tensors="pt").input_ids
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to("cuda")
def custom_backend(gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any):
torch_inputs = [input for input in sample_inputs if isinstance(input, torch.Tensor)]
def print_cat_inputs(*args, **kwargs):
print("\n[cat] Intercepted torch.cat call")
print("args:", args)
print("kwargs:", kwargs)
if len(args) > 0 and isinstance(args[0], (list, tuple)):
cat_inputs = args[0]
for idx, t in enumerate(cat_inputs):
if isinstance(t, torch.Tensor):
print(f" - Input {idx}: shape={t.shape}, device={t.device}, dtype={t.dtype}")
else:
print(f" - Input {idx}: non-tensor input: {t}")
else:
print("[cat] Could not detect list of tensors")
return original_cat(*args, **kwargs)
original_cat = torch.cat
torch.cat = print_cat_inputs
gm = aot_export_joint_simple(
gm,
torch_inputs,
trace_joint=False,
)
return gm
cur_input = input_id.to("cuda")
model.forward = torch.compile(model.forward, backend = custom_backend)
Expected behavior
The prints above show that the concat operation receives empty tensors. This comes intransformers 4.57.1but does not in transformers 4.52.4