8000 `SymInt` input doesn't get optimized out from `torch.compiled()` graph even if unused · Issue #108446 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
SymInt input doesn't get optimized out from torch.compiled() graph even if unused #108446
Open
@kkontny

Description

@kkontny

🐛 Describe the bug

We have Dynamo backend defined similar to IPEX which traces and freezes the model:

import importlib
import logging

import torch
from torch._dynamo import register_backend
from .common import fake_tensor_unsupported

@register_backend
@fake_tensor_unsupported
def aio(model, input):
    model.print_readable()
    try:
        with torch.no_grad():
            traced_model = torch.jit.trace(model.eval(), inputs)
            frozen_model = torch.jit.freeze(traced_model)
        return frozen_model
    except Exception as ex:
        log.warning("JIT trace failed during the optimize process.")
        log.warning(print(ex))
        return model

I'm running the Llama model from Transformers repo tag tag: v4.30.1 with following script:

import argparse
import os
import sys
import time
import datetime
import torch._dynamo.config
import transformers
import torch
import torch._dynamo

from torch.autograd.profiler import profile
import traceback as tb
import logging

default_input_texts = ("Below is an instruction that describes a task." 
                      "Write a response that appropriately completes the request.\r\n\r\n" 
                      "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:")


def parse_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("-m", "--model_path",
                      type=str, required=None,
                      help="Recovered Model path")
  parser.add_argument("-a", "--aio",
                      dest="aio", action='store_true',
                      help="Use AIO backend")
  parser.set_defaults(aio=False)
  parser.add_argument("-i", "--input_prompt",
                      type=str, default=default_input_texts,
                      help="Input prompt")
  return parser.parse_args()

def main():
  args = parse_args()

  torch._dynamo.config.cache_size_limit = 128

  print("Loading model and tokenizer...")
  alpaca_model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
  alpaca_tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
  alpaca_model.config.pad_token_id = alpaca_tokenizer.pad_token_id = 0 #unk
  alpaca_model.config.bos_token_id = 1
  alpaca_model.config.eos_token_id = 2

  print("Torch compile...")
  alpaca_model = alpaca_model.eval()
  alpaca_model = torch.compile(alpaca_model, backend="air", dynamic=True, fullgraph=False)

  inputs = alpaca_tokenizer(args.input_prompt, return_tensors="pt")
  outputs = alpaca_model.generate(inputs=inputs.input_ids, max_new_tokens=100)

   output_text = alpaca_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

   print("-- Alpaca output --")
   print("{}\n\n".format(output_text))

one of the graph that torch.compile() produces is:

class GraphModule(torch.nn.Module):
    def forward(self, s0 : torch.SymInt, L_attention_mask_ : torch.Tensor):
        l_attention_mask_ = L_attention_mask_
        
        # File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:737, code: position_ids = attention_mask.long().cumsum(-1) - 1
        long = l_attention_mask_.long()
        cumsum = long.cumsum(-1);  long = None
        sub = cumsum - 1;  cumsum = None
        
        # File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:738, code: position_ids.masked_fill_(attention_mask == 0, 1)
        eq = l_attention_mask_ == 0;  l_attention_mask_ = None
        masked_fill_ = sub.masked_fill_(eq, 1);  eq = None
        return (sub,)

Here second argument is s0 : torch.SymInt which isn't used later, I think it should be optimized out by DeadCodeElimination, I tried to call eliminate_dead_code on model, it doesn't do anything. This is troublesome since orch.jit.trace doesn't support SymInt inputs.
This bug occurs many times in this model, I pasted only one subgraph where is occurs since it is short.

Problem doesn't occur on v2.0.0 tag, but happens on 400c4de53bb7b36066aef381313ed71e4a877e95

Versions

main branch

cc @ezyang @anijain2305 @chauhang @penguinwu @msaroufim @wconstab @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0