8000 Implement `np.diff` for single order differences (#50569) · pytorch/pytorch@b18eeaa · GitHub
[go: up one dir, main page]

Skip to content

Commit b18eeaa

Browse files
soulitzerfacebook-github-bot
authored andcommitted
Implement np.diff for single order differences (#50569)
Summary: Implements `np.diff` for single order differences only: - method and function variants for `diff` and function variant for `diff_out` - supports out variant, but not in-place since shape changes - adds OpInfo entry, and test in `test_torch` - automatic autograd because we are using the `Math` dispatch _Update: we only support Tensors for prepend and append in this PR. See discussion below and comments for more details._ Currently there is a quirk in the c++ API based on how this is implemented: it is not possible to specify scalar prepend and appends without also specifying all 4 arguments. That is because the goal is to match NumPy's diff signature of `diff(int n=1, int dim=-1, Union[Scalar, Tensor] prepend=None, Union[Scalar, Tensor] append)=None` where all arguments are optional, positional and in the correct order. There are a couple blockers. One is c++ ambiguity. This prevents us from simply doing `diff(int n=1, int dim=-1, Scalar? prepend=None, Tensor? append=None)` etc for all combinations of {Tensor, Scalar} x {Tensor, Scalar}. Why not have append, prepend not have default args and then write out the whole power set of {Tensor, Scalar, omitted} x {Tensor, Scalar, omitted} you might ask. Aside from having to write 18 overloads, this is actually illegal because arguments with defaults must come after arguments without defaults. This would mean having to write `diff(prepend, append, n, dim)` which is not desired. Finally writing out the entire power set of all arguments n, dim, prepend, append is out of the question because that would actually involve 2 * 2 * 3 * 3 = 36 combinations. And if we include the out variant, that would be 72 overloads! With this in mind, the current way this is implemented is actually to still do `diff(int n=1, int dim=-1, Scalar? prepend=None, Tensor? append=None)`. But also make use of `cpp_no_default_args`. The idea is to only have one of the 4 {Tensor, Scalar} x {Tensor, Scalar} provide default arguments for the c++ api, and add `cpp_no_default_args` for the remaining 3 overloads. With this, Python api works as expected, but some calls such as `diff(prepend=1)` won't work on c++ api. We can optionally add 18 more overloads that cover the {dim, n, no-args} x {scalar-tensor, tensor-scalar, scalar-scalar} x {out, non-out} cases for c++ api. _[edit: counting is hard - just realized this number is still wrong. We should try to count the cases we do cover instead and subtract that from the total: (2 * 2 * 3 * 3) - (3 + 2^4) = 17. 3 comes from the 3 of 4 combinations of {tensor, scalar}^2 that we declare to be `cpp_no_default_args`, and the one remaining case that has default arguments has covers 2^4 cases. So actual count is 34 additional overloads to support all possible calls]_ _[edit: thanks to #50767 hacky_wrapper is no longer necessary; it is removed in the latest commit]_ hacky_wrapper was also necessary here because `Tensor?` will cause dispatch to look for the `const optional<Tensor>&` schema but also generate a `const Tensor&` declaration in Functions.h. hacky_wrapper allows us to define our function as `const Tensor&` but wraps it in optional for us, so this avoids both the errors while linking and loading. _[edit: rewrote the above to improve clarity and correct the fact that we actually need 18 more overloads (26 total), not 18 in total to complete the c++ api]_ Pull Request resolved: #50569 Reviewed By: H-Huang Differential Revision: D26176105 Pulled By: soulitzer fbshipit-source-id: cd8e77cc2de1117c876cd71c29b312887daca33f
1 parent e54cbb8 commit b18eeaa

File tree

10 files changed

+261
-1
lines changed

10 files changed

+261
-1
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ _(aten, diag_embed) \
289289
_(aten, diagflat) \
290290
_(aten, diagonal) \
291291
_(aten, fill_diagonal_) \
292+
_(aten, diff) \
292293
_(aten, digamma) \
293294
_(aten, dim) \
294295
_(aten, dist) \

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,90 @@ Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor&
413413
return result.scatter_add_(dim, indices, grad);
414414
}
415415

416+
static Tensor prepend_append_on_dim(const Tensor& self, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append, int64_t dim) {
417+
// Helper for diff that handles prepending and appending when at least one is present
418+
TORCH_INTERNAL_ASSERT(prepend.has_value() || append.has_value(), "either prepend or append must be have value");
419+
if (!prepend.has_value() && append.has_value()) {
420+
return at::cat({self, append.value()}, dim);
421+
} else if (prepend.has_value() && !append.has_value()) {
422+
return at::cat({prepend.value(), self}, dim);
423+
} else {
424+
return at::cat({prepend.value(), self, append.value()}, dim);
425+
}
426+
}
427+
428+
static inline void diff_check_compatible_shape(const Tensor& self, const c10::optional<Tensor>&other, int64_t dim) {
429+
// Helper for diff that checks whether the shape of the tensor to prepend or append
430+
// is compatible with that of input
431+
if (other.has_value()) {
432+
int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim(), false);
433+
434+
TORCH_CHECK(
435+
other.value().dim() == self.dim(),
436+
"diff expects prepend or append to be the same dimension as input");
437+
438+
for (int i = 0; i < other.value().dim(); i++) {
439+
TORCH_CHECK(
440+
other.value().size(i) == self.size(i) || i == wrapped_dim,
441+
"diff expects the shape of tensor to prepend or append to match that of"
442+
" input except along the differencing dimension;"
443+
" input.size(", i, ") = ", self.size(i), ", but got"
444+
" tensor.size(", i, ") = ", other.value().size(i));
445+
}
446+
}
447+
}
448+
449+
static inline void diff_check(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>&prepend, const c10::optional<Tensor>& append) {
450+
// Helper for diff that checks whether its parameters are valid
451+
TORCH_CHECK(
452+
n == 1,
453+
"diff only supports n = 1 currently. Please file an issue at"
454+
" https://github.com/pytorch/pytorch/issues/new?assignees=&labels=&template=feature-request.md"
455+
" if your use case requires supporting higher-order differences");
456+
457+
TORCH_CHECK(
458+
self.dim() >= 1,
459+
"diff expects input to be at least one-dimensional");
460+
461+
diff_check_compatible_shape(self, prepend, dim);
462+
diff_check_compatible_shape(self, append, dim);
463+
}
464+
465+
static inline Tensor diff_helper(const Tensor& self, int64_t n, int64_t dim) {
466+
auto out_len = self.size(dim) - 1;
467+
if (self.dtype() == at::kBool) {
468+
return at::logical_xor(at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
469+
}
470+
return at::narrow(self, dim, 1, out_len) - at::narrow(self, dim, 0, out_len);
471+
}
472+
473+
Tensor diff(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append) {
474+
diff_check(self, n, dim, prepend, append);
475+
if (!prepend.has_value() && !append.has_value()) {
476+
return diff_helper(self, n, dim);
477+
} else {
478+
auto a = prepend_append_on_dim(self, prepend, append, dim);
479+
return diff_helper(a, n, dim);
480+
}
481+
}
482+
483+
static inline Tensor& diff_out_helper(const Tensor& self, int64_t n, int64_t dim, Tensor& result) {
484+
auto out_len = self.size(dim) - 1;
485+
if (self.dtype() == at::kBool) {
486+
return at::logical_xor_out(result, at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
487+
}
488+
return at::sub_out(result, at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
489+
}
490+
491+
Tensor& diff_out(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append, Tensor& result) {
492+
diff_check(self, n, dim, prepend, append);
493+
if (!prepend.has_value() && !append.has_value()) {
494+
return diff_out_helper(self, n, dim, result);
495+
} else {
496+
auto a = prepend_append_on_dim(self, prepend, append, dim);
497+
return diff_out_helper(a, n, dim, result);
498+
}
499+
}
416500

417501
// ALL REDUCE #################################################################
418502

aten/src/ATen/native/native_functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,16 @@
13651365
- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
13661366
variants: method
13671367

1368+
- func: diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor
1369+
variants: function, method
1370+
dispatch:
1371+
Math: diff
1372+
1373+
- func: diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)
1374+
variants: function
1375+
dispatch:
1376+
Math: diff_out
1377+
13681378
- func: div.Tensor(Tensor self, Tensor other) -> Tensor
13691379
variants: function, method
13701380
dispatch:

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ view of a storage and defines numeric operations on it.
290290
.. automethod:: fill_diagonal_
291291
.. automethod:: fmax
292292
.. automethod:: fmin
293+
.. automethod:: diff
293294
.. automethod:: digamma
294295
.. automethod:: digamma_
295296
.. automethod:: dim

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ Other Operations
470470
diag_embed
471471
diagflat
472472
diagonal
473+
diff
473474
einsum
474475
flatten
475476
flip

test/test_torch.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
2424
skipCUDAMemoryLeakCheckIf, BytesIOContext,
2525
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
26-
wrapDeterministicFlagAPITest, DeterministicGuard)
26+
wrapDeterministicFlagAPITest, DeterministicGuard, make_tensor)
2727
from multiprocessing.reduction import ForkingPickler
2828
from torch.testing._internal.common_device_type import (
2929
instantiate_device_type_tests,
@@ -4131,6 +4131,87 @@ def logcumsumexp(a, axis):
41314131
'expected scalar_type Double but found Float'):
41324132
torch.logcumsumexp(b, axis, out=inplace_out)
41334133

4134+
def _test_diff_numpy(self, t, dims=None):
4135+
# Helper for test_diff to compare with NumPy reference implementation
4136+
def to_np(t):
4137+
if t.dtype == torch.bfloat16:
4138+
return t.to(dtype=torch.float, device="cpu").numpy()
4139+
else:
4140+
return t.cpu().numpy()
4141+
4142+
for dim in dims if dims else range(t.dim()):
4143+
prepend = t.narrow(dim, 0, 1)
4144+
append = t.narrow(dim, 0, 1)
4145+
np_t = to_np(t)
4146+
4147+
# test when prepend and append's size along dim is 1
4148+
actual = torch.diff(t, dim=dim, prepend=prepend, append=append)
4149+
expected = torch.from_numpy(np.diff(np_t, axis=dim, prepend=to_np(prepend), append=to_np(append)))
4150+
self.assertEqual(actual, expected.to(t.dtype))
4151+
4152+
# test when prepend and append's size along dim != 1
4153+
actual = torch.diff(t, dim=dim, prepend=t, append=t)
4154+
expected = torch.from_numpy(np.diff(np_t, axis=dim, prepend=np_t, append=np_t))
4155+
self.assertEqual(actual, expected.to(t.dtype))
4156+
4157+
# All tensors appear contiguous on XLA
4158+
@onlyOnCPUAndCUDA
4159+
@dtypes(*torch.testing.get_all_dtypes())
4160+
def test_diff_noncontig(self, device, dtype):
4161+
shapes = (
4162+
(1,),
4163+
(1, 5),
4164+
(3, 5),
4165+
(1, 5, 1),
4166+
(2, 3, 5))
4167+
4168+
for shape in shapes:
4169+
contig = make_tensor(shape, device, dtype, low=-9, high=9)
4170+
4171+
non_contig = torch.empty(shape + (2, 2), device=device, dtype=dtype)[..., 0]
4172+
non_contig = non_contig.select(-1, -1)
4173+
non_contig.copy_(contig)
4174+
self.assertTrue(not non_contig.is_contiguous() or shape == (1,))
4175+
4176+
self._test_diff_numpy(non_contig)
4177+
4178+
# RngNormal not implemented for type f16 for XLA
4179+
@dtypes(*torch.testing.get_all_dtypes(include_half=False))
4180+
@dtypesIfCPU(*torch.testing.get_all_dtypes())
4181+
@dtypesIfCUDA(*torch.testing.get_all_dtypes())
4182+
def test_diff(self, device, dtype):
4183+
shapes = (
4184+
(1,),
4185+
(1, 5),
4186+
(3, 5),
4187+
(1, 5, 1),
4188+
(2, 3, 5))
4189+
4190+
for shape in shapes:
4191+
contig = make_tensor(shape, device, dtype, low=-9, high=9)
4192+
self._test_diff_numpy(contig)
4193+
4194+
t = torch.ones(2, 3)
4195+
4196+
with self.assertRaisesRegex(
4197+
RuntimeError, 'diff expects prepend or append to be the same dimension as input'):
4198+
invalid_prepend = torch.tensor([1, 2, 3], device=device, dtype=dtype)
4199+
t.diff(dim=0, prepend=invalid_prepend)
4200+
4201+
with self.assertRaisesRegex(
4202+
RuntimeError, 'diff expects the shape of tensor to prepend or append to match that of input'):
4203+
invalid_prepend = torch.tensor([[0, 1]], device=device, dtype=dtype)
4204+
t.diff(dim=0, prepend=invalid_prepend)
4205+
4206+
with self.assertRaisesRegex(
4207+
RuntimeError, 'diff only supports n = 1 currently'):
4208+
torch.diff(t, n=2)
4209+
4210+
with self.assertRaisesRegex(
4211+
RuntimeError, 'diff expects input to be at least one-dimensional'):
4212+
scalar = torch.tensor(2, device=device, dtype=dtype)
4213+
torch.diff(scalar)
4214+
41344215
def _test_large_cum_fn_helper(self, x, fn):
41354216
x_cpu = x.cpu().float()
41364217
expected = fn(x_cpu)

torch/_tensor_docs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,13 @@ def add_docstr_all(method, docstr):
11581158
In-place version of :meth:`~Tensor.floor_divide`
11591159
""")
11601160

1161+
add_docstr_all('diff',
1162+
r"""
1163+
diff(n=1, dim=-1, prepend=None, append=None) -> Tensor
1164+
1165+
See :func:`torch.diff`
1166+
""")
1167+
11611168
add_docstr_all('digamma',
11621169
r"""
11631170
digamma() -> Tensor

torch/_torch_docs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,6 +2642,45 @@ def merge_dicts(*dicts):
26422642
[ 1.0500, 0.7336, -0.3836, -1.1015]]])
26432643
""".format(**common_args))
26442644

2645+
add_docstr(torch.diff, r"""
2646+
diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor
2647+
2648+
Computes the n-th forward difference along the given dimension.
2649+
2650+
The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order
2651+
differences are calculated by using :func:`torch.diff` recursively.
2652+
2653+
.. note:: Only `n = 1` is currently supported
2654+
2655+
Args:
2656+
input (Tensor): the tensor to compute the differences on
2657+
n (int, optional): the number of times to recursively compute the difference
2658+
dim (int, optional): the dimension to compute the difference along.
2659+
Default is the last dimension.
2660+
prepend, append (Tensor, optional): values to prepend or append to
2661+
:attr:`input` along :attr:`dim` before computing the difference.
2662+
Their dimensions must be equivalent to that of input, and their shapes
2663+
must match input's shape except on :attr:`dim`.
2664+
2665+
Keyword args:
2666+
{out}
2667+
2668+
Example::
2669+
2670+
>>> a = torch.tensor([1, 3, 2])
2671+
>>> torch.diff(a)
2672+
tensor([ 2, -1])
2673+
>>> b = torch.tensor([4, 5])
2674+
>>> torch.diff(a, append=b)
2675+
tensor([ 2, -1, 2, 1])
2676+
>>> c = torch.tensor([[1, 2, 3], [3, 4, 5]])
2677+
>>> torch.diff(c, dim=0)
2678+
tensor([[2, 2, 2]])
2679+
>>> torch.diff(c, dim=1)
2680+
tensor([[1, 1],
2681+
[1, 1]])
2682+
""".format(**common_args))
2683+
26452684
add_docstr(torch.digamma, r"""
26462685
digamma(input, *, out=None) -> Tensor
26472686

torch/overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
364364
torch.diag: lambda input, diagonal=0, out=None: -1,
365365
torch.diag_embed: lambda input, diagonal=0, out=None: -1,
366366
torch.diagflat: lambda input, offset=0: -1,
367+
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
367368
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
368369
torch.digamma: lambda input, out=None: -1,
369370
torch.dist: lambda input, other, p=2: -1,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,33 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad):
623623
0, torch.tensor(0, dtype=torch.int64, device=device))),
624624
)
625625

626+
def sample_inputs_diff(op_info, device, dtype, requires_grad):
627+
test_cases = (
628+
((1,), 0, None, None),
629+
((S,), 0, None, None),
630+
((S, 1), 0, None, None),
631+
((S, 1), 1, None, None),
632+
((S, S), 0, None, None),
633+
((S, S), 1, None, None),
634+
((S, S), 0, (1, S), (2, S)),
635+
((S, S), 0, None, (2, S)),
636+
((S, S, S), 1, None, None),
637+
((S, S, S), 1, (S, 1, S), (S, 1, S)),)
638+
639+
sample_inputs = []
640+
for size, dim, size_prepend, size_append in test_cases:
641+
args = (make_tensor(size, device, dtype,
642+
low=None, high=None,
643+
requires_grad=requires_grad), 1, dim,
644+
make_tensor(size_prepend, device, dtype,
645+
low=None, high=None,
646+
requires_grad=requires_grad) if size_prepend else None,
647+
make_tensor(size_append, device, dtype,
648+
low=None, high=None,
649+
requires_grad=requires_grad) if size_append else None)
650+
sample_inputs += [SampleInput(args)]
651+
652+
return tuple(sample_inputs)
626653

627654
def sample_inputs_index_select(op_info, device, dtype, requires_grad):
628655
return (SampleInput((make_tensor((S, S, S), device, dtype,
@@ -1435,6 +1462,11 @@ def sample_inputs_masked_select(op_info, device, dtype, requires_grad):
14351462
SkipInfo('TestCommon', 'test_variant_consistency_jit',
14361463
device_type='cuda', dtypes=[torch.float16]),
14371464
)),
1465+
OpInfo('diff',
1466+
op=torch.diff,
1467+
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
1468+
sample_inputs_func=sample_inputs_diff,
1469+
test_inplace_grad=False),
14381470
UnaryUfuncInfo('exp',
14391471
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
14401472
dtypes=all_types_and_complex_and(torch.bool, torch.half),
@@ -2323,6 +2355,9 @@ def __len__(self):
23232355
def ident(x):
23242356
return x
23252357

2358+
# Do NOT add to this list. Method tests are being DEPRECATED and replaced by OpInfos.
2359+
# See https://github.com/pytorch/pytorch/wiki/Writing-tests-in-PyTorch-1.8
2360+
#
23262361
# (
23272362
# method name,
23282363
# input size/constructing fn,

0 commit comments

Comments
 (0)
0