10000 Revert "min/max support for SymInt/Floats, finish as_strided/scatter/… · csarofeen/pytorch@811b8e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 811b8e0

Browse files
Revert "min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (pytorch#86643)"
This reverts commit 86f914e. Reverted pytorch#86643 on behalf of https://github.com/osalpekar due to Need to revert this to cleanly revert pytorch#86488. This should be safe to re-land later
1 parent f1fdb6e commit 811b8e0

File tree

12 files changed

+32
-92
lines changed

12 files changed

+32
-92
lines changed

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <c10/cuda/CUDAMathCompat.h>
1818

19+
#include <ATen/native/NonSymbolicBC.h>
1920
#include <ATen/native/nested/NestedTensorUtils.h>
2021
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2122

@@ -367,8 +368,8 @@ __global__ void transform_bias_rescale_qkv_add_padding_kernel(
367368
}
368369

369370
Tensor collapse_dims_1_and_2(const Tensor& sizes) {
370-
auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1);
371-
auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1);
371+
auto sizes_dim1 = at::native::narrow(sizes, 1, 0, 1);
372+
auto sizes_dim2 = at::native::narrow(sizes, 1, 1, 1);
372373

373374
return (sizes_dim1 * sizes_dim2).contiguous();
374375
}
@@ -450,7 +451,7 @@ __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
450451
auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_size_tensor());
451452
auto offsets =
452453
NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
453-
at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel())
454+
at::native::narrow(offsets, 0, sizes.numel() + 1, sizes.numel())
454455
.copy_(sizes.reshape({-1}));
455456
auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
456457
const auto offsets_ptr = metadata.data_ptr<int>();

c10/core/SymInt.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,6 @@ bool SymInt::operator>=(SymInt sci) const {
136136
return res[0]->ge(res[1])->bool_();
137137
}
138138

139-
SymInt SymInt::min(SymInt sci) const {
140-
if (!is_symbolic() && !sci.is_symbolic()) {
141-
return std::min(data_, sci.data_);
142-
}
143-
auto res = normalize_symints(*this, sci);
144-
return SymInt::toSymInt(res[0]->min(res[1]));
145-
}
146-
SymInt SymInt::max(SymInt sci) const {
147-
if (!is_symbolic() && !sci.is_symbolic()) {
148-
return std::max(data_, sci.data_);
149-
}
150-
auto res = normalize_symints(*this, sci);
151-
return SymInt::toSymInt(res[0]->max(res[1]));
152-
}
153-
154139
void SymInt::operator*=(SymInt sci) {
155140
*this = *this * sci;
156141
}

c10/core/SymInt.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@ class C10_API SymInt {
170170
void operator*=(SymInt sci);
171171
void operator+=(SymInt sci);
172172

173-
SymInt min(SymInt sci) const;
174-
SymInt max(SymInt sci) const;
175-
176173
SymInt operator*(int64_t sci) const;
177174
bool operator<(int64_t sci) const;
178175
bool operator==(int64_t sci) const;

c10/core/SymIntNodeImpl.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
6363
virtual SymIntNode ceil() {
6464
TORCH_CHECK(false, "NYI");
6565
};
66-
virtual SymIntNode min(const SymIntNode& other) {
67-
TORCH_CHECK(false, "NYI");
68-
};
69-
virtual SymIntNode max(const SymIntNode& other) {
70-
TORCH_CHECK(false, "NYI");
71-
};
7266
virtual SymIntNode clone() {
7367
TORCH_CHECK(false, "NYI");
7468
};

test/functorch/test_aotdispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
10621062
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
10631063
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10641064
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
1065+
xfail('nn.functional.linear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10651066
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
10661067
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10671068
xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices_backward.default - couldn't find s...
@@ -1136,6 +1137,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
11361137
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
11371138
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/deco...
11381139
xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
1140+
xfail('squeeze', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
11391141
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
11401142
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
11411143
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

test/test_proxy_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,7 @@ def f(a, b, c, d, e):
10561056
xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
10571057
xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
10581058
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
1059+
xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
10591060
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
10601061
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
10611062
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition

tools/autograd/derivatives.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,19 +1493,19 @@
14931493
result: auto_element_wise
14941494

14951495
- name: squeeze(Tensor(a) self) -> Tensor(a)
1496-
self: unsqueeze_to(grad, self.sym_sizes())
1496+
self: unsqueeze_to(grad, self.sizes())
14971497
result: auto_linear
14981498

14991499
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
1500-
self: unsqueeze_to(grad, dim, self.sym_sizes())
1500+
self: unsqueeze_to(grad, dim, self.sizes())
15011501
result: auto_linear
15021502

15031503
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
1504-
self: unsqueeze_to(grad, self.sym_sizes())
1504+
self: unsqueeze_to(grad, self.sizes())
15051505
result: auto_linear
15061506

15071507
- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
1508-
self: unsqueeze_to(grad, dim, self.sym_sizes())
1508+
self: unsqueeze_to(grad, dim, self.sizes())
15091509
result: auto_linear
15101510

15111511
- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor

torch/_subclasses/fake_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,6 @@ def wrap(e, device=None):
885885
def functions_with_cpp_meta_impl_that_support_symint(self):
886886
return [
887887
aten.empty_strided.default,
888-
aten.as_strided_scatter.default,
889888
aten.as_strided.default,
890889
aten.zeros.default,
891890
aten.detach.default,

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -848,26 +848,23 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
848848
return at::stack(grads_tensors, dim);
849849
}
850850

851-
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
851+
Tensor unsqueeze_to(const Tensor& self, IntArrayRef sizes) {
852852
auto result = self;
853853

854-
int64_t nDims = sym_sizes.size();
854+
int64_t nDims = sizes.size();
855855
for (const auto dim : c10::irange(nDims)) {
856-
if (sym_sizes[dim] == 1) {
856+
if (sizes[dim] == 1) {
857857
result = result.unsqueeze(dim);
858858
}
859859
}
860860
return result;
861861
}
862862

863-
Tensor unsqueeze_to(
864-
const Tensor& self,
865-
int64_t dim,
866-
c10::SymIntArrayRef sym_sizes) {
867-
dim = at::maybe_wrap_dim(dim, sym_sizes.size());
863+
Tensor unsqueeze_to(const Tensor& self, int64_t dim, IntArrayRef sizes) {
864+
dim = at::maybe_wrap_dim(dim, sizes.size());
868865
// in NumPy it's not an error to unsqueeze a scalar, but we still need to
869866
// avoided unsqueezing in the backward.
870-
if (sym_sizes.size() > 0 && sym_sizes[dim] == 1) {
867+
if (sizes.size() > 0 && sizes[dim] == 1) {
871868
return self.unsqueeze(dim);
872869
}
873870
return self;
@@ -2839,27 +2836,21 @@ Tensor as_strided_backward(
28392836

28402837
// Step (1): create underlying tensor as "storage"
28412838
auto shared_offset =
2842-
// TODO: symint-ify. Do we need a min() and max() for SymInts?
2843-
input_geometry.sym_storage_offset().min(sym_storage_offset);
2839+
std::min(input_geometry.sym_storage_offset(), sym_storage_offset);
28442840
auto inp_effective_offset =
28452841
input_geometry.sym_storage_offset() - shared_offset;
28462842
auto out_effective_offset = sym_storage_offset - shared_offset;
2847-
auto base_size1 =
2848-
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
2849-
auto base_size2 =
2850-
_min_storage_size(out_sizes_, out_strides_, out_effective_offset);
2851-
auto base_size = base_size1.max(base_size2);
2852-
auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));
2843+
auto base_size = std::max(
2844+
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
2845+
_min_storage_size(out_sizes_, out_strides_, out_effective_offset));
2846+
auto storage = grad.new_empty_symint(c10::SymIntArrayRef(base_size));
2847+
storage.zero_();
28532848

28542849
// prepare indices tensor if we will do index_add_ later
28552850
c10::optional<at::Tensor> flatten_full_indices;
28562851
if (inp_maybe_overlap || out_maybe_overlap) {
28572852
flatten_full_indices =
2858-
// TODO: should we symint-ify arange? Need SymScalar.
2859-
at::arange(
2860-
0,
2861-
base_size.guard_int(__FILE__, __LINE__),
2862-
grad.options().dtype(at::kLong));
2853+
at::arange(0, base_size, grad.options().dtype(at::kLong));
28632854
}
28642855

28652856
// Step (2): use output geometry to scatter gradients into storage

torch/csrc/autograd/FunctionsManual.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ at::Tensor logcumsumexp_backward(
215215
at::Tensor result,
216216
int64_t dim);
217217
at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
218-
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
218+
at::Tensor unsqueeze_to(const at::Tensor& self, at::IntArrayRef sizes);
219219
at::Tensor unsqueeze_to(
220220
const at::Tensor& self,
221221
int64_t dim,
222-
c10::SymIntArrayRef sym_sizes);
222+
at::IntArrayRef sizes);
223223
std::vector<at::Tensor> cat_tensors_backward(
224224
const at::Tensor& grad,
225225
const std::vector<std::vector<c10::SymInt>>& sizes,

0 commit comments

Comments
 (0)
0