-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Fix test_ops for tiny backend #9302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 91 commits
d9418ce
204f2c5
cd704c0
f0461cf
503bdc8
859f4c3
221a841
077380a
6f5bd3e
e1f2f83
121e317
4281f80
41f59f1
7428b1b
9082525
180c4e3
d429ab9
200c43d
31cd144
d15cc92
da824fa
2ca2726
3e48de8
cc760e0
f3afa0b
eec2e60
1e63367
a5b4976
eac9c78
7feb0b9
8e17a94
93d97ad
be56b3d
4452f05
03a8237
897d83b
b0d0af7
c838601
b316198
eb9d7b7
d959d95
c62c0fe
b65db9b
75c7993
0cdb41c
376a13d
9b297e9
1e4d868
2b2ff69
f57803b
a1f1fd6
e88bb9d
0167fd5
e1bf597
c8a7813
090845c
cd0ad8e
96ea963
a9f2808
01bbff1
c4e2ac4
674b35b
ac505fd
d285d79
cd89c25
4cf6f7a
701e216
53205ed
98bc8a2
611d302
68448c7
a7d41d2
e352b44
ae3f35a
9d969a2
2e0dd3f
bb8df79
384dab1
7f47846
4d384f0
9495e23
1e7f648
6ac2b97
4781339
c535b5c
1e87791
7db0aa9
f1f6a6a
889e1e8
4513874
da178ca
f431880
4cbdb59
fb316eb
48e5035
eb4f1a8
05c97f0
c36a190
86e46d9
72ee796
0a865ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -190,8 +190,9 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra | |
if TORCH_DEBUG >= 1: | ||
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") | ||
input, weight, bias = unwrap(input), unwrap(weight), unwrap(bias) if bias is not None else None | ||
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)) | ||
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding)) | ||
# TODO: fix test_biased_conv2d fails without realize() | ||
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).realize()) | ||
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding).realize()) | ||
|
||
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone") | ||
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask): | ||
|
@@ -205,6 +206,28 @@ def convolution_backward_overrideable(grad_out, input, weight, stride, padding, | |
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out) | ||
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask]) | ||
|
||
@torch.library.impl("aten::slice.Tensor", "privateuseone") | ||
def slice_tensor(self, dim=0, start=None, end=None, step=1): | ||
slices = [slice(None)] * unwrap(self).ndim | ||
slices[dim] = slice(start, end, step) | ||
return wrap(unwrap(self)[slices]) | ||
|
||
@torch.library.impl("aten::slice_backward", "privateuseone") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there decompositions for the backward? Always better to use those than write custom code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do see some of these functions listed in the decomp list in torch code, in |
||
def slice_backward(grad_out, input_sizes, dim=0, start=None, end=None, step=1): | ||
grad_input = Tensor.zeros(input_sizes).contiguous() | ||
slices = [slice(None)] * len(input_sizes) | ||
slices[dim] = slice(start, end, step) | ||
grad_input[slices] = unwrap(grad_out) | ||
return wrap(grad_input) | ||
|
||
@torch.library.impl("aten::select_backward", "privateuseone") | ||
def select_backward(grad_out, input_sizes, dim, index): | ||
grad_input = Tensor.zeros(input_sizes).contiguous() | ||
slices = [slice(None)] * len(input_sizes) | ||
slices[dim] = index | ||
grad_input[slices] = unwrap(grad_out) | ||
return wrap(grad_input) | ||
|
||
def avg_pool(self, kernel_size, stride=[], padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None): | ||
return wrap(unwrap(self).avg_pool2d(kernel_size, stride if stride != [] else None, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad)) | ||
|
||
|
@@ -235,6 +258,9 @@ def upsample(self, size, align_corners=False, mode=None): return wrap(Tensor.int | |
torch.library.impl(f"aten::upsample_nearest{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest")) | ||
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact")) | ||
|
||
@torch.library.impl("aten::cumsum", "privateuseone") | ||
def cumsum(self, dim): return wrap(unwrap(self).cumsum(dim).contiguous()) | ||
Anish9901 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@torch.library.impl("aten::scatter_add.out", "privateuseone") | ||
def scatter_add(self, dim, index, src, out): | ||
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out) | ||
|
Uh oh!
There was an error while loading. Please reload this page.