8000 bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps (#… · pytorch/pytorch@b9a22b3 · GitHub
[go: up one dir, main page]

Skip to content

Commit b9a22b3

Browse files
hellopaheSkylion007malfet
authored andcommitted
bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps (#146623)
This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape. The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error. Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms. --- reproduce code: ``` import torch import torch.nn.functional as F head_num, seq_len, embed_dim = 16, 16, 80 bsz = 1 q = torch.randn(head_num, seq_len, embed_dim) k = torch.randn(head_num, seq_len, embed_dim) v = torch.randn(head_num, seq_len, embed_dim) attention_mask = torch.ones(1, seq_len, seq_len) oo_cpu = F.scaled_dot_product_attention( q.to("cpu"), k.to("cpu"), v.to("cpu"), attention_mask.to("cpu"), dropout_p=0.0 ) if torch.backends.mps.is_available(): oo_mps = F.scaled_dot_product_attention( q.to("mps"), k.to("mps"), v.to("mps"), attention_mask.to("mps"), dropout_p=0.0 ) assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5) ``` error outputs: ``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module> oo_mps = F.scaled_dot_product_attention( IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3) ``` hardware and envs: ``` torch 2.6.0 apple m3 max ``` --- Pull Request resolved: #146623 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent 17a8085 commit b9a22b3

File tree

2 files changed

+109
-54
lines changed

2 files changed

+109
-54
lines changed

aten/src/ATen/native/mps/operations/Attention.mm

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
namespace at {
2020
namespace native {
2121

22+
// expand potential 3d to 4d tensor
23+
static inline std::tuple<Tensor, bool> ensure_4d(const Tensor& x) {
24+
if (x.dim() == 3) {
25+
return {x.unsqueeze(0), true};
26+
} else {
27+
return {x, false};
28+
}
29+
}
30+
2231
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor& query,
2332
const Tensor& key,
2433
const Tensor& value,
@@ -39,6 +48,11 @@
3948
TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(),
4049
"_scaled_dot_product_attention_math_for_mps: query, key, and value must not be nested");
4150

51+
// Ensure 4D tensors
52+
auto [q_, sq] = ensure_4d(query);
53+
auto [k_, sk] = ensure_4d(key);
54+
auto [v_, sv] = ensure_4d(value);
55+
4256
using namespace mps;
4357
struct CachedGraph : public MPSCachedGraph {
4458
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@@ -49,67 +63,70 @@
4963
MPSGraphTensor* outputTensor = nil;
5064
MPSGraphTensor* attnTensor = nil;
5165
};
52-
int64_t batchSize = query.size(0);
53-
int64_t num_head = query.size(1);
54-
int64_t qSize = query.size(2);
55-
int64_t headSize = query.size(3);
56-
int64_t maxSeqLength = key.size(2);
66+
int64_t batchSize = q_.size(0);
67+
int64_t num_head = q_.size(1);
68+
int64_t qSize = q_.size(2);
69+
int64_t headSize = q_.size(3);
70+
int64_t maxSeqLength = k_.size(2);
5771
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
5872
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
5973
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
6074
@autoreleasepool {
61-
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
75+
auto mkey = __func__ + getTensorsStringKey({q_, k_, v_}) + ":" + std::to_string(is_causal) + ":" +
6276
std::to_string(attn_mask.has_value());
63-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&](auto mpsGraph, auto graph) {
64-
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, query);
65-
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, key);
66-
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, value);
67-
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
68-
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
77+
auto cachedGraph =
78+
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = q_, k_ = k_, v_ = v_](auto mpsGraph, auto graph) {
79+
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
80+
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
81+
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
82+
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
83+
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
84+
shape:getMPSShape({1})
85+
dataType:MPSDataTypeFloat32];
6986

70-
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
87+
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
7188

72-
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
73-
// TODO: In MacOS15 beta, there is a MPSGraph issue when the SDPA sequence gets remapped to use
74-
// an improved kernel for the computation, causing NaNs in the result. This identity prevents the remapping.
75-
// Limit the availability check once a fix lands.
76-
maskedMM = [mpsGraph identityWithTensor:maskedMM name:nil];
77-
}
89+
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
90+
// TODO: In MacOS15 beta, there is a MPSGraph issue when the SDPA sequence gets remapped to use
91+
// an improved kernel for the computation, causing NaNs in the result. This identity prevents the remapping.
92+
// Limit the availability check once a fix lands.
93+
maskedMM = [mpsGraph identityWithTensor:maskedMM name:nil];
94+
10000 }
7895

79-
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
80-
if ([maskedMM dataType] != MPSDataTypeFloat32) {
81-
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
82-
}
83-
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
84-
if ([maskedMM dataType] != qTensor.dataType) {
85-
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
86-
}
96+
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
97+
if ([maskedMM dataType] != MPSDataTypeFloat32) {
98+
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
99+
}
100+
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
101+
if ([maskedMM dataType] != qTensor.dataType) {
102+
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
103+
}
87104

88-
if (is_causal) {
89-
auto causalMask = [mpsGraph constantWithScalar:1.0f
90-
shape:getMPSShape({qSize, maxSeqLength})
91-
dataType:MPSDataTypeBool];
92-
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
93-
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
94-
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
95-
truePredicateTensor:maskedMM
96-
falsePredicateTensor:minusInf
97-
name:nil];
98-
} else if (attn_mask) {
99-
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
100-
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
101-
}
102-
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
103-
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
104-
graph->qTensor = qTensor;
105-
graph->kTensor = kTensor;
106-
graph->vTensor = vTensor;
107-
graph->outputTensor = output;
108-
graph->attnTensor = sm;
109-
});
110-
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
111-
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
112-
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
105+
if (is_causal) {
106+
auto causalMask = [mpsGraph constantWithScalar:1.0f
107+
shape:getMPSShape({qSize, maxSeqLength})
108+
dataType:MPSDataTypeBool];
109+
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
110+
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
111+
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
112+
truePredicateTensor:maskedMM
113+
falsePredicateTensor:minusInf
114+
name:nil];
115+
} else if (attn_mask) {
116+
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
117+
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
118+
}
119+
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
120+
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
121+
graph->qTensor = qTensor;
122+
graph->kTensor = kTensor;
123+
graph->vTensor = vTensor;
124+
graph->outputTensor = output;
125+
graph->attnTensor = sm;
126+
});
127+
auto qPlaceholder = Placeholder(cachedGraph->qTensor, q_);
128+
auto kPlaceholder = Placeholder(cachedGraph->kTensor, k_);
129+
auto vPlaceholder = Placeholder(cachedGraph->vTensor, v_);
113130
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
114131
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
115132
NSDictionary* feeds = nil;
@@ -122,8 +139,13 @@
122139
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
123140
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
124141
}
125-
return {out, attn};
142+
143+
// Squeeze back to 3D
144+
auto final_out = (sq ? out.squeeze(0) : out);
145+
auto final_attn = (sq ? attn.squeeze(0) : attn);
146+
147+
return {std::move(final_out), std::move(final_attn)};
126148
}
127149

128150
} // namespace native
129-
} // namespace at
151+
} // namespace at

test/test_mps.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9915,6 +9915,39 @@ def test_sdpa_mask_fp16_L6(self):
99159915
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
99169916
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
99179917

9918+
def _test_sdpa_3d_input(self, dtype):
9919+
head_num, seq_len, embed_dim = 16, 16, 80
9920+
9921+
q = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
9922+
k = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
9923+
v = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
9924+
attention_mask = torch.ones(1, seq_len, seq_len, dtype=dtype)
9925+
9926+
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9927+
y = F.scaled_dot_product_attention(
9928+
q.to("mps"),
9929+
k.to("mps"),
9930+
v.to("mps"),
9931+
attention_mask.to("mps"),
9932+
dropout_p=0.0
9933+
)
9934+
9935+
y_ref = F.scaled_dot_product_attention(
9936+
q.to("cpu"),
9937+
k.to("cpu"),
9938+
v.to("cpu"),
9939+
attention_mask.to("cpu"),
9940+
dropout_p=0.0
9941+
)
9942+
9943+
self._compare_tensors(y.cpu(), y_ref)
9944+
9945+
def test_sdpa_3d_input_fp32(self):
9946+
self._test_sdpa_3d_input(torch.float32)
9947+
9948+
def test_sdpa_3d_input_fp16(self):
9949+
self._test_sdpa_3d_input(torch.float16)
9950+
99189951

99199952
class TestGatherScatter(TestCaseMPS):
99209953
def test_slicing_with_step(self):

0 commit comments

Comments
 (0)
0