Description
🐛 Describe the bug
We have Dynamo backend defined similar to IPEX which traces and freezes the model:
import importlib
import logging
import torch
from torch._dynamo import register_backend
from .common import fake_tensor_unsupported
@register_backend
@fake_tensor_unsupported
def aio(model, input):
model.print_readable()
try:
with torch.no_grad():
traced_model = torch.jit.trace(model.eval(), inputs)
frozen_model = torch.jit.freeze(traced_model)
return frozen_model
except Exception as ex:
log.warning("JIT trace failed during the optimize process.")
log.warning(print(ex))
return model
I'm running the Llama model from Transformers repo tag tag: v4.30.1 with following script:
import argparse
import os
import sys
import time
import datetime
import torch._dynamo.config
import transformers
import torch
import torch._dynamo
from torch.autograd.profiler import profile
import traceback as tb
import logging
default_input_texts = ("Below is an instruction that describes a task."
"Write a response that appropriately completes the request.\r\n\r\n"
"### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model_path",
type=str, required=None,
help="Recovered Model path")
parser.add_argument("-a", "--aio",
dest="aio", action='store_true',
help="Use AIO backend")
parser.set_defaults(aio=False)
parser.add_argument("-i", "--input_prompt",
type=str, default=default_input_texts,
help="Input prompt")
return parser.parse_args()
def main():
args = parse_args()
torch._dynamo.config.cache_size_limit = 128
print("Loading model and tokenizer...")
alpaca_model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
alpaca_tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
alpaca_model.config.pad_token_id = alpaca_tokenizer.pad_token_id = 0 #unk
alpaca_model.config.bos_token_id = 1
alpaca_model.config.eos_token_id = 2
print("Torch compile...")
alpaca_model = alpaca_model.eval()
alpaca_model = torch.compile(alpaca_model, backend="air", dynamic=True, fullgraph=False)
inputs = alpaca_tokenizer(args.input_prompt, return_tensors="pt")
outputs = alpaca_model.generate(inputs=inputs.input_ids, max_new_tokens=100)
output_text = alpaca_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print("-- Alpaca output --")
print("{}\n\n".format(output_text))
one of the graph that torch.compile()
produces is:
class GraphModule(torch.nn.Module):
def forward(self, s0 : torch.SymInt, L_attention_mask_ : torch.Tensor):
l_attention_mask_ = L_attention_mask_
# File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:737, code: position_ids = attention_mask.long().cumsum(-1) - 1
long = l_attention_mask_.long()
cumsum = long.cumsum(-1); long = None
sub = cumsum - 1; cumsum = None
# File: /onspecta/transformers/src/transformers/models/llama/modeling_llama.py:738, code: position_ids.masked_fill_(attention_mask == 0, 1)
eq = l_attention_mask_ == 0; l_attention_mask_ = None
masked_fill_ = sub.masked_fill_(eq, 1); eq = None
return (sub,)
Here second argument is s0 : torch.SymInt
which isn't used later, I think it should be optimized out by DeadCodeElimination, I tried to call eliminate_dead_code
on model, it doesn't do anything. This is troublesome since orch.jit.trace
doesn't support SymInt
inputs.
This bug occurs many times in this model, I pasted only one subgraph where is occurs since it is short.
Problem doesn't occur on v2.0.0 tag, but happens on 400c4de53bb7b36066aef381313ed71e4a877e95
Versions
main branch
cc @ezyang @anijain2305 @chauhang @penguinwu @msaroufim @wconstab @bdhirsh