Skip to content

Empty tensor in torch model trace for concat operation #42027

@apbose

Description

@apbose

System Info

transformer - 4.57.1
GPT2 model trace leads to torch.cat getting empty tensor

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Repro code:

import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
from torch._functorch.aot_autograd import aot_export_joint_simple
from typing import Sequence, Any

tokenizer = AutoTokenizer.from_pretrained("gpt2")
prompt = "GPT2 is a model developed by."
input_id= tokenizer(prompt, return_tensors="pt").input_ids
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to("cuda")

def custom_backend(gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any):
    torch_inputs = [input for input in sample_inputs if isinstance(input, torch.Tensor)]
    def print_cat_inputs(*args, **kwargs):
        print("\n[cat] Intercepted torch.cat call")
        print("args:", args)
        print("kwargs:", kwargs)

        if len(args) > 0 and isinstance(args[0], (list, tuple)):
            cat_inputs = args[0]
            for idx, t in enumerate(cat_inputs):
                if isinstance(t, torch.Tensor):
                    print(f"  - Input {idx}: shape={t.shape}, device={t.device}, dtype={t.dtype}")
                else:
                    print(f"  - Input {idx}: non-tensor input: {t}")
        else:
            print("[cat] Could not detect list of tensors")

        return original_cat(*args, **kwargs)

    original_cat = torch.cat
    torch.cat = print_cat_inputs
    gm = aot_export_joint_simple(
        gm,
        torch_inputs,
        trace_joint=False,
        )
    return gm

cur_input = input_id.to("cuda")
model.forward = torch.compile(model.forward, backend = custom_backend)

Expected behavior

The prints above show that the concat operation receives empty tensors. This comes intransformers 4.57.1but does not in transformers 4.52.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions