8000 Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements · Issue #141909 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements #141909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gazonk opened this issue Dec 2, 2024 · 10 comments
Assignees
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@gazonk
Copy link
gazonk commented Dec 2, 2024

🚀 The feature, motivation and pitch

Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements

Alternatives

No response

Additional context

Reported as suggested by the error message.
I'm on a Apple M2 Max MacBook Pro with 96GB Ram.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@gazonk
Copy link
Author
gazonk commented Dec 2, 2024

ComfyUI Error Report

Error Details

  • Node ID: 10
  • Node Type: VAEEncode
  • Exception Type: RuntimeError
  • Exception Message: Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements.File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues

Stack Trace

  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 323, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)

  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 198, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)

  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 169, in _map_node_over_list
    process_inputs(input_dict, i)

  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 158, in process_inputs
    results.append(getattr(obj, func)(**inputs))

  File "/Users/renauddumeur/work/repos/ComfyUI/nodes.py", line 320, in encode
    t = vae.encode(pixels[:,:,:,:3])

  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/sd.py", line 415, in encode
    out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()

  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/models/autoencoder.py", line 179, in encode
    z = self.encoder(x)

  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)

  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)

  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 531, in forward
    h = self.mid.attn_1(h)

  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)

  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)

  File "/Users/renauddumeur/work/repos/ComfyUI/
8000
comfy/ldm/modules/diffusionmodules/model.py", line 287, in forward
    h_ = self.optimized_attention(q, k, v)

  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 206, in normal_attention
    r1 = slice_attention(q, k, v)

  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 180, in slice_attention
    s1 = torch.bmm(q[:, i:end], k) * scale

System Information

  • ComfyUI Version: v0.3.6
  • Arguments: main.py
  • OS: posix
  • Python Version: 3.10.15 (main, Oct 3 2024, 02:24:49) [Clang 14.0.6 ]
  • Embedded Python: false
  • PyTorch Version: 2.5.1

Devices

  • Name: mps
    • Type: mps
    • VRAM Total: 103079215104
    • VRAM Free: 67738386432
    • Torch VRAM Total: 103079215104
    • Torch VRAM Free: 67738386432

Logs

2024-12-02T21:50:27.541065 - [START] Security scan2024-12-02T21:50:27.541076 - 
2024-12-02T21:50:27.828194 - [DONE] Security scan2024-12-02T21:50:27.828230 - 
2024-12-02T21:50:27.869554 - ## ComfyUI-Manager: installing dependencies done.2024-12-02T21:50:27.869585 - 
2024-12-02T21:50:27.869601 - ** ComfyUI startup time:2024-12-02T21:50:27.869613 -  2024-12-02T21:50:27.869625 - 2024-12-02 21:50:27.8695902024-12-02T21:50:27.869635 - 
2024-12-02T21:50:27.869661 - ** Platform:2024-12-02T21:50:27.869672 -  2024-12-02T21:50:27.869682 - Darwin2024-12-02T21:50:27.869691 - 
2024-12-02T21:50:27.869702 - ** Python version:2024-12-02T21:50:27.869712 -  2024-12-02T21:50:27.869721 - 3.10.15 (main, Oct  3 2024, 02:24:49) [Clang 14.0.6 ]2024-12-02T21:50:27.869731 - 
2024-12-02T21:50:27.869742 - ** Python executable:2024-12-02T21:50:27.869751 -  2024-12-02T21:50:27.869760 - /opt/homebrew/anaconda3/envs/ComfyUI/bin/python2024-12-02T21:50:27.869769 - 
2024-12-02T21:50:27.869779 - ** ComfyUI Path:2024-12-02T21:50:27.869788 -  2024-12-02T21:50:27.869797 - /Users/renauddumeur/work/repos/ComfyUI2024-12-02T21:50:27.869806 - 
2024-12-02T21:50:27.869844 - ** Log path:2024-12-02T21:50:27.869855 -  2024-12-02T21:50:27.869864 - /Users/renauddumeur/work/repos/ComfyUI/comfyui.log2024-12-02T21:50:27.869873 - 
2024-12-02T21:50:28.164864 - 
Prestartup times for custom nodes:2024-12-02T21:50:28.164906 - 
2024-12-02T21:50:28.164931 -    0.6 seconds:2024-12-02T21:50:28.164950 -  2024-12-02T21:50:28.164961 - /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/ComfyUI-Manager2024-12-02T21:50:28.164972 - 
2024-12-02T21:50:28.164983 - 
2024-12-02T21:50:28.914905 - Total VRAM 98304 MB, total RAM 98304 MB
2024-12-02T21:50:28.914987 - pytorch version: 2.5.1
2024-12-02T21:50:28.915088 - Set vram state to: SHARED
2024-12-02T21:50:28.915120 - Device: mps
2024-12-02T21:50:29.463868 - Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention
2024-12-02T21:50:30.093069 - [Prompt Server] web root: /Users/renauddumeur/work/repos/ComfyUI/web
2024-12-02T21:50:30.233520 - /opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
2024-12-02T21:50:30.325684 - ### Loading: ComfyUI-Manager (V2.54)2024-12-02T21:50:30.325714 - 
2024-12-02T21:50:30.378119 - ### ComfyUI Revision: 2876 [8e4118c0] | Released on '2024-12-01'2024-12-02T21:50:30.378163 - 
2024-12-02T21:50:30.415357 - [ComfyUI-Manager] default cache updated: https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/model-list.json2024-12-02T21:50:30.415411 - 
2024-12-02T21:50:30.432146 - [ComfyUI-Manager] default cache updated: https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/extension-node-map.json2024-12-02T21:50:30.432172 - 
2024-12-02T21:50:30.438374 - [ComfyUI-Manager] default cache updated: https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/github-stats.json2024-12-02T21:50:30.438391 - 
2024-12-02T21:50:30.438796 - [ComfyUI-Manager] default cache updated: https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/alter-list.json2024-12-02T21:50:30.438811 - 
2024-12-02T21:50:30.451791 - [ComfyUI-Manager] default cache updated: https://raw.githubusercontent.com/ltdrdata/ComfyUI-Manager/main/custom-node-list.json2024-12-02T21:50:30.451814 - 
2024-12-02T21:50:31.223634 - 
Import times for custom nodes:
2024-12-02T21:50:31.223758 -    0.0 seconds: /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/websocket_image_save.py
2024-12-02T21:50:31.223792 -    0.0 seconds: /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/ComfyUI-GGUF
2024-12-02T21:50:31.223818 -    0.1 seconds: /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/ComfyUI-Manager
2024-12-02T21:50:31.223841 -    0.8 seconds: /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/ComfyUI-VideoHelperSuite
2024-12-02T21:50:31.223862 - 
2024-12-02T21:50:31.226794 - Starting server

2024-12-02T21:50:31.227068 - To see the GUI go to: http://127.0.0.1:8188
2024-12-02T21:50:47.166743 - FETCH DATA from: /Users/renauddumeur/work/repos/ComfyUI/custom_nodes/ComfyUI-Manager/extension-node-map.json2024-12-02T21:50:47.166804 - 2024-12-02T21:50:47.175085 -  [DONE]2024-12-02T21:50:47.175168 - 
2024-12-02T21:50:53.836094 - got prompt
2024-12-02T21:50:53.950551 - model weight dtype torch.float16, manual cast: None
2024-12-02T21:50:53.952820 - model_type EPS
2024-12-02T21:50:54.170683 - Using split attention in VAE
2024-12-02T21:50:54.171637 - Using split attention in VAE
2024-12-02T21:50:54.248098 - Requested to load SD1ClipModel
2024-12-02T21:50:54.248179 - Loading 1 new model
2024-12-02T21:50:54.250134 - loaded completely 0.0 235.84423828125 True
2024-12-02T21:50:54.359190 - Requested to load AutoencoderKL
2024-12-02T21:50:54.359380 - Loading 1 new model
2024-12-02T21:50:54.435614 - loaded completely 0.0 319.11416244506836 True
2024-12-02T21:50:56.450926 - !!! Exception during processing !!! Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements.File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues
2024-12-02T21:50:56.454699 - Traceback (most recent call last):
  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 323, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 198, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 169, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/Users/renauddumeur/work/repos/ComfyUI/execution.py", line 158, in process_inputs
    results.append(getattr(obj, func)(**inputs))
  File "/Users/renauddumeur/work/repos/ComfyUI/nodes.py", line 320, in encode
    t = vae.encode(pixels[:,:,:,:3])
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/sd.py", line 415, in encode
    out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/models/autoencoder.py", line 179, in encode
    z = self.encoder(x)
  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 531, in forward
    h = self.mid.attn_1(h)
  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/homebrew/anaconda3/envs/ComfyUI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 287, in forward
    h_ = self.optimized_attention(q, k, v)
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 206, in normal_attention
    r1 = slice_attention(q, k, v)
  File "/Users/renauddumeur/work/repos/ComfyUI/comfy/ldm/modules/diffusionmodules/model.py", line 180, in slice_attention
    s1 = torch.bmm(q[:, i:end], k) * scale
RuntimeError: Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements.File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues

2024-12-02T21:50:56.455158 - Prompt executed in 2.61 seconds

Attached Workflow

Please make sure that workflow does not contain any sensitive information such as API keys or passwords.

{"last_node_id":11,"last_link_id":22,"nodes":[{"id":3,"type":"KSampler","pos":[1015,130],"size":[315,262],"flags":{},"order":5,"mode":0,"inputs":[{"name":"model","type":"MODEL","link":12},{"name":"positive","type":"CONDITIONING","link":13},{"name":"negative","type":"CONDITIONING","link":14},{"name":"latent_image","type":"LATENT","link":15}],"outputs":[{"name":"LATENT","type":"LATENT","links":[18]}],"properties":{"Node name for S&R":"KSampler"},"widgets_values":[348692293144324,"randomize",8,6.5,"euler_ancestral","normal",0.5]},{"id":7,"type":"CLIPTextEncode","pos":[515,460],"size":[400,200],"flags":{},"order":3,"mode":0,"inputs":[{"name":"clip","type":"CLIP","link":17}],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[14]}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["text, watermark, nsfw, nude"]},{"id":8,"type":"VAEDecode","pos":[1430,130],"size":[210,46],"flags":{},"order":6,"mode":0,"inputs":[{"name":"samples","type":"LATENT","link":18},{"name":"vae","type":"VAE","link":19}],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[20]}],"properties":{"Node name for S&R":"VAEDecode"},"widgets_values":[]},{"id":9,"type":"SaveImage","pos":[1740,130],"size":[315,58],"flags":{},"order":7,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":20}],"outputs":[],"properties":{},"widgets_values":["ComfyUI"]},{"id":10,"type":"VAEEncode","pos":[515,790],"size":[210,46],"flags":{},"order":4,"mode":0,"inputs":[{"name":"pixels","type":"IMAGE","link":21},{"name":"vae","type":"VAE","link":22}],"outputs":[{"name":"LATENT","type":"LATENT","links":[15]}],"properties":{"Node name for S&R":"VAEEncode"},"widgets_values":[]},{"id":11,"type":"LoadImage","pos":[100,358],"size":[315,314],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[21]},{"name":"MASK","type":"MASK","links":null}],"properties":{"Node name for S&R":"LoadImage"},"widgets_values":["DSC_0014-1.jpeg","image"]},{"id":6,"type":"CLIPTextEncode","pos":[515,130],"size":[400,200],"flags":{},"order":2,"mode":0,"inputs":[{"name":"clip","type":"CLIP","link":16}],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[13]}],"properties":{"Node name for S&R":"CLIPTextEncode"},"widgets_values":["He is chewing something so delicious that his hair rise"]},{"id":4,"type":"CheckpointLoaderSimple","pos":[100,130],"size":[315,98],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","type":"MODEL","links":[12]},{"name":"CLIP","type":"CLIP","links":[16,17]},{"name":"VAE","type":"VAE","links":[19,22]}],"properties":{"Node name for S&R":"CheckpointLoaderSimple"},"widgets_values":["aniverse_v15Pruned.safetensors"]}],"links":[[12,4,0,3,0,"MODEL"],[13,6,0,3,1,"CONDITIONING"],[14,7,0,3,2,"CONDITIONING"],[15,10,0,3,3,"LATENT"],[16,4,1,6,0,"CLIP"],[17,4,1,7,0,"CLIP"],[18,3,0,8,0,"LATENT"],[19,4,2,8,1,"VAE"],[20,8,0,9,0,"IMAGE"],[21,11,0,10,0,"IMAGE"],[22,4,2,10,1,"VAE"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.2100000000000002,"offset":[-115.00663521250681,151.71115512689147]}},"version":0.4}

Additional Context

(Please add any additional context or steps to reproduce the error here)

@malfet malfet added high priority module: regression It used to work, and now it doesn't module: mps Related to Apple Metal Performance Shaders framework labels Dec 3, 2024
@malfet
Copy link
Contributor
malfet commented Dec 3, 2024

Just curious: did it work in 2.4?

@malfet
Copy link
Contributor
malfet commented Dec 3, 2024

Unrelated, but a very funny (and unhelpful) error:

% /usr/bin/python3 -c "import torch;print(torch.empty(72250,72250, device='mps', dtype=torch.float16))" 
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/nshulga/git/pytorch/pytorch/torch/_tensor.py", line 568, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/Users/nshulga/git/pytorch/pytorch/torch/_tensor_str.py", line 708, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/Users/nshulga/git/pytorch/pytorch/torch/_tensor_str.py", line 625, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/Users/nshulga/git/pytorch/pytorch/torch/_tensor_str.py", line 339, in _tensor_str
    self = self.float()
RuntimeError: Invalid buffer size: 19.45 GB

Fixed in #141927

@gazonk
Copy link
Author
gazonk commented Dec 3, 2024

Just curious: did it work in 2.4?

I don't know. I started using ComfyUI very recently.

@malfet malfet added this to the 2.6.0 milestone Dec 6, 2024
@hvaara
Copy link
Contributor
hvaara commented Dec 7, 2024

Minimal repro

import torch
x1 = torch.randn(1, 72250, 1, device="mps")
x2 = torch.randn(1, 1, 72250, device="mps")
res = torch.bmm(x1, x2)

Reproed bad on v2.4.1, where PyTorch crashed with

/AppleInternal/Library/BuildRoots/4b66fb3c-7dd0-11ef-b4fb-4a83e32a47e1/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:850: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: total bytes of NDArray > 2**32'

@hvaara
Copy link
Contributor
hvaara commented Dec 7, 2024

Commenting out

TORCH_CHECK(
batch1.size(1) * batch2.size(2) <= pow(2, 32),
"Output size of the matrix multiplication is larger than currently supported by the MPS backend: ",
batch1.size(1),
",",
batch2.size(2),
", needs to be less than 2**32 elements.",
"File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues");

at main reproduces the error from v2.4.1.

@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review module: regression It used to work, and now it doesn't labels Dec 9, 2024
@janeyx99
Copy link
Contributor
janeyx99 commented Dec 9, 2024

From triage discussion: we've determined it's high pri to use fallback or to use tiling in the large matrix case.

@jhavukainen
Copy link
Collaborator

Added a fix in #143095 that solves the minimal repro case (thanks @hvaara). I don't have the instructions to run ComfyUI though but I'm optimistic that this should get it through the bmm call at least.

aditew01 pushed a commit to aditew01/pytorch that referenced this issue Dec 18, 2024
…ytorch#143095)

The previous tiling implementation worked for up to 2^32 total elements per single batch entry. This extends the functionality to support the dimensions encountered in ComfyUI (output shape: 1,72250,72250).

Fixes pytorch#141909
Pull Request resolved: pytorch#143095
Approved by: https://github.com/kulinseth
pytorchbot pushed a commit that referenced this issue Jan 10, 2025
…143095)

The previous tiling implementation worked for up to 2^32 total elements per single batch entry. This extends the functionality to support the dimensions encountered in ComfyUI (output shape: 1,72250,72250).

Fixes #141909
Pull Request resolved: #143095
Approved by: https://github.com/kulinseth

(cherry picked from commit afa313e)
kit1980 pushed a commit that referenced this issue Jan 10, 2025
…144558)

Extend bmm tiling to work up to 2^32 elem in any single output dim (#143095)

The previous tiling implementation worked for up to 2^32 total elements per single batch entry. This extends the functionality to support the dimensions encountered in ComfyUI (output shape: 1,72250,72250).

Fixes #141909
Pull Request resolved: #143095
Approved by: https://github.com/kulinseth

(cherry picked from commit afa313e)

Co-authored-by: Joona Havukainen <jhavukainen@apple.com>
@atalman
Copy link
Contributor
atalman commented Jan 20, 2025

Observing RuntimeError: Invalid buffer size: 19.45 GB error with rc 2.6 and latest nightly :

import torch
print(torch.__version__)
x1 = torch.randn(1, 72250, 1, device="mps")
x2 = torch.randn(1, 1, 72250, device="mps")
res = torch.bmm(x1, x2)

Output RC 2.6:

python3 test1.py
2.6.0
Traceback (most recent call last):
  File "/Users/atalman/Downloads/release26/pytorch/test/test1.py", line 5, in <module>
    res = torch.bmm(x1, x2)
RuntimeError: Invalid buffer size: 19.45 GB

Output Nightly:

python3 test1.py                                                                                                     
2.7.0.dev20250120
Traceback (most recent call last):
  File "/Users/atalman/Downloads/release26/pytorch/test/test1.py", line 5, in <module>
    res = torch.bmm(x1, x2)
RuntimeError: Invalid buffer size: 19.45 GB

@malfet
Copy link
Contributor
malfet commented Jan 21, 2025

@atalman what machine you run this one on?
Works fine on Mac with 64Gb of RAM, but fails on one with 32Gb...

% pip install torch==2.5.1  --index-url https://download.pytorch.org/whl/test
...
% python
Python 3.11.11 (main, Dec  3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
2.5.1
>>> x1 = torch.randn(1, 72250, 1, device="mps")
>>> x2 = torch.randn(1, 1, 72250, device="mps")
>>> res = torch.bmm(x1, x2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Output size of the matrix multiplication is larger than currently supported by the MPS backend: 72250,72250, needs to be less than 2**32 elements.File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues
% pip install torch==2.6.0  --index-url https://download.pytorch.org/whl/test
...
% python                                                                     
Python 3.11.11 (main, Dec  3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
2.6.0
>>> x1 = torch.randn(1, 72250, 1, device="mps")
>>> x2 = torch.randn(1, 1, 72250, device="mps")
>>> res = torch.bmm(x1, x2)
>>> res
tensor([[[-0.0011,  0.0028,  0.0079,  ...,  0.0073,  0.0055,  0.0015],
         [ 0.0826, -0.2130, -0.6030,  ..., -0.5581, -0.4169, -0.1109],
         [-0.0501,  0.1292,  0.3658,  ...,  0.3386,  0.2529,  0.0673],
         ...,
         [-0.0556,  0.1432,  0.4053,  ...,  0.3752,  0.2802,  0.0746],
         [-0.0408,  0.1052,  0.2980,  ...,  0.2758,  0.2060,  0.0548],
         [ 0.1382, -0.3562, -1.0085,  ..., -0.9335, -0.6972, -0.1856]]],
       device='mps:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants
0