8000 `SymInt` input doesn't get optimized out from `torch.compiled()` graph even if unused · Issue #108446 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

SymInt input doesn't get optimized out from torch.compiled() graph even if unused #108446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kkontny opened this issue Sep 1, 2023 · 3 comments
Labels
good first issue module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kkontny
Copy link
kkontny commented Sep 1, 2023

🐛 Describe the bug

We have Dynamo backend defined similar to IPEX which traces and freezes the model:

import importlib
import logging

import torch
from torch._dynamo import register_backend
from .common import fake_tensor_unsupported

@register_backend
@fake_tensor_unsupported
def aio(model, input):
    model.print_readable()
    try:
        with torch.no_grad():
            traced_model = torch.jit.trace(model.eval(), inputs)
            frozen_model = torch.jit.freeze(traced_model)
        return frozen_model
    except Exception as ex:
        log.warning("JIT trace failed during the optimize process.")
        log.warning(print(ex))
        return model

I'm running the Llama model from Transformers repo tag tag: v4.30.1 with following script:

import argparse
import os
import sys
import time
import datetime
import torch._dynamo.config
import transformers
import torch
import torch._dynamo

from torch.autograd.profiler import profile
import traceback as tb
import logging

default_input_texts = ("Below is an instruction that describes a task." 
                      "Write a response that appropriately completes the request.\r\n\r\n" 
                      "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:")


def parse_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("-m", "--model_path",
                      type=str, required=None,
                      help="Recovered Model path")
  parser.add_argument("-a", "--aio",
                      dest="aio", action='store_true',
                      help="Use AIO backend")
  parser.set_defaults(aio=False)
  parser.add_argument("-i", "--input_prompt",
                      type=str, default=default_input_texts,
                      help="Input prompt")
  return parser.parse_args()

def main():
  args = parse_args()

  torch._dynamo.config.cache_size_limit = 128

  print("Loading model and tokenizer...")
  alpaca_model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
  alpaca_tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
  alpaca_model.config.pad_token_id = alpaca_tokenizer.pad_token_id = 0 #unk
  alpaca_model.config.bos_token_id = 1
  alpaca_model.config.eos_token_id = 2

  print("Torch compile...")
  alpaca_model = alpaca_model.eval()
  alpaca_model = torch.compile(alpaca_model, backend="air", dynamic=True, fullgraph=False)

  inputs = alpaca_tokenizer(args.input_prompt, return_tensors="pt")
  outputs = alpaca_model.generate(inputs=inputs.input_ids, max_new_tokens=100)

   output_text = alpaca_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

   print("-- Alpaca output --")
   print("{}\n\n".format(output_text))

one of the graph that torch.compile() produces is:

class GraphModule(torch.nn.Module):
    def forward(self, s0 : torch.SymInt, L_attention_mask_ : torch.Tensor):
        l_attention_mask_ = L_attention_mask_
        
        # File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:737, code: position_ids = attention_mask.long().cumsum(-1) - 1
        long = l_attention_mask_.long()
        cumsum = long.cumsum(-1);  long = None
        sub = cumsum - 1;  cumsum = None
        
        # File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:738, code: position_ids.masked_fill_(attention_mask == 0, 1)
        eq = l_attention_mask_ == 0;  l_attention_mask_ = None
        masked_fill_ = sub.masked_fill_(eq, 1);  eq = None
        return (sub,)

Here second argument is s0 : torch.SymInt which isn't used later, I think it should be optimized out by DeadCodeElimination, I tried to call eliminate_dead_code on model, it doesn't do anything. This is troublesome since orch.jit.trace doesn't support SymInt inputs.
This bug occurs many times in this model, I pasted only one subgraph where is occurs since it is short.

Problem doesn't occur on v2.0.0 tag, but happens on 400c4de53bb7b36066aef381313ed71e4a877e95

Versions

main branch

cc @ezyang @anijain2305 @chauhang @penguinwu @msaroufim @wconstab @bdhirsh

@kkontny kkontny changed the title SymInt input doesn't get optimised out from torch.compiled() graph SymInt input doesn't get optimized out from torch.compiled() graph even if unused Sep 1, 2023
@ezyang
Copy link
Contributor
ezyang commented Sep 5, 2023

eliminate_dead_code doesn't eliminate dead inputs. The addition of the SymInt input is intentional. You'll need to muck around with

    def add_symbol_bindings(self, arg: GraphArg):
        # Insert implicit size vars as necessary.  With dynamic shapes, we 
        # maintain the invariant that every sizevar gets a direct SymInt input
        # into the graph.  This means downstream graph transforms can assume
        # every size variable is explicitly bound and accessible, instead of
        # having to pull it out implicitly from tensors.
        
        if self.export:
            return

I think I'd be OK with adding another config knob for this but this is NOT the happy path for Dynamo.

@ezyang ezyang added module: dynamic shapes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 5, 2023
@pavan-msys
Copy link

Hi @kkontny

Can I work on this issue?

@kkontny
Copy link
Author
kkontny commented May 13, 2025

@pavan-msys Sure, I think I have another solution to original problem, however problem is still valid.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0