8000 Reland 2 of Merge more symbolic meta kernels and symint changes from … · csarofeen/pytorch@978b46d · GitHub
[go: up one dir, main page]

Skip to content

Commit 978b46d

Browse files
albanDpytorchmergebot
authored andcommitted
Reland 2 of Merge more symbolic meta kernels and symint changes from branch (pytorch#86334) (pytorch#86488)
symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops add meta_bernoulli_ meta kernel for at::gather get pytorch_struct to pass: meta for scatter_add, fix backward symintify split ops Pull Request resolved: pytorch#86488 Approved by: https://github.com/ezyang
1 parent 55663b7 commit 978b46d

File tree

17 files changed

+324
-65
lines changed

17 files changed

+324
-65
lines changed

aten/src/ATen/FunctionalInverses.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,31 +177,31 @@ Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const T
177177
return base.slice_scatter_symint(mutated_view, dim, start, end, step);
178178
}
179179

180-
Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, int64_t split_size, int64_t dim) {
180+
Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
181181
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
182182
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
183183
// on top of the base tensor.
184184
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
185-
dim = at::maybe_wrap_dim(dim, base.sizes().size());
186-
auto dim_size = base.size(dim);
187-
auto start = mutated_view_idx * split_size;
188-
auto end = start + split_size;
185+
dim = at::maybe_wrap_dim(dim, base.dim());
186+
auto dim_size = base.sym_size(dim);
187+
auto start = split_size * mutated_view_idx;
188+
auto end = split_size + start;
189189
if (end > dim_size) end = dim_size;
190190
// Pessimism: we can't reapply views for slice_scatter.
191-
return base.slice_scatter(mutated_view, dim, start, end, 1);
191+
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
192192
}
193193

194-
Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, at::IntArrayRef split_sizes, int64_t dim) {
195-
dim = at::maybe_wrap_dim(dim, base.sizes().size());
196-
auto dim_size = base.size(dim);
197-
int64_t start = 0;
194+
Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
195+
dim = at::maybe_wrap_dim(dim, base.dim());
196+
auto dim_size = base.sym_size(dim);
197+
c10::SymInt start = 0;
198198
for (auto i = 0; i < mutated_view_idx; ++i) {
199199
start += split_sizes[i];
200200
}
201201
auto end = start + split_sizes[mutated_view_idx];
202202
if (end > dim_size) end = dim_size;
203203
// Pessimism: we can't reapply views for slice_scatter.
204-
return base.slice_scatter(mutated_view, dim, start, end, 1);
204+
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
205205
}
206206

207207
Tensor FunctionalInverses::squeeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {

aten/src/ATen/functorch/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
200200
OP_DECOMPOSE(special_multigammaln);
201201
OP_DECOMPOSE(special_polygamma);
202202
OP_DECOMPOSE(special_softmax);
203-
OP_DECOMPOSE2(split, sizes);
203+
m.impl("split.sizes", native::split_symint);
204204
OP_DECOMPOSE(square);
205205
OP_DECOMPOSE(numpy_T);
206206
OP_DECOMPOSE(reshape_as);

aten/src/ATen/native/Dropout.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ template<bool inplace>
1212
using Ctype = typename std::conditional<inplace, Tensor&, Tensor>::type;
1313

1414
Tensor make_feature_noise(const Tensor& input) {
15-
auto input_sizes = input.sizes();
15+
auto input_sizes = input.sym_sizes();
1616
TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
17-
std::vector<int64_t> sizes;
17+
std::vector<c10::SymInt> sizes;
1818
sizes.reserve(input.dim());
1919
sizes.push_back(input_sizes[0]);
2020
sizes.push_back(input_sizes[1]);
2121
for (const auto i : c10::irange(2, input.dim())) {
2222
(void)i; //Suppress unused variable warning
2323
sizes.push_back(1);
2424
}
25-
return input.new_empty(sizes);
25+
return input.new_empty_symint(sizes);
2626
}
2727

2828
bool is_fused_kernel_acceptable(const Tensor& input, double p) {
@@ -46,7 +46,7 @@ Tensor multiply(const Tensor& input, const Tensor& noise) {
4646
template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
4747
Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
4848
TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
49-
if (p == 0 || !train || input.numel() == 0) {
49+
if (p == 0 || !train || input.sym_numel() == 0) {
5050
return input;
5151
}
5252

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, cons
13651365
if (sparse_grad) {
13661366
return at::_gather_sparse_backward(self, dim, index, grad);
13671367
}
1368-
auto result = grad.new_zeros(self.sizes());
1368+
auto result = grad.new_zeros_symint(self.sym_sizes());
13691369
// for composite compliance, use out-of-place variant of
13701370
// `scatter_add` if index tensor is a Tensor Subclass.
13711371
if (isTensorSubclassLike(index)) {

aten/src/ATen/native/TensorShape.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -723,19 +723,19 @@ std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
723723
TORCH_CHECK(chunks > 0,
724724
"chunk expects `chunks` to be greater than 0, got: ", chunks);
725725

726-
const auto dim_size = self.size(dim);
727-
int64_t split_size = (dim_size + chunks - 1) / chunks;
726+
const auto dim_size = self.sym_size(dim);
727+
auto split_size = (dim_size + chunks - 1) / chunks;
728728

729729
// We need to call split_with_sizes in the case where split_size and dimension size are 0, because
730730
// a call to split would discard the number of chunks (because we can have an arbitrary number of
731731
// 0-sized chunks adding up to 0). So, call split_with_sizes with the correct number of chunks,
732732
// eventually we will do this for all cases.
733733
if (split_size == 0 && dim_size == 0) {
734-
std::vector<int64_t> split_sizes(chunks, split_size);
734+
std::vector<c10::SymInt> split_sizes(chunks, split_size);
735735
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
736-
return self.split_with_sizes(split_sizes, dim);
736+
return self.split_with_sizes_symint(split_sizes, dim);
737737
} else {
738-
return self.split(split_size, dim);
738+
return self.split_symint(split_size, dim);
739739
}
740740
}
741741

@@ -2273,8 +2273,8 @@ std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
22732273
return splits;
22742274
}
22752275

2276-
std::vector<Tensor> split(const Tensor& self, IntArrayRef sizes, int64_t dim) {
2277-
return at::split_with_sizes(self, sizes, dim);
2276+
std::vector<Tensor> split_symint(const Tensor& self, c10::SymIntArrayRef sizes, int64_t dim) {
2277+
return at::split_with_sizes_symint(self, sizes, dim);
22782278
}
22792279

22802280
std::vector<Tensor> unsafe_split(const Tensor& self, int64_t split_size, int64_t dim) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4782,34 +4782,36 @@
47824782
CUDA: softmax_backward_cuda_out
47834783
MPS: softmax_backward_mps_out
47844784

4785-
- func: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
4785+
- func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
47864786
variants: function, method
47874787
device_check: NoCheck
47884788
device_guard: False
47894789
dispatch:
47904790
CompositeExplicitAutograd: unsafe_split
47914791
autogen: unsafe_split.Tensor_out
47924792

4793-
- func: split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]
4793+
- func: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
47944794
variants: function, method
47954795
device_check: NoCheck
47964796
device_guard: False
47974797
dispatch:
47984798
CompositeExplicitAutograd: split
47994799

4800-
- func: split.sizes(Tensor(a -> *) self, int[] split_size, int dim=0) -> Tensor(a)[]
4800+
- func: split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
48014801
variants: function, method
48024802
device_guard: False
4803+
dispatch:
4804+
CompositeImplicitAutograd: split_symint
48034805

4804-
- func: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
4806+
- func: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
48054807
variants: function, method
48064808
device_check: NoCheck
48074809
device_guard: False
48084810
dispatch:
48094811
CompositeExplicitAutograd: unsafe_split_with_sizes
48104812
autogen: unsafe_split_with_sizes.out
48114813

4812-
- func: split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]
4814+
- func: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
48134815
variants: function, method
48144816
device_check: NoCheck
48154817
device_guard: False
@@ -12810,13 +12812,13 @@
1281012812
CompositeExplicitAutogradNonFunctional: slice_copy_Tensor
1281112813
tags: view_copy
1281212814

12813-
- func: split_copy.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
12815+
- func: split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
1281412816
variants: function
1281512817
dispatch:
1281612818
CompositeExplicitAutogradNonFunctional: split_copy_Tensor
1281712819
tags: view_copy
1281812820

12819-
- func: split_with_sizes_copy(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
12821+
- func: split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
1282012822
variants: function
1282112823
dispatch:
1282212824
CompositeExplicitAutogradNonFunctional: split_with_sizes_copy
@@ -13022,13 +13024,13 @@
1302213024
CompositeExplicitAutograd: slice_copy_Tensor_out
1302313025

1302413026

13025-
- func: split_copy.Tensor_out(Tensor self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()
13027+
- func: split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
1302613028
variants: function
1302713029
dispatch:
1302813030
CompositeExplicitAutograd: split_copy_Tensor_out
1302913031

1303013032

13031-
- func: split_with_sizes_copy.out(Tensor self, int[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
13033+
- func: split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
1303213034
variants: function
1303313035
dispatch:
1303413036
CompositeExplicitAutograd: split_with_sizes_copy_out

aten/src/ATen/native/quantized/cpu/fused_obs_fake_quant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ at::Tensor fused_moving_avg_obs_fake_quant(
236236
const int64_t ch_axis,
237237
bool per_row_fake_quant,
238238
bool symmetric_quant) {
239-
if (self.numel() == 0) {
239+
if (self.sym_numel() == 0) {
240240
return self.clone();
241241
}
242242
const auto res = at::_fused_moving_avg_obs_fq_helper(

c10/core/SymFloatNodeImpl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ c10::SymIntNode SymFloatNodeImpl::ceil() {
1313
TORCH_CHECK(false, "NYI");
1414
}
1515

16+
c10::SymIntNode SymFloatNodeImpl::floor() {
17+
TORCH_CHECK(false, "NYI");
18+
}
19+
1620
} // namespace c10

c10/core/SymFloatNodeImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target {
6060
TORCH_CHECK(false, "NYI");
6161
};
6262
virtual SymIntNode ceil();
63+
virtual SymIntNode floor();
6364
virtual std::string str() {
6465
TORCH_CHECK(false, "NYI");
6566
};

functorch/test/test_aotdispatch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
785785
xfail('fmax', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
786786
xfail('fmin', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
787787
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
788-
xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition
789788
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
790789
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
791790
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
@@ -975,7 +974,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
975974
xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition
976975
xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition
977976
xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decompos...
978-
xfail('scatter_add', ''), # aten.scatter_add.default - couldn't find symbolic meta function/decomposition
979977
xfail('scatter', ''), # aten.scatter.src - couldn't find symbolic meta function/decomposition
980978
xfail('scatter_reduce', 'amax'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decom...
981979
xfail('scatter_reduce', 'amin'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decom...
@@ -993,8 +991,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
993991
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
994992
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/deco...
995993
xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
996-
xfail('split', 'list_args'), # Cannot call sizes() on tensor with symbolic sizes/strides
997-
xfail('split_with_sizes', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
998994
xfail('squeeze', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
999995
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
1000996
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides

0 commit comments

Comments
 (0)
0