@@ -325,21 +325,25 @@ void alloc_with_matching_layout(
325
325
at::IntArrayRef (shape), at::IntArrayRef (ordered_strides), 0 );
326
326
}
327
327
328
- void permute_to_matching_layout (const Tensor& output, Tensor& grad_output)
329
- {
328
+ void permute_to_matching_layout (const Tensor& output, Tensor& grad_output) {
330
329
const int dims = output.sizes ().size ();
331
330
std::vector<int64_t > outer_to_inner (dims);
332
331
std::iota (outer_to_inner.begin (), outer_to_inner.end (), 0 );
333
332
const auto o_strides = output.strides ();
334
333
std::stable_sort (
335
- outer_to_inner.begin (), outer_to_inner.end (), [&o_strides](int idx1, int idx2) {
334
+ outer_to_inner.begin (),
335
+ outer_to_inner.end (),
336
+ [&o_strides](int idx1, int idx2) {
336
337
return o_strides[idx1] > o_strides[idx2];
337
338
});
338
339
std::vector<int64_t > inverse (dims);
339
340
for (int d = 0 ; d < dims; d++) {
340
- inverse[d] = std::find (outer_to_inner.begin (), outer_to_inner.end (), d) - outer_to_inner.begin ();
341
+ inverse[d] = std::find (outer_to_inner.begin (), outer_to_inner.end (), d) -
342
+ outer_to_inner.begin ();
341
343
}
342
- grad_output = grad_output.permute (at::IntArrayRef (outer_to_inner)).contiguous ().permute (at::IntArrayRef (inverse));
344
+ grad_output = grad_output.permute (at::IntArrayRef (outer_to_inner))
345
+ .contiguous ()
346
+ .permute (at::IntArrayRef (inverse));
343
347
}
344
348
345
349
bool same_strides (const Tensor& t1, const Tensor& t2) {
@@ -348,13 +352,14 @@ bool same_strides(const Tensor& t1, const Tensor& t2) {
348
352
const auto t1strides = t1.strides ();
349
353
const auto t2strides = t2.strides ();
350
354
const int dim = t1strides.size ();
351
- if (dim != (int ) t2strides.size ()) {
355
+ if (dim != (int )t2strides.size ()) {
352
356
return false ;
353
357
}
354
358
const auto t1sizes = t1.sizes ();
355
359
const auto t2sizes = t2.sizes ();
356
-
357
- // we are going through strides backward here, but if both are backward it's comparable
360
+
361
+ // we are going through strides backward here, but if both are backward it's
362
+ // comparable
358
363
for (int i = 0 ; i < dim; i++) {
359
364
if (t1sizes[i] > 1 ) {
360
365
t1_strides_no_ones.push_back (t1strides[i]);
@@ -363,7 +368,11 @@ bool same_strides(const Tensor& t1, const Tensor& t2) {
363
368
t2_strides_no_ones.push_back (t2strides[i]);
364
369
}
365
370
}
366
- return std::equal (t1_strides_no_ones.begin (), t1_strides_no_ones.end (), t2_strides_no_ones.begin (), t2_strides_no_ones.end ());
371
+ return std::equal (
372
+ t1_strides_no_ones.begin (),
373
+ t1_strides_no_ones.end (),
374
+ t2_strides_no_ones.begin (),
375
+ t2_strides_no_ones.end ());
367
376
}
368
377
} // namespace
369
378
@@ -738,7 +747,6 @@ void run_cudnn_SDP_bprop(
738
747
TORCH_WARN_ONCE (
739
748
" cuDNN SDPA backward got grad_output.strides() != output.strides(), "
740
749
" attempting to materialize a grad_output with matching strides..." );
741
- TORCH_WARN_ONCE (" output: " , o.strides (), " grad_output: " , dO_.strides ());
742
750
permute_to_matching_layout (o, dO_);
743
751
}
744
752
TORCH_INTERNAL_ASSERT (
0 commit comments