8000 torch.compile graph break when tuning llama with FA2 · Issue #37199 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content
torch.compile graph break when tuning llama with FA2 #37199
@SilverSoldier

Description

@SilverSoldier

System Info

  • transformers version: 4.50.3
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.12.9
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

I am getting 2 graph breaks when tuning llama3-8b model with torch.compile. Both are failing due to Dynamic Control Flow issues with the following error message:

Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow.

The thing is, both of these lines were actually last touched by PRs to remove graph breaks, so clearly something changed in pytorch compile that is breaking them again.

  1. if attention_mask is not None and (attention_mask == 0.0).any() in modeling_llama
    This was fixed just 4 months ago in this PR
  2. max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) in modeling_flash_attention_utils
    This was fixed as part of this PR 6 months ago.

The error message suggests to use torch.cond but it is not trivial to use for both these cases. I attempted to fix #1, but the return values of the if and else branch are None and attention_mask respectively while cond expects the outputs of the 2 branches to be of same type and shape.
For #2 it is the torch.diff clause which is causing the problem. I haven't looked closely enough but the multiple condition checks might make clean separation into cond syntax slightly difficult.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

accelerate launch --num_processes=1  -m tuning.sft_trainer --output_dir ./train_output --max_steps 5 --learning_rate 2e-5 --training_data_path=/data/data/ei_5.jsonl --save_steps=50 --torch_dtype bfloat16 --logging_strategy steps --logging_steps 1 --per_device_train_batch_size 8 --max_seq_length 1024 --include_tokens_per_second true --data_formatter_template "### Input: {input} \n\n### Response: {output}" --response_template "\n### Response:" --torch_compile True --model_name_or_path /data/models/llama3-8b --use_flash_attn True --packing

Checked with pytorch 2.5 and 2.6, with and without padding_free with no change in graph breaks

Expected behavior

Expected no graph breaks.

Curious if anyone knows why this is popping up now and how to fix it, is it only through cond or is there some other way (I am willing to fix but not sure how).
Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0