-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Description
System Info
transformers
version: 4.50.3- Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
- Python version: 3.12.9
- Huggingface_hub version: 0.29.3
- Safetensors version: 0.5.3
- Accelerate version: 1.0.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: Yes
- GPU type: NVIDIA A100-SXM4-80GB
Who can help?
I am getting 2 graph breaks when tuning llama3-8b model with torch.compile. Both are failing due to Dynamic Control Flow issues with the following error message:
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow.
The thing is, both of these lines were actually last touched by PRs to remove graph breaks, so clearly something changed in pytorch compile that is breaking them again.
if attention_mask is not None and (attention_mask == 0.0).any()
in modeling_llama
This was fixed just 4 months ago in this PRmax_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
in modeling_flash_attention_utils
This was fixed as part of this PR 6 months ago.
The error message suggests to use torch.cond
but it is not trivial to use for both these cases. I attempted to fix #1, but the return values of the if and else branch are None and attention_mask respectively while cond
expects the outputs of the 2 branches to be of same type and shape.
For #2 it is the torch.diff
clause which is causing the problem. I haven't looked closely enough but the multiple condition checks might make clean separation into cond
syntax slightly difficult.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
accelerate launch --num_processes=1 -m tuning.sft_trainer --output_dir ./train_output --max_steps 5 --learning_rate 2e-5 --training_data_path=/data/data/ei_5.jsonl --save_steps=50 --torch_dtype bfloat16 --logging_strategy steps --logging_steps 1 --per_device_train_batch_size 8 --max_seq_length 1024 --include_tokens_per_second true --data_formatter_template "### Input: {input} \n\n### Response: {output}" --response_template "\n### Response:" --torch_compile True --model_name_or_path /data/models/llama3-8b --use_flash_attn True --packing
Checked with pytorch 2.5 and 2.6, with and without padding_free with no change in graph breaks
Expected behavior
Expected no graph breaks.
Curious if anyone knows why this is popping up now and how to fix it, is it only through cond
or is there some other way (I am willing to fix but not sure how).
Thanks!