8000 lint · pytorch/pytorch@e2c811b · GitHub
[go: up one dir, main page]

Skip to content

Commit e2c811b

Browse files
committed
lint
1 parent 51c476e commit e2c811b

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

aten/src/ATen/native/cudnn/MHA.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,21 +325,25 @@ void alloc_with_matching_layout(
325325
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
326326
}
327327

328-
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output)
329-
{
328+
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
330329
const int dims = output.sizes().size();
331330
std::vector<int64_t> outer_to_inner(dims);
332331
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
333332
const auto o_strides = output.strides();
334333
std::stable_sort(
335-
outer_to_inner.begin(), outer_to_inner.end(), [&o_strides](int idx1, int idx2) {
334+
outer_to_inner.begin(),
335+
outer_to_inner.end(),
336+
[&o_strides](int idx1, int idx2) {
336337
return o_strides[idx1] > o_strides[idx2];
337338
});
338339
std::vector<int64_t> inverse(dims);
339340
for (int d = 0; d < dims; d++) {
340-
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - outer_to_inner.begin();
341+
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
342+
outer_to_inner.begin();
341343
}
342-
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)).contiguous().permute(at::IntArrayRef(inverse));
344+
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
345+
.contiguous()
346+
.permute(at::IntArrayRef(inverse));
343347
}
344348

345349
bool same_strides(const Tensor& t1, const Tensor& t2) {
@@ -348,13 +352,14 @@ bool same_strides(const Tensor& t1, const Tensor& t2) {
348352
const auto t1strides = t1.strides();
349353
const auto t2strides = t2.strides();
350354
const int dim = t1strides.size();
351-
if (dim != (int) t2strides.size()) {
355+
if (dim != (int)t2strides.size()) {
352356
return false;
353357
}
354358
const auto t1sizes = t1.sizes();
355359
const auto t2sizes = t2.sizes();
356-
357-
// we are going through strides backward here, but if both are backward it's comparable
360+
361+
// we are going through strides backward here, but if both are backward it's
362+
// comparable
358363
for (int i = 0; i < dim; i++) {
359364
if (t1sizes[i] > 1) {
360365
t1_strides_no_ones.push_back(t1strides[i]);
@@ -363,7 +368,11 @@ bool same_strides(const Tensor& t1, const Tensor& t2) {
363368
t2_strides_no_ones.push_back(t2strides[i]);
364369
}
365370
}
366-
return std::equal(t1_strides_no_ones.begin(), t1_strides_no_ones.end(), t2_strides_no_ones.begin(), t2_strides_no_ones.end());
371+
return std::equal(
372+
t1_strides_no_ones.begin(),
373+
t1_strides_no_ones.end(),
374+
t2_strides_no_ones.begin(),
375+
t2_strides_no_ones.end());
367376
}
368377
} // namespace
369378

@@ -738,7 +747,6 @@ void run_cudnn_SDP_bprop(
738747
TORCH_WARN_ONCE(
739748
"cuDNN SDPA backward got grad_output.strides() != output.strides(), "
740749
"attempting to materialize a grad_output with matching strides...");
741-
TORCH_WARN_ONCE("output: ", o.strides(), " grad_output: ", dO_.strides());
742750
permute_to_matching_layout(o, dO_);
743751
}
744752
TORCH_INTERNAL_ASSERT(

0 commit comments

Comments
 (0)
0