-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Implementing NumPy-like function torch.heaviside() #42523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
11fa3ce
70b03d5
0ec0e30
2b184d6
ca2a5c6
f8a45e5
209ba37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6263,63 +6263,63 @@ def test_bitwise_xor(self, device): | |
torch.testing.get_all_dtypes(include_complex=False)))) | ||
def test_heaviside(self, device, dtypes): | ||
input_dtype = dtypes[0] | ||
val_dtype = dtypes[1] | ||
values_dtype = dtypes[1] | ||
|
||
rng = np.random.default_rng() | ||
input = np.array(rng.integers(-10, 10, size=10), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a few zeros to input? As written there's a chance it won't have any. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed! |
||
dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) | ||
input[0] = input[3] = input[7] = 0 | ||
val = np.array(rng.integers(-10, 10, size=10), | ||
dtype=torch_to_numpy_dtype_dict[val_dtype if (val_dtype != torch.bfloat16) else torch.float64]) | ||
np_result = torch.from_numpy(np.heaviside(input, val)).to(device=device, dtype=input_dtype) | ||
values = np.array(rng.integers(-10, 10, size=10), | ||
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) | ||
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) | ||
|
||
input = torch.from_numpy(input).to(device=device, dtype=input_dtype) | ||
val = torch.from_numpy(val).to(device=device, dtype=val_dtype) | ||
values = torch.from_numpy(values).to(device=device, dtype=values_dtype) | ||
out = torch.empty_like(input) | ||
|
||
if input_dtype == val_dtype: | ||
torch_result = torch.heaviside(input, val) | ||
if input_dtype == values_dtype: | ||
torch_result = torch.heaviside(input, values) | ||
self.assertEqual(np_result, torch_result) | ||
|
||
torch_result = input.heaviside(val) | ||
torch_result = input.heaviside(values) | ||
self.assertEqual(np_result, torch_result) | ||
|
||
torch.heaviside(input, val, out=out) | ||
torch.heaviside(input, values, out=out) | ||
self.assertEqual(np_result, out) | ||
|
||
input.heaviside_(val) | ||
input.heaviside_(values) | ||
self.assertEqual(np_result, input) | ||
else: | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): | ||
torch.heaviside(input, val) | ||
torch.heaviside(input, values) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): | ||
input.heaviside(val) | ||
input.heaviside(values) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): | ||
torch.heaviside(input, val, out=out) | ||
torch.heaviside(input, values, out=out) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): | ||
input.heaviside_(val) | ||
input.heaviside_(values) | ||
|
||
|
||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found") | ||
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(), | ||
torch.testing.get_all_complex_dtypes()))) | ||
def test_heaviside_complex(self, device, dtypes): | ||
input_dtype = dtypes[0] | ||
val_dtype = dtypes[1] | ||
values_dtype = dtypes[1] | ||
|
||
data = (complex(0, -6), complex(-1, 3), complex(1, 1)) | ||
input = torch.tensor(data, device=device, dtype=input_dtype) | ||
val = torch.tensor(data, device=device, dtype=val_dtype) | ||
values = torch.tensor(data, device=device, dtype=values_dtype) | ||
out = torch.empty_like(input) | ||
real = input.real | ||
|
||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'): | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): | ||
torch.heaviside(input, real) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'): | ||
real.heaviside(val) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'): | ||
input.heaviside_(val) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not implemented for complex tensors.'): | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): | ||
real.heaviside(values) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): | ||
input.heaviside_(values) | ||
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): | ||
torch.heaviside(real, real, out=out) | ||
|
||
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty cool test but I think you can simplify it by generating an array with NumPy's integers and then a second NumPy array with integers or rand depending on the dtype of the rhs argument. Then you can run NumPy's heaviside on these args to get your expected values.
For your actual values you can
torch.from_numpy(lhs/rhs).to(device=device, dtype=dtype)
the NumPy arrays and run yourtorch.heaviside
. Then you can compare the results withself.assertEqual(actual, expected, exact_dtype=False)
, and verify thetorch.heaviside
function produced the correct dtype by checking the dtype of theactual
tensor against the result oftorch.result_type
.This should let you write a thorough but compact test, I think. What are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, LGTM. I would improve my test case, thanks for your advice! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests look good. This first test will need a tweak depending on how you decide to handle type promotion. It's probably easiest to require the types be the same, then you could just check for that case and that a runtime error is thrown.