-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
In huggingface/diffusers#9133 some inconsistencies was discovered in output from FLUX. I was not able to reproduce, but @bghira saw this issue consistently. We chatted offline and discovered we had different versions of PyTorch. @bghira had latest nightly, I had v2.3.1. I updated to the nightly and could also reproduce the issue consistently.
Bisect replay log:
git bisect start
# status: waiting for both good and bad commits
# bad: [a9d34138dfd5c049305487fc9f1b8bae0cadc98e] [traced-graph][sparse] add to_dense() operation to sparse export test (#133175)
git bisect bad a9d34138dfd5c049305487fc9f1b8bae0cadc98e
# good: [63d5e9221bedd1546b7d364b5ce4171547db12a9] [EZ] Pin scipy to 1.12 for Py-3.12 (#127322)
git bisect good 63d5e9221bedd1546b7d364b5ce4171547db12a9
# good: [86a2d67bb9db7dae8ff4589930dd505a6c5b4ec6] Simplify guards using info from previous guards (#121463)
git bisect good 86a2d67bb9db7dae8ff4589930dd505a6c5b4ec6
# skip: [6bfc6e08759cf1fd7cf89916124285bf131b7168] Add back private function torch.cuda.amp.autocast_mode._cast (#127433)
git bisect skip 6bfc6e08759cf1fd7cf89916124285bf131b7168
# skip: [bb1468d50660a7c3c1c635925688f406e1d7bd5f] Updates state dict in state dict loader (#127617)
git bisect skip bb1468d50660a7c3c1c635925688f406e1d7bd5f
# bad: [3174e6cb8e2d37210c7569e51dc6a9522110e0f3] [Temp][CI] Run older MPS tests/Mac builds on MacOS 13 (#127428)
git bisect bad 3174e6cb8e2d37210c7569e51dc6a9522110e0f3
# good: [0e6367dd44d2b514609c880fc7644ebef2c8ab89] Factor var_to_range assignments to _update_var_to_range helper (#124283)
git bisect good 0e6367dd44d2b514609c880fc7644ebef2c8ab89
# bad: [37d2ecd12322825767e2357d0d0b98f6af48cbe3] Only log toplevel torchscript calls. (#125714)
git bisect bad 37d2ecd12322825767e2357d0d0b98f6af48cbe3
# skip: [07d3af8e6af8e02bdbd489d5590175c4f2d931d3] Added ARC test jobs to all build jobs in the unstable bucket (#125142)
git bisect skip 07d3af8e6af8e02bdbd489d5590175c4f2d931d3
# bad: [97509c8eb2aef89c8bf8429018aa6ce4a8269fde] Revert "[Inductor][Quant] Fix PT2E Dynamic Quant regression (#125207)"
git bisect bad 97509c8eb2aef89c8bf8429018aa6ce4a8269fde
# good: [720e5f306dce7d1b1103ec4ed0de3b9d7bc6155c] Update CODEOWNERS - Dataloader (#125181)
git bisect good 720e5f306dce7d1b1103ec4ed0de3b9d7bc6155c
# good: [0302dc68bf76a0af6dd4bb0488aaf22998374a0e] [Reland] Fakify script object inputs and attributes for non-strict ex… (#125490)
git bisect good 0302dc68bf76a0af6dd4bb0488aaf22998374a0e
# good: [6f1e3a6bf73327a351dc8a8c08635bd727b3134f] [DCP] Always flatten mapping even if no tensors present (#125335)
git bisect good 6f1e3a6bf73327a351dc8a8c08635bd727b3134f
# bad: [13462ecd27e693cd6facddc1bda92b204ed2e15e] Update preserve_node_meta to reset torch.fx.traceback.current_meta (#125500)
git bisect bad 13462ecd27e693cd6facddc1bda92b204ed2e15e
# good: [939b701d3ad615c75e4b1f92eb81a6c0f29c1343] SymInt-ify mem-efficient attention forward op signature (#125418)
git bisect good 939b701d3ad615c75e4b1f92eb81a6c0f29c1343
# bad: [50073127b5e49b2b75a912d73a70ecb61890a32d] [tp] add some test for shard output layouts for rowwise parallel (#125713)
git bisect bad 50073127b5e49b2b75a912d73a70ecb61890a32d
# bad: [58e045d03c25b3b50fcdfbe2fce75965ee869606] [MPS] Fix strided ELU op (#125692)
git bisect bad 58e045d03c25b3b50fcdfbe2fce75965ee869606
# bad: [baf36f6d11bf1dd7c8bcd99292a6f50238bd4955] Pad bandwidth bound split k kernels on a100 (#125650)
git bisect bad baf36f6d11bf1dd7c8bcd99292a6f50238bd4955
# good: [3fb53bb6a7ca55dfc85b0b657b58af8c578b5e5b] [MPS] Fix strided mse_loss (#125696)
git bisect good 3fb53bb6a7ca55dfc85b0b657b58af8c578b5e5b
# bad: [ba275486791622a7493f032d02ad420fa9497d30] [MPS] Remove in place views (causes too many crashes) (#124895)
git bisect bad ba275486791622a7493f032d02ad420fa9497d30
# first bad commit: [ba275486791622a7493f032d02ad420fa9497d30] [MPS] Remove in place views (causes too many crashes) (#124895)
Culprit was identified to be ba27548 (PR: #124895).
I did some digging and figured out it was related to offsets.
MRE:
import torch
import torch.nn as nn
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
x_mps = x_cpu.to('mps')
output_cpu = bn_cpu(x_cpu)
output_mps = bn_mps(x_mps)
output_offset_cpu = bn_cpu(x_cpu[5:])
output_offset_mps = bn_mps(x_mps[5:])
print(f"{torch.sum(abs(output_cpu - output_mps.cpu()) > 1e-5) = }")
print(f"{torch.sum(abs(output_offset_cpu - output_offset_mps.cpu()) > 1e-5) = }")
I have a fix and regression test prepared already. I'll ship a PR soon.
Versions
Collecting environment information...
PyTorch version: 2.5.0a0+gita9d3413
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.6 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A
Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.5.0a0+gita9d3413
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.5.0a0+gita9d3413 dev_0
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen