8000 Remove unnecessary dtype checks for complex types & disable complex d… · pytorch/pytorch@3f052ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f052ba

Browse files
Remove unnecessary dtype checks for complex types & disable complex dispatch for CPU min/max pointwise ops (#50465)
Summary: Fixes #50064 **PROBLEM DESCRIPTION:** 1. Had not removed dtype checks for complex types in the previous PR (#50347) for this issue. These type-checks were added in #36377, but are no longer necessary, as we now rely upon dispatch macros to produce error messages. 2. dtype checks in `clamp_max()` and `clamp_min()` for complex inputs had not been removed either. 3. For min/max pointwise ops in TensorCompareKernel.cpp, complex dispatch had not been removed for min/max functions. ### **FIX DESCRIPTION:** **FIX SUMMARY:** 1. Removed dtype checks added in #36377, and added 3 more in TensorCompare.cpp. 2. Removed dtype checks for complex inputs in `clamp_max()` and `clamp_min()`. 3. Disabled complex dispatch for min/max pointwise ops in TensorCompareKernel.cpp. 4. Error messages in the exceptions raised due to min/max ops not being implemented are now checked for containing the text _not support_ (which can also be present in _not supported_), or _not implemented_, so one of them should be a part of error messages, in order for them to be informative. **REASON FOR NOT CHANGING DISPATCH FOR CUDA AND CLAMP OPS**: As for the CUDA min/max operations, their kernels do not seem to be compiled & dispatched for complex types anyway, so no further changes seem to be required. Basically, the dispatch macros currently being used don't have cases for complex types. For example, 1. the reduce CUDA ops use [AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L548-L575) in [ReduceMinMaxKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu), and that macro doesn't allow complex types. 2. In [MinMaxElementwiseKernel.cu](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu), the CUDA pointwise ops use [`AT_DISPATCH_FLOATING_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h#L240-L263) for non-integral & non-boolean types, and this marco doesn't have a case for complex types either. 3. [clamp CUDA ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UnaryOpsKernel.cu#L170-L211) use `AT_DISPATCH_ALL_TYPES_AND2 (https://github.com/pytorch/pytorch/commit/678fe9f0771a5cd98ead214363d70480ba03000d)`, which doesn't have a case for complex types. Similarly, [CPU clamp min/max ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp#L428-L458) use the `AT_DISPATCH_ALL_TYPES_AND `dispatch macro, which doesn't have a case for complex types. **REASON FOR ADDING 3 dtype CHECKS:** There are a few cases in which the methods corresponding to `min_stub()` or `max_stub()` are not called, so dispatch macros don't get invoked, resulting in no exceptions being raised. Hence, `dtype` checks are necessary at 3 places to raise exceptions: 1. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L342 2. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L422 3. https://github.com/pytorch/pytorch/blob/52dcc7299925de055d330781d2fe0dad71182829/aten/src/ATen/native/TensorCompare.cpp#L389 The first dtype check requirement can be verified from the following example Python code based on `test_complex_unsupported()`: ``` import unittest import torch class MyTestCase(unittest.TestCase): def test_1(self): t = torch.tensor((1 + 1j), device='cpu', dtype=torch.complex128) with self.assertRaises(Exception): torch.max(t, dim=0) if __name__ == '__main__': unittest.main() ``` Pull Request resolved: #50465 Reviewed By: mruberry Differential Revision: D25938106 Pulled By: ngimel fbshipit-source-id: 95e2df02ba8583fa3ce87d4a2fdcd60b912dda46
1 parent 1fdc35d commit 3f052ba

File tree

7 files changed

+48
-52
lines changed

7 files changed

+48
-52
lines changed

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,16 +826,12 @@ Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self
826826
Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
827827

828828
Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) {
829-
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");
830-
831829
auto iter = TensorIterator::binary_op(result, self, other);
832830
maximum_stub(iter.device_type(), iter);
833831
return result;
834832
}
835833

836834
Tensor maximum(const Tensor& self, const Tensor& other) {
837-
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");
838-
839835
Tensor result;
840836
auto iter = TensorIterator::binary_op(result, self, other);
841837
maximum_stub(iter.device_type(), iter);
@@ -852,16 +848,12 @@ Tensor max(const Tensor& self, const Tensor& other) {
852848
}
853849

854850
Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) {
855-
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs.");
856-
857851
auto iter = TensorIterator::binary_op(result, self, other);
858852
minimum_stub(iter.device_type(), iter);
859853
return result;
860854
}
861855

862856
Tensor minimum(const Tensor& self, const Tensor& other) {
863-
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs.");
864-
865857
Tensor result;
866858
auto iter = TensorIterator::binary_op(result, self, other);
867859
minimum_stub(iter.device_type(), iter);

aten/src/ATen/native/ReduceAllOps.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,20 @@ DEFINE_DISPATCH(max_all_stub);
1111
DEFINE_DISPATCH(_aminmax_all_stub);
1212

1313
Tensor min(const Tensor &self) {
14-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
1514
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
1615
Tensor result = at::empty({}, self.options());
1716
min_all_stub(self.device().type(), result, self.contiguous());
1817
return result;
1918
}
2019

2120
Tensor max(const Tensor &self) {
22-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
2321
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
2422
Tensor result = at::empty({}, self.options());
2523
max_all_stub(self.device().type(), result, self.contiguous());
2624
return result;
2725
}
2826

2927
std::tuple<Tensor, Tensor> _aminmax_all(const Tensor &self) {
30-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
3128
TORCH_CHECK(self.numel() > 0, "operation does not have an identity.");
3229
Tensor min_result = at::empty({}, self.options());
3330
Tensor max_result = at::empty({}, self.options());

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ std::tuple<Tensor &,Tensor &> mode_out(Tensor& values, Tensor& indices,
314314
}
315315

316316
std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
317-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
318317
Tensor max_indices = at::empty({0}, self.options().dtype(kLong));
319318
if (self.is_quantized()) {
320319
Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
@@ -329,7 +328,6 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
329328

330329
static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indices,
331330
const Tensor& self, int64_t dim, bool keepdim) {
332-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
333331
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
334332
"max only supports CPU AND CUDA device type, got: ", self.device().type());
335333
TORCH_CHECK(self.layout() == Layout::Strided,
@@ -342,6 +340,7 @@ static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indic
342340
max_indices.device(), " for indices output");
343341
dim = maybe_wrap_dim(dim, self.dim());
344342
if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
343+
TORCH_CHECK(!self.is_complex(), "max does not support complex inputs.");
345344
AT_ASSERT(max.dim() == 0);
346345
max_indices.resize_({}).fill_(0);
347346
return std::forward_as_tuple(max, max_indices);
@@ -353,7 +352,6 @@ static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indic
353352

354353
std::tuple<Tensor&,Tensor&> max_out(Tensor& max, Tensor& max_indices,
355354
const Tensor& self, int64_t dim, bool keepdim) {
356-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
357355
auto result = [&]() {
358356
NoNamesGuard guard;
359357
return max_out_impl(max, max_indices, self, dim, keepdim);
@@ -364,7 +362,6 @@ std::tuple<Tensor&,Tensor&> max_out(Tensor& max, Tensor& max_indices,
364362
}
365363

366364
std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
367-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
368365
Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
369366
if (self.is_quantized()) {
370367
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
@@ -378,7 +375,6 @@ std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
378375

379376
static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max,
380377
const Tensor& self, int64_t dim, bool keepdim) {
381-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
382378
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
383379
"min_max_val only supports CPU AND CUDA device type, got: ", self.device().type());
384380
TORCH_CHECK(self.layout() == Layout::Strided,
@@ -392,6 +388,7 @@ static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max
392388
dim = maybe_wrap_dim(dim, self.dim());
393389
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min") &&
394390
_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
391+
TORCH_CHECK(!self.is_complex(), "min_max does not support complex inputs.");
395392
return std::forward_as_tuple(min, max);
396393
} else {
397394
_aminmax_stub(self.device().type(), min, max, self, dim, keepdim);
@@ -400,7 +397,6 @@ static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max
400397
}
401398

402399
std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
403-
TORCH_CHECK(!self.is_complex(), "min_max is not yet implemented for complex tensors.");
404400
TORCH_CHECK(!self.is_quantized(), "min is not yet implemented for quantized tensors.");
405401

406402
Tensor min = at::empty({0}, self.options());
@@ -412,7 +408,6 @@ std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdi
412408

413409
static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
414410
const Tensor& self, int64_t dim, bool keepdim) {
415-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
416411
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
417412
"min only supports CPU AND CUDA device type, got: ", self.device().type());
418413
TORCH_CHECK(self.layout() == Layout::Strided,
@@ -425,6 +420,7 @@ static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indic
425420
min_indices.device(), " for indices output");
426421
dim = maybe_wrap_dim(dim, self.dim());
427422
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
423+
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
428424
AT_ASSERT(min.dim() == 0);
429425
min_indices.resize_({}).fill_(0);
430426
return std::forward_as_tuple(min, min_indices);
@@ -436,7 +432,6 @@ static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indic
436432

437433
std::tuple<Tensor&,Tensor&> min_out(Tensor& min, Tensor& min_indices,
438434
const Tensor& self, int64_t dim, bool keepdim) {
439-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
440435
auto result = [&]() {
441436
NoNamesGuard guard;
442437
return min_out_impl(min, min_indices, self, dim, keepdim);
@@ -450,21 +445,17 @@ std::tuple<Tensor&,Tensor&> min_out(Tensor& min, Tensor& min_indices,
450445
// Named tensor overloads
451446

452447
std::tuple<Tensor, Tensor> min(const Tensor& self, Dimname dim, bool keepdim) {
453-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
454448
return at::min(self, dimname_to_position(self, dim), keepdim);
455449
}
456450
std::tuple<Tensor &,Tensor &> min_out(Tensor& min, Tensor& min_indices,
457451
const Tensor& self, Dimname dim, bool keepdim) {
458-
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
459452
return at::min_out(min, min_indices, self, dimname_to_position(self, dim), keepdim);
460453
}
461454
std::tuple<Tensor, Tensor> max(const Tensor& self, Dimname dim, bool keepdim) {
462-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
463455
return at::max(self, dimname_to_position(self, dim), keepdim);
464456
}
465457
std::tuple<Tensor &,Tensor &> max_out(Tensor& max, Tensor& max_indices,
466458
const Tensor& self, Dimname dim, bool keepdim) {
467-
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
468459
return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim);
469460
}
470461
Tensor argmax(const Tensor& self, Dimname dim, bool keepdim) {

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,6 @@ Tensor signbit(const Tensor& self) {
549549
}
550550

551551
Tensor& clamp_out(Tensor& result, const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
552-
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
553552
if (min && max) {
554553
TORCH_CHECK(self.layout() == Layout::Strided,
555554
"clamp only supports strided layout, got: ", self.layout());
@@ -575,7 +574,6 @@ Tensor& clamp_(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
575574
}
576575

577576
Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) {
578-
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
579577
TORCH_CHECK(self.layout() == Layout::Strided,
580578
"clamp_max only supports strided layout, got: ", self.layout());
581579
auto iter = TensorIterator::unary_op(result, self);
@@ -593,7 +591,6 @@ Tensor& clamp_max_(Tensor& self, Scalar max) {
593591
}
594592

595593
Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) {
596-
TORCH_CHECK(!self.is_complex(), "clamp does not support complex inputs.");
597594
TORCH_CHECK(self.layout() == Layout::Strided,
598595
"clamp_min only supports strided layout, got: ", self.layout());
599596
auto iter = TensorIterator::unary_op(result, self);

aten/src/ATen/native/cpu/TensorCompareKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ static void min_kernel_impl(
8181
TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong,
8282
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type());
8383

84-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
84+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
8585
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] (
8686
scalar_t* result_data, int64_t* indice_data,
8787
const scalar_t* self_data, auto self_dim_stride) {
@@ -118,7 +118,7 @@ static void max_kernel_impl(
118118
TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong,
119119
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type());
120120

121-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
121+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
122122
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] (
123123
scalar_t* result_data, int64_t* indice_data,
124124
const scalar_t* self_data, auto self_dim_stride) {

test/test_binary_ufuncs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,11 +1067,11 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
10671067
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
10681068
def test_maximum_minimum_complex(self, device, dtypes):
10691069
for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min):
1070-
with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
1070+
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
10711071
torch_op(torch.ones(1, device=device, dtype=dtypes[0]),
10721072
torch.ones(1, device=device, dtype=dtypes[1]))
10731073

1074-
with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
1074+
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
10751075
torch_op(torch.ones(1, device=device, dtype=dtypes[1]),
10761076
torch.ones(1, device=device, dtype=dtypes[0]))
10771077

test/test_torch.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5984,60 +5984,79 @@ def test_complex_unsupported(self, device, dtype):
59845984
# Note: whether PyTorch should support min and max on complex
59855985
# tensors is an open question.
59865986
# See https://github.com/pytorch/pytorch/issues/36374
5987-
with self.assertRaises(RuntimeError):
5987+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59885988
torch.min(t)
5989-
with self.assertRaises(RuntimeError):
5989+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59905990
t.min()
5991-
with self.assertRaises(RuntimeError):
5991+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59925992
torch.min(t, dim=0)
5993-
with self.assertRaises(RuntimeError):
5993+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59945994
torch.min(t, t)
5995-
with self.assertRaises(RuntimeError):
5995+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59965996
torch.min(t, t, out=t)
59975997

5998-
with self.assertRaises(RuntimeError):
5998+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
59995999
torch.max(t)
6000-
with self.assertRaises(RuntimeError):
6000+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60016001
t.max()
6002-
with self.assertRaises(RuntimeError):
6002+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60036003
torch.max(t, dim=0)
6004-
with self.assertRaises(RuntimeError):
6004+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60056005
torch.max(t, t)
6006-
with self.assertRaises(RuntimeError):
6006+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60076007
torch.max(t, t, out=t)
60086008

6009-
with self.assertRaises(RuntimeError):
6009+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60106010
torch.amin(t)
6011-
with self.assertRaises(RuntimeError):
6011+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60126012
t.amin()
6013-
with self.assertRaises(RuntimeError):
6013+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60146014
torch.amin(t, dim=0)
60156015

6016-
with self.assertRaises(RuntimeError):
6016+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60176017
torch.amax(t)
6018-
with self.assertRaises(RuntimeError):
6018+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60196019
t.amax()
6020-
with self.assertRaises(RuntimeError):
6020+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60216021
torch.amax(t, dim=0)
60226022

6023+
# Tests _aminmax() variants with complex inputs,
6024+
# which are currently not supported due to min & max being unsupported
6025+
# for complex inputs, as per https://github.com/pytorch/pytorch/issues/36374
6026+
# Test with a single-element tensor t, as well as a multi-element tensor x
6027+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6028+
min_val, max_val = torch._aminmax(t)
6029+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6030+
min_val = torch._aminmax(t, dim=0)[0]
6031+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6032+
max_val = torch._aminmax(t, dim=0)[1]
6033+
# Test _aminmax() with a multi-element tensor
6034+
x = torch.tensor([(1 + 1j), (2 + 3j)], device=device, dtype=dtype)
6035+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6036+
min_val, max_val = torch._aminmax(x)
6037+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6038+
min_val = torch._aminmax(x, dim=0)[0]
6039+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
6040+
max_val = torch._aminmax(x, dim=0)[1]
6041+
60236042
# Tests clamp variants with complex inputs
60246043
# Note: whether PyTorch should support clamp on complex
60256044
# tensors is an open question.
60266045
# See https://github.com/pytorch/pytorch/issues/33568
60276046
min_val = 1 + 1j
60286047
max_val = 4 + 4j
60296048
out = torch.empty((0,), device=device, dtype=dtype)
6030-
with self.assertRaises(RuntimeError):
6049+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60316050
torch.clamp(t, min=min_val)
6032-
with self.assertRaises(RuntimeError):
6051+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60336052
torch.clamp(t, max=max_val)
6034-
with self.assertRaises(RuntimeError):
6053+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60356054
torch.clamp(t, min_val, max_val)
6036-
with self.assertRaises(RuntimeError):
6055+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60376056
torch.clamp(t, min=min_val, out=out)
6038-
with self.assertRaises(RuntimeError):
6057+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60396058
torch.clamp(t, max=max_val, out=out)
6040-
with self.assertRaises(RuntimeError):
6059+
with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'):
60416060
torch.clamp(t, min_val, max_val, out=out)
60426061

60436062
def test_pickle_gradscaler(self, device):

0 commit comments

Comments
 (0)
0