8000 Implementing NumPy-like function torch.heaviside() by Kiyosora · Pull Request #42523 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Implementing NumPy-like function torch.heaviside() #42523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove scale implement
  • Loading branch information
Kiyosora committed Aug 29, 2020
commit ca2a5c6f842dfc8825a3a23ba620d42bc5447577
38 changes: 12 additions & 26 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,45 +932,31 @@ Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scal
return self - (other * alpha);
}

Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& val) {
TORCH_CHECK(!self.is_complex() && !result.is_complex() && !val.is_complex(),
"heaviside is not implemented for complex tensors.");
TORCH_CHECK(self.dtype() == val.dtype() && result.dtype() == self.dtype(),
Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
"heaviside is not yet implemented for complex tensors.");
TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(),
"heaviside is not yet implemented for tensors with different dtypes.");

auto iter = TensorIterator::binary_op(result, self, val, /*check_mem_overlap=*/true);
auto iter = TensorIterator::binary_op(result, self, values, /*check_mem_overlap=*/true);
heaviside_stub(iter.device_type(), iter);
return result;
}

Tensor& heaviside_out(Tensor& result, const Tensor& self, Scalar val) {
Tensor val_t = wrapped_scalar_tensor_and_check_convert(val, self);
return at::heaviside_out(result, self, val_t);
}

Tensor heaviside(const Tensor& self, const Tensor& val) {
TORCH_CHECK(!self.is_complex() && !val.is_complex(),
"heaviside is not implemented for complex tensors.");
TORCH_CHECK(self.dtype() == val.dtype(),
Tensor heaviside(const Tensor& self, const Tensor& values) {
TORCH_CHECK(!self.is_complex() && !values.is_complex(),
"heaviside is not yet implemented for complex tensors.");
TORCH_CHECK(self.dtype() == values.dtype(),
"heaviside is not yet implemented for tensors with different dtypes.");

Tensor result;
auto iter = TensorIterator::binary_op(result, self, val);
auto iter = TensorIterator::binary_op(result, self, values);
heaviside_stub(iter.device_type(), iter);
return iter.output();
}

Tensor heaviside(const Tensor& self, Scalar val) {
Tensor val_t = wrapped_scalar_tensor_and_check_convert(val, self);
return at::heaviside(self, val_t);
}

Tensor& heaviside_(Tensor& self, const Tensor& val) {
return at::heaviside_out(self, self, val);
}

Tensor& heaviside_(Tensor& self, Scalar val) {
return at::heaviside_out(self, self, val);
Tensor& heaviside_(Tensor& self, const Tensor& values) {
return at::heaviside_out(self, self, values);
}

// TODO: Deduplicate this with the TensorIterator logic. This would
Expand Down
33 changes: 11 additions & 22 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3600,28 +3600,6 @@
CPU, CUDA: sub_
SparseCPU, SparseCUDA: sub_sparse_

- func: heaviside.Tensor_out(Tensor self, Tensor val, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: heaviside_out

- func: heaviside.Scalar_out(Tensor self, Scalar val, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: heaviside_out

- func: heaviside.Tensor(Tensor self, Tensor val) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: heaviside.Scalar(Tensor self, Scalar val) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: heaviside_.Tensor(Tensor(a!) self, Tensor val) -> Tensor(a!)
variants: method

- func: heaviside_.Scalar(Tensor(a!) self, Scalar val) -> Tensor(a!)
variants: method

# For C++ only, until we have conversion from C++ numbers to Tensor
- func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
Expand All @@ -3635,6 +3613,17 @@
use_c10_dispatcher: full
variants: function

- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: heaviside_out

- func: heaviside(Tensor self, Tensor values) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
variants: method

# For C++ only, until we have conversion from C++ numbers to Tensor
- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
Expand Down
44 changes: 22 additions & 22 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6263,63 +6263,63 @@ def test_bitwise_xor(self, device):
torch.testing.get_all_dtypes(include_complex=False))))
def test_heaviside(self, device, dtypes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty cool test but I think you can simplify it by generating an array with NumPy's integers and then a second NumPy array with integers or rand depending on the dtype of the rhs argument. Then you can run NumPy's heaviside on these args to get your expected values.

For your actual values you can torch.from_numpy(lhs/rhs).to(device=device, dtype=dtype) the NumPy arrays and run your torch.heaviside. Then you can compare the results with self.assertEqual(actual, expected, exact_dtype=False), and verify the torch.heaviside function produced the correct dtype by checking the dtype of the actual tensor against the result of torch.result_type.

This should let you write a thorough but compact test, I think. What are your thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, LGTM. I would improve my test case, thanks for your advice! 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests look good. This first test will need a tweak depending on how you decide to handle type promotion. It's probably easiest to require the types be the same, then you could just check for that case and that a runtime error is thrown.

input_dtype = dtypes[0]
val_dtype = dtypes[1]
values_dtype = dtypes[1]

rng = np.random.default_rng()
input = np.array(rng.integers(-10, 10, size=10),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a few zeros to input? As written there's a chance it won't have any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input[0] = input[3] = input[7] = 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
input[0] = input[3] = input[7] = 0
val = np.array(rng.integers(-10, 10, size=10),
dtype=torch_to_numpy_dtype_dict[val_dtype if (val_dtype != torch.bfloat16) else torch.float64])
np_result = torch.from_numpy(np.heaviside(input, val)).to(device=device, dtype=input_dtype)
values = np.array(rng.integers(-10, 10, size=10),
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)

input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
val = torch.from_numpy(val).to(device=device, dtype=val_dtype)
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
out = torch.empty_like(input)

if input_dtype == val_dtype:
torch_result = torch.heaviside(input, val)
if input_dtype == values_dtype:
torch_result = torch.heaviside(input, values)
self.assertEqual(np_result, torch_result)

torch_result = input.heaviside(val)
torch_result = input.heaviside(values)
self.assertEqual(np_result, torch_result)

torch.heaviside(input, val, out=out)
torch.heaviside(input, values, out=out)
self.assertEqual(np_result, out)

input.heaviside_(val)
input.heaviside_(values)
self.assertEqual(np_result, input)
else:
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
torch.heaviside(input, val)
torch.heaviside(input, values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
input.heaviside(val)
input.heaviside(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
torch.heaviside(input, val, out=out)
torch.heaviside(input, values, out=out)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
input.heaviside_(val)
input.heaviside_(values)


@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
torch.testing.get_all_complex_dtypes())))
def test_heaviside_complex(self, device, dtypes):
input_dtype = dtypes[0]
val_dtype = dtypes[1]
values_dtype = dtypes[1]

data = (complex(0, -6), complex(-1, 3), complex(1, 1))
input = torch.tensor(data, device=device, dtype=input_dtype)
val = torch.tensor(data, device=device, dtype=val_dtype)
values = torch.tensor(data, device=device, dtype=values_dtype)
out = torch.empty_like(input)
real = input.real

with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'):
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
torch.heaviside(input, real)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'):
real.heaviside(val)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'):
input.heaviside_(val)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'):
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
real.heaviside(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
input.heaviside_(values)
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
torch.heaviside(real, real, out=out)

@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
Expand Down
10 changes: 6 additions & 4 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5611,17 +5611,19 @@ def merge_dicts(*dicts):

Args:
{input}
values (Tensor or Scalar): The value(s) to sample from where :attr:`input` is zero.
values (Tensor): The values to sample from where :attr:`input` is zero.

Keyword arguments:
{out}

Example::

>>> t = torch.tensor([-1.5, 0, 2.0])
>>> torch.heaviside(t, 0.5)
>>> input = torch.tensor([-1.5, 0, 2.0])
>>> values = torch.tensor([0.5])
>>> torch.heaviside(input, values)
tensor([0.0000, 0.5000, 1.0000])
>>> torch.heaviside(t, -2)
>>> values = torch.tensor([1.2, -2.0, 3.5])
>>> torch.heaviside(input, values)
tensor([0., -2., 1.])

""".format(**common_args))
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.gt: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, val, out=None: -1,
torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.hspmm: lambda mat1, mat2, out=None: -1,
Expand Down
0