From 997e627af02114c2c4d035aaaa5acf7ff1c190c7 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 6 Aug 2020 15:09:24 -0500 Subject: [PATCH 1/6] Add torch.linalg.norm --- aten/src/ATen/core/interned_strings.h | 1 + aten/src/ATen/native/LinearAlgebra.cpp | 249 ++++++++- aten/src/ATen/native/native_functions.yaml | 16 + docs/source/linalg.rst | 1 + test/test_linalg.py | 475 ++++++++++++++++++ test/test_torch.py | 7 +- .../templates/python_linalg_functions.cpp | 1 + torch/csrc/api/include/torch/linalg.h | 32 ++ torch/linalg/__init__.py | 125 +++++ 9 files changed, 897 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index a38044b1a304..29119e23e711 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -166,6 +166,7 @@ namespace c10 { _(aten, clip_) \ _(aten, det) \ _(aten, linalg_det) \ + _(aten, linalg_norm) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 27714fab28c6..5c47abc727f1 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { if (dim.size() == 1 || dim.size() == 0) { return at::norm(self, 2, dim, keepdim); } + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); if (self.is_complex()){ - return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim)); + return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim)); } else { - return at::sqrt(at::sum((self * self), dim, keepdim)); + return at::sqrt(at::sum((self * self), dim_, keepdim)); } } @@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out( if (dim.size() == 1 || dim.size() == 0) { return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type()); } + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); if (self.is_complex()){ - return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim, keepdim)); + return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim_, keepdim)); } else { - return at::sqrt_out(result, at::sum((self * self), dim, keepdim)); + return at::sqrt_out(result, at::sum((self * self), dim_, keepdim)); } } @@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) { Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); - auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim()); + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); Tensor p = self.permute(permutation); // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm @@ -1360,19 +1369,245 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) { TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); - auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim()); + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); Tensor p = self.permute(permutation); at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim); if (keepdim) { result.unsqueeze_(-1); - result = result.permute(permutation_reverse); + Tensor result_ = result.permute(permutation_reverse); + at::native::resize_output(result, result_.sizes()); + result.copy_(result_); + } + return result; +} + +static std::vector make_dim_list(int64_t ndim) { + std::vector dim_list(ndim); + for (int64_t ind = 0; ind < ndim; ind++) { + dim_list[ind] = ind; + } + return dim_list; +} + +static void check_str_ord_valid(std::string str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { + TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); + TORCH_CHECK(!opt_dtype.has_value(), "dtype argument is currently not supported in frobenius norm, ", + "but will be in the future"); + bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); + TORCH_CHECK(dims_valid, "order \"", str_ord, + "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); +} + +// Performs vector norm for ord = +/-infinity, and the second dimension reduction +// for matrix norms. +static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) { + Tensor result; + if (self.numel() == 0 && self.sizes()[dim] > 0) { + // This special case is needed in matrix norm for tensors with 3 or more dims, + // or in vector norm for order inf and -inf for tesnsors with 2 or more dims. + // When the sizes of the dims to be reduced are greater than 0 but another dim + // in the tensor is size 0 (thus numel == 0), we must either flatten or resize + // the second reduction dim to 1, to avoid calling min/max, which would throw + // an error. + if (self.sizes()[dim] != 1) { + auto new_sizes = self.sizes().vec(); + new_sizes[dim] = 1; + self.resize_(new_sizes); + } + result = keepdim ? self : self.flatten(dim); + } else { + if (ord > 0) { + result = std::get<0>(self.max(dim, keepdim)); + } else { + result = std::get<0>(self.min(dim, keepdim)); + } + } + return result; +} + +// Performs matrix norm +static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, + IntArrayRef dim, bool keepdim, optional opt_dtype) { + Tensor result; + auto ord = opt_ord.value_or(2.0).toDouble(); + TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, + "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type()); + TORCH_CHECK(self.layout() == Layout::Strided, + "matrix norm only supports strided layout, got: ", self.layout()); + if ((dim.size() == 0) && (self.dim() == 2)) { + dim = {0, 1}; + } + TORCH_CHECK(dim.size() == 2, "_norm_matrix: 'dim' must either specify 2 dimensions, or if ", + "'self' is 2-D 'dim' can specify 0 dimensions for a full reduction. Got 'dim' specifying ", + dim.size(), " dims and 'self' is ", self.dim(), "-D"); + ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); + TORCH_CHECK( + at::isFloatingType(scalarType) || at::isComplexType(scalarType), + "Can only calculate the mean of floating types. Got ", + toString(scalarType), + " instead."); + + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], + "Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead"); + + Tensor self_; + + if (opt_dtype.has_value()) { + self_ = self.to(scalarType); + } else { + self_ = self; + } + + if (std::abs(ord) == 2) { + // Need to shift the reduction dims to the back, because at::svd will only operate on + // the last 2 dimensions + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); + auto permutation_reverse = create_reverse_permutation(permutation); + + result = std::get<1>(self_.permute(permutation).svd()).abs(); + result = _norm_min_max(result, ord, result.dim() - 1, keepdim); + + if (keepdim) { + result.unsqueeze_(-1); + result = result.permute(permutation_reverse); + } + } else { + // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except + // that the order of the two dims is swapped. So we can swap the dims if + // abs(p) == infinity to simplify the rest of the operation's logic. + if (std::abs(ord) == INFINITY) { + std::swap(dim_[0], dim_[1]); + } + // If the dim of the second reduction is greater than that of the first reduction + // and we are not keeping the dims, then the fact that the output of the first + // reduction will have one fewer dimension means that the second reduction dim + // will be off by one, so we need to correct that. + if ((dim_[1] > dim_[0]) && !keepdim) { + dim_[1]--; + } + if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) { + result = self_.abs().sum(dim_[0], keepdim); + result = _norm_min_max(result, ord, dim_[1], keepdim); + } else { + TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm"); + } + } + return result; +} + +// Performs vector norm +// This function mostly serves as a wrapper for at::norm, but it overrides a few cases +// for numpy compatibility. These cases are corrected within this wrapper, rather than +// in at::norm itself, to avoid breaking backward compatibility. +static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { + if (opt_ord.has_value()) { + TORCH_INTERNAL_ASSERT(dim.size() == 1); + auto ord = opt_ord.value().toDouble(); + Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self; + if (std::abs(ord) == INFINITY) { + // The ord = +/-infinity case is overridden because at::norm does not match numpy + // when the input contains extreme values (like nan or +/-inf) or if the input + // size is degenerate (like size(0), size(0, N), etc) + self_ = self_.abs(); + return _norm_min_max(self_, ord, dim[0], keepdim); + } else if ((self_.numel() == 0) && (ord < 0)) { + // For negative orders with degenerate input sizes, at::norm's result does not + // match numpy. + Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim); + if (ord >= -1) { + // Result must be infinite in this case, and the simplest way to make that + // happen is to simply add infinity + result += INFINITY; + } else { + result = result.pow(1.0 / (ord + 1)); + } + return result; + } + } else { + // If ord == None, need to check for unique dims because at::norm does not check it + // for this case. + std::vector dim_(dim); + maybe_wrap_dims(dim_, self.dim()); + bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end(); + TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")"); + } + if (opt_dtype.has_value()) { + return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); + } else { + return at::norm(self, opt_ord, dim, keepdim); + } +} + +static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + // Callers must give the ord argument as either a number, a string, or neither. + // Since the user-facing API has no direct control over how this function is called, this is an internal assert. + TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value())); + if (opt_dtype.has_value()) { + auto dtype = opt_dtype.value(); + TORCH_CHECK(dtype == result.scalar_type(), "provided dtype must match dtype of result, but got", + "dtype = ", dtype, ", out.dtype = ", result.scalar_type()); } + int64_t ndim = self.dim(); + Tensor result_ = result.clone(); + if (opt_str_ord.has_value()) { + // 'ord' is string + auto str_ord = opt_str_ord.value(); + check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); + if (str_ord == "fro") { + result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + } else if (str_ord == "nuc") { + if (opt_dim.has_value()) { + result_ = at::nuclear_norm(self, opt_dim.value(), keepdim); + } else { + result_ = at::nuclear_norm(self, keepdim); + } + } + } else { + // 'ord' is int or None + std::vector dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim); + if (!opt_num_ord.has_value() || dim_.size() == 1) { + result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype); + } else if (dim_.size() == 2) { + result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype); + } else { + TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is " + "not 1-D or 2-D"); + } + } + resize_output(result, result_.sizes()); + result.copy_(result_); return result; } +// Numerical or None norms +Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + Tensor result = at::empty({0}, opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).to(self.device()); + return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +// Frobenius and nuclear norms +Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + Tensor result = at::empty({0}, opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).to(self.device()); + return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + +// Numerical or None norms +Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return linalg_norm_out_impl(result, self, opt_ord, c10::nullopt, opt_dim, keepdim, opt_dtype); +} + +// Frobenius and nuclear norms +Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0d2241033281..c6b87eb1950f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7288,6 +7288,22 @@ - func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_norm(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.ord_str(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_norm.ord_str_out(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 1152267f3609..834b6a60ac93 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -13,3 +13,4 @@ Functions --------- .. autofunction:: det +.. autofunction:: norm diff --git a/test/test_linalg.py b/test/test_linalg.py index 73e63c19b646..cbcdb6344844 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1,10 +1,14 @@ import torch import unittest +import itertools +from math import inf, nan, isnan from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack) +from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args +from torch.autograd import gradcheck if TEST_NUMPY: import numpy as np @@ -54,6 +58,477 @@ def test_det(self, device, dtype): with self.assertRaises(IndexError): op(t) + # This test confirms that torch.linalg.norm's dtype argument works + # as expected, according to the function's documentation + def test_norm_dtype(self, device): + def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype): + msg = ( + f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' + f'from_dtype={from_dtype}, to_dtype={to_dtype}') + input = torch.randn(*input_size, dtype=from_dtype, device=device) + result = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=from_dtype) + self.assertEqual(result.dtype, from_dtype, msg=msg) + result_converted = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) + self.assertEqual(result_converted.dtype, to_dtype, msg=msg) + self.assertEqual(result.to(compare_dtype), result_converted.to(compare_dtype), msg=msg) + + result_out_converted = torch.empty_like(result_converted) + torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_converted) + self.assertEqual(result_out_converted.dtype, to_dtype, msg=msg) + self.assertEqual(result_converted, result_out_converted, msg=msg) + + ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + ord_matrix = [1, -1, 2, -2, inf, -inf, None] + S = 10 + test_cases = [ + ((S, ), ord_vector), + ((S, S), ord_matrix), + ] + for keepdim in [True, False]: + for input_size, ord_settings in test_cases: + for ord in ord_settings: + # float to double + run_test_case(input_size, ord, keepdim, torch.float, torch.double, torch.float) + # double to float + run_test_case(input_size, ord, keepdim, torch.double, torch.double, torch.float) + + # Make sure that setting dtype != out.dtype raises an error + dtype_pairs = [ + (torch.float, torch.double), + (torch.double, torch.float), + ] + for keepdim in [True, False]: + for input_size, ord_settings in test_cases: + for ord in ord_settings: + for dtype, out_dtype in dtype_pairs: + input = torch.rand(*input_size) + result = torch.Tensor().to(out_dtype) + with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'): + torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result) + + # TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test + # and add 'nuc' and 'fro' to ord_matrix above + for ord in ['nuc', 'fro']: + input = torch.randn(10, 10, device=device) + with self.assertRaisesRegex(RuntimeError, r'dtype argument is currently not supported'): + torch.linalg.norm(input, ord, dtype=torch.float) + + # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that + # their vector norm results match + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_vector(self, device, dtype): + def run_test_case(input, p, dim, keepdim): + result = torch.linalg.norm(input, ord, dim, keepdim) + input_numpy = input.cpu().numpy() + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + self.assertEqual(result, result_numpy, msg=msg) + + result_out = torch.empty_like(result) + torch.linalg.norm(input, ord, dim, keepdim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + + ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + S = 10 + test_cases = [ + # input size, p settings, dim + ((S, ), ord_vector, None), + ((S, ), ord_vector, (0, )), + ((S, S, S), ord_vector, (0, )), + ((S, S, S), ord_vector, (1, )), + ((S, S, S), ord_vector, (2, )), + ((S, S, S), ord_vector, (-1, )), + ((S, S, S), ord_vector, (-2, )), + ] + L = 1_000_000 + if dtype == torch.double: + test_cases.append(((L, ), ord_vector, None)) + for keepdim in [True, False]: + for input_size, ord_settings, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_test_case(input, ord, dim, keepdim) + + # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that + # their matrix norm results match + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_matrix(self, device, dtype): + def run_test_case(input, p, dim, keepdim): + result = torch.linalg.norm(input, ord, dim, keepdim) + input_numpy = input.cpu().numpy() + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + self.assertEqual(result, result_numpy, msg=msg) + + result_out = torch.empty_like(result) + torch.linalg.norm(input, ord, dim, keepdim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + + ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro', None] + S = 10 + test_cases = [ + # input size, p settings, dim + ((S, S), ord_matrix, None), + ((S, S), ord_matrix, (0, 1)), + ((S, S), ord_matrix, (1, 0)), + ((S, S, S, S), ord_matrix, (2, 0)), + ((S, S, S, S), ord_matrix, (-1, -2)), + ((S, S, S, S), ord_matrix, (-1, -3)), + ((S, S, S, S), ord_matrix, (-3, 2)), + ] + L = 1_000 + if dtype == torch.double: + test_cases.append(((L, L), ord_matrix, None)) + for keepdim in [True, False]: + for input_size, ord_settings, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_test_case(input, ord, dim, keepdim) + + # Test autograd and jit functionality for linalg functions. + # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, + # the `test_cases` entries below should be moved there. These entries are in a similar format, + # so they should work with minimal changes. + @dtypes(torch.float, torch.double) + def test_autograd_and_jit(self, device, dtype): + torch.manual_seed(0) + S = 10 + NO_ARGS = None # NOTE: refer to common_methods_invocations.py if you need this feature + test_cases = [ + # NOTE: Not all the features from common_methods_invocations.py are functional here, since this + # is only a temporary solution. + # ( + # method name, + # input size/constructing fn, + # args (tuple represents shape of a tensor arg), + # test variant name (will be used at test name suffix), // optional + # (should_check_autodiff[bool], nonfusible_nodes, fusible_nodes) for autodiff, // optional + # indices for possible dim arg, // optional + # fn mapping output to part that should be gradcheck'ed, // optional + # kwargs // optional + # ) + ('norm', (S,), (), 'default_1d'), + ('norm', (S, S), (), 'default_2d'), + ('norm', (S, S, S), (), 'default_3d'), + ('norm', (S,), (inf,), 'vector_inf'), + ('norm', (S,), (3.5,), 'vector_3_5'), + ('norm', (S,), (2,), 'vector_2'), + ('norm', (S,), (1,), 'vector_1'), + ('norm', (S,), (0,), 'vector_0'), + ('norm', (S,), (-inf,), 'vector_neg_inf'), + ('norm', (S,), (-3.5,), 'vector_neg_3_5'), + ('norm', (S,), (2,), 'vector_neg_2'), + ('norm', (S,), (1,), 'vector_neg_1'), + ('norm', (S, S), (inf,), 'matrix_inf'), + ('norm', (S, S), (2,), 'matrix_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), (1,), 'matrix_1'), + ('norm', (S, S), (-inf,), 'matrix_neg_inf'), + ('norm', (S, S), (-2,), 'matrix_neg_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), (-1,), 'matrix_neg_1'), + ('norm', (S, S), ('fro',), 'fro'), + ('norm', (S, S), ('fro', [0, 1]), 'fro_dim'), + ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), ('nuc', [0, 1]), 'nuc_dim', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ] + for test_case in test_cases: + func_name = test_case[0] + func = getattr(torch.linalg, func_name) + input_size = test_case[1] + args = list(test_case[2]) + test_case_name = test_case[3] if len(test_case) >= 4 else None + mapping_funcs = list(test_case[6]) if len(test_case) >= 7 else None + + # Skip a test if a decorator tells us to + if mapping_funcs is not None: + def decorated_func(self, device, dtype): + pass + for mapping_func in mapping_funcs: + decorated_func = mapping_func(decorated_func) + try: + decorated_func(self, device, dtype) + except unittest.SkipTest: + continue + + msg = f'function name: {func_name}, case name: {test_case_name}' + + # Test JIT + input = torch.randn(*input_size, dtype=dtype, device=device) + input_script = input.clone().detach() + script_method, tensors = gen_script_fn_and_args("linalg.norm", "functional", input_script, *args) + self.assertEqual( + func(input, *args), + script_method(input_script), + msg=msg) + + # Test autograd + # gradcheck is only designed to work with torch.double inputs + if dtype == torch.double: + input = torch.randn(*input_size, dtype=dtype, device=device, requires_grad=True) + + def run_func(input): + return func(input, *args) + self.assertTrue(gradcheck(run_func, input), msg=msg) + + # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments + # to ensure that they both throw errors + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_errors(self, device, dtype): + def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): + test_case_info = ( + f'test case input.size()={input.size()}, ord={ord}, dim={dim}, ' + f'keepdim={keepdim}, dtype={dtype}') + + with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info): + torch.linalg.norm(input, ord, dim, keepdim) + + input_numpy = input.cpu().numpy() + + msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"' + with self.assertRaises(Exception, msg=test_case_info): + np.linalg.norm(input_numpy, ord, dim, keepdim) + + S = 10 + error_test_cases = [ + # input size, p settings, dim, error type, error regex + ((S, ), ['fro'], None, RuntimeError, r'order "fro" can only be used if either len\(dim\) == 2'), + ((S, ), ['nuc'], None, RuntimeError, r'order "nuc" can only be used if either len\(dim\) == 2'), + ((S, S), [3.5], None, RuntimeError, r'Order 3.5 not supported for matrix norm'), + ((S, S), [0], None, RuntimeError, r'Order 0 not supported for matrix norm'), + ((S, S), ['nuc'], (0, ), RuntimeError, r'order "nuc" can only be used if either len\(dim\) == 2'), + ((S, S), ['fro'], (0, ), RuntimeError, r'order "fro" can only be used if either len\(dim\) == 2'), + ((S, S), ['nuc'], (0, 0), RuntimeError, r'duplicate or invalid dimensions'), + ((S, S), ['fro', 0], (0, 0), RuntimeError, r'Expected dims to be different'), + ((S, S), ['fro', 'nuc', 0], (0, 4), IndexError, r'Dimension out of range'), + ((S, ), [0], (4, ), IndexError, r'Dimension out of range'), + ((S, ), [None], (0, 0), RuntimeError, r'Expected dims to be different, got this instead'), + ((S, S, S), [1], (0, 1, 2), RuntimeError, r"'dim' must specify 1 or 2 dimensions"), + ((S, S, S), [1], None, RuntimeError, r"'dim' must specify 1 or 2 dimensions"), + ((S, S), ['garbage'], (0, 1), RuntimeError, r'Invalid norm order: garbage'), + ] + for keepdim in [True, False]: + for input_size, ord_settings, dim, error_type, error_regex in error_test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) + + # Test complex number inputs for linalg.norm. Some cases are not supported yet, so + # this test also verifies that those cases raise an error. + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.cfloat, torch.cdouble) + def test_norm_complex(self, device, dtype): + def gen_error_message(input_size, ord, keepdim, dim=None): + return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( + input_size, ord, keepdim, dim) + + if self.device_type == 'cpu': + supported_vector_ords = [0, 1, 3, inf, -1, -2, -3, -inf] + supported_matrix_ords = ['nuc', 1, 2, inf, -1, -2, -inf] + unsupported_vector_ords = [ + (2, r'norm with p=2 not supported for complex tensors'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + unsupported_matrix_ords = [ + ('fro', r'frobenius norm not supported for complex tensors'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + + elif self.device_type == 'cuda': + supported_vector_ords = [inf, -inf] + supported_matrix_ords = [1, inf, -1, -inf] + unsupported_vector_ords = [ + (0, r'norm_cuda" not implemented for \'Complex'), + (1, r'norm_cuda" not implemented for \'Complex'), + (2, r'norm with p=2 not supported for complex tensors'), + (-1, r'norm_cuda" not implemented for \'Complex'), + (-2, r'norm_cuda" not implemented for \'Complex'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + unsupported_matrix_ords = [ + (None, r'norm with p=2 not supported for complex tensors'), + ('fro', r'frobenius norm not supported for complex tensors'), + (2, r'"svd_cuda" not implemented for \'Complex'), + (-2, r'"svd_cuda" not implemented for \'Complex'), + ('nuc', r'"svd_cuda" not implemented for \'Complex'), + ] + + # Test supported ords + for keepdim in [False, True]: + # vector norm + x = torch.randn(25, device=device, dtype=dtype) + xn = x.cpu().numpy() + for ord in supported_vector_ords: + res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, ord, keepdims=keepdim) + msg = gen_error_message(x.size(), ord, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # matrix norm + x = torch.randn(25, 25, device=device, dtype=dtype) + xn = x.cpu().numpy() + for ord in supported_matrix_ords: + # TODO: Need to fix abort when nuclear norm is given cdouble input: + # "double free or corruption (!prev) Aborted (core dumped)" + if ord == 'nuc' and dtype == torch.cdouble: + continue + res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, ord, keepdims=keepdim) + msg = gen_error_message(x.size(), ord, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # Test unsupported ords + # vector norm + x = torch.randn(25, device=device, dtype=dtype) + for ord, error_msg in unsupported_vector_ords: + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.linalg.norm(x, ord) + + # matrix norm + x = torch.randn(25, 25, device=device, dtype=dtype) + for ord, error_msg in unsupported_matrix_ords: + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.linalg.norm(x, ord) + + # Make sure that linalg.norm raises an error if dim is an integer + # TODO: When integer dims are supported in norm, remove this test + def test_norm_dim_int_error(self, device): + input = torch.randn(10, device=device) + with self.assertRaisesRegex(TypeError, r'linalg_norm\(\) received an invalid combination of arguments'): + torch.linalg.norm(input, dim=0) + + # Test that linal.norm gives the same result as numpy when inputs + # contain extreme values (inf, -inf, nan) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_norm_extreme_values(self, device): + vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf] + vectors = [] + matrices = [] + for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): + vectors.append(list(pair)) + matrices.append([[pair[0], pair[1]]]) + matrices.append([[pair[0]], [pair[1]]]) + for vector in vectors: + x = torch.tensor(vector).to(device) + x_n = x.cpu().numpy() + for ord in vector_ords: + msg = f'ord={ord}, vector={vector}' + result = torch.linalg.norm(x, ord=ord) + result_n = np.linalg.norm(x_n, ord=ord) + self.assertEqual(result, result_n, msg=msg) + + # TODO: Need to fix these cases + def is_broken_matrix_norm_case(ord, x): + if self.device_type == 'cuda': + if x.size() == torch.Size([1, 2]): + if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1: + return True + return False + + for matrix in matrices: + x = torch.tensor(matrix).to(device) + x_n = x.cpu().numpy() + for ord in matrix_ords: + if is_broken_matrix_norm_case(ord, x): + continue + msg = f'ord={ord}, matrix={matrix}' + result = torch.linalg.norm(x, ord=ord) + result_n = np.linalg.norm(x_n, ord=ord) + self.assertEqual(result, result_n, msg=msg) + + # Test degenerate shape results match numpy for linalg.norm vector norms + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_vector_degenerate_shapes(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + if should_error: + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + else: + if dtype in [torch.cfloat, torch.cdouble] and ord in [2, None]: + # TODO: Once these ord values have support for complex numbers, + # remove this error test case + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + return + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + result = torch.linalg.norm(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, dim + ((0, ), [inf, -inf], None), + ((0, S), [inf, -inf], (0,)), + ((0, S), [], (1,)), + ((S, 0), [], (0,)), + ((S, 0), [inf, -inf], (1,)), + ] + for keepdim in [True, False]: + for input_size, error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_vector: + run_test_case(input, ord, dim, keepdim, ord in error_ords) + + # Test degenerate shape results match numpy for linalg.norm matrix norms + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_matrix_degenerate_shapes(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + if dtype in [torch.cfloat, torch.cdouble] and ord in ['fro', None]: + # TODO: Once these ord values have support for complex numbers, + # remove this error test case + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + return + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + if should_error: + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + else: + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + result = torch.linalg.norm(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, dim + ((0, 0), [1, 2, inf, -1, -2, -inf], None), + ((0, S), [2, inf, -2, -inf], None), + ((S, 0), [1, 2, -1, -2], None), + ((S, S, 0), [], (0, 1)), + ((1, S, 0), [], (0, 1)), + ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), + ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), + ] + for keepdim in [True, False]: + for input_size, error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_matrix: + run_test_case(input, ord, dim, keepdim, ord in error_ords) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/test/test_torch.py b/test/test_torch.py index aa65896010cd..88b328e9f9a2 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9983,13 +9983,14 @@ def check_single_nuclear_norm(x, axes): @skipCUDAIfNoMagma def test_nuclear_norm_exceptions(self, device): for lst in [], [1], [1, 2]: - for axes in (), (0,), (0, 1): - x = torch.tensor(lst, dtype=torch.double, device=device) + x = torch.tensor(lst, dtype=torch.double, device=device) + for axes in (), (0,): self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) + self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2)) + self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) def test_embedding_scalar_weight_error(self, device): indices = torch.rand(2, 2, device=device).long() diff --git a/tools/autograd/templates/python_linalg_functions.cpp b/tools/autograd/templates/python_linalg_functions.cpp index fa139eef0b87..b02438e31189 100644 --- a/tools/autograd/templates/python_linalg_functions.cpp +++ b/tools/autograd/templates/python_linalg_functions.cpp @@ -12,6 +12,7 @@ using at::Tensor; using at::Scalar; +using at::ScalarType; using at::MemoryFormat; using at::Generator; using at::IntArrayRef; diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index d1c4c60df6f2..5ce90dcc972e 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -12,6 +12,22 @@ inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } +inline Tensor norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -21,4 +37,20 @@ inline Tensor linalg_det(const Tensor& self) { return detail::det(self); } +inline Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + }} // torch::linalg diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 61850fab1dba..758701e0d402 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -13,3 +13,128 @@ Alias of :func:`torch.det`. """) + +norm = _add_docstr(_linalg.linalg_norm, r""" +linalg.norm(input, ord=None, dim=None, keepdim=False, out=None, dtype=None) -> Tensor + +Returns the matrix norm or vector norm of a given tensor. + +This function can calculate one of eight different types of matrix norms, or one +of an infinite number of vector norms, depending on both the number of reduction +dimensions and the value of the `ord` parameter. + +Args: + input (Tensor): The input tensor. If dim is None, x must be 1-D or 2-D, unless :attr:`ord` + is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D + will be returned. + + ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm. + inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- not supported -- + 'nuc' nuclear norm -- not supported -- + inf max(sum(abs(x), dim=1)) max(abs(x)) + -inf min(sum(abs(x), dim=1)) min(abs(x)) + 0 -- not supported -- sum(x != 0) + 1 max(sum(abs(x), dim=0)) as below + -1 min(sum(abs(x), dim=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- not supported -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Default: ``None`` + + dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int, + vector norm will be calculated over the specified dimension. If :attr:`dim` + is a 2-tuple of ints, matrix norm will be calculated over the specified + dimensions. If :attr:`dim` is None, matrix norm will be calculated + when the input tensor has two dimensions, and vector norm will be + calculated when the input tensor has one dimension. Default: ``None`` + + keepdim (bool, optional): If set to True, the reduced dimensions are retained + in the result as dimensions with size one. Default: ``False`` + + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. If this argument is used in conjunction with the + :attr:`out` argument, the output tensor's type must match this argument; + otherwise, a RuntimeError will be raised. This argument is not currently + supported for :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + +Examples:: + + >>> import torch + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> b = a.reshape((3, 3)) + >>> b + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + + >>> LA.norm(a) + tensor(7.7460) + >>> LA.norm(b) + tensor(7.7460) + >>> LA.norm(b, 'fro') + tensor(7.7460) + >>> LA.norm(a, float('inf')) + tensor(4.) + >>> LA.norm(b, float('inf')) + tensor(9.) + >>> LA.norm(a, -float('inf')) + tensor(0.) + >>> LA.norm(b, -float('inf')) + tensor(2.) + + >>> LA.norm(a, 1) + tensor(20.) + >>> LA.norm(b, 1) + tensor(7.) + >>> LA.norm(a, -1) + tensor(0.) + >>> LA.norm(b, -1) + tensor(6.) + >>> LA.norm(a, 2) + tensor(7.7460) + >>> LA.norm(b, 2) + tensor(7.3485) + + >>> LA.norm(a, -2) + tensor(0.) + >>> LA.norm(b.double(), -2) + tensor(1.8570e-16, dtype=torch.float64) + >>> LA.norm(a, 3) + tensor(5.8480) + >>> LA.norm(a, -3) + tensor(0.) + +Using the :attr:`dim` argument to compute vector norms:: + + >>> c = torch.tensor([[1., 2., 3.], + ... [-1, 1, 4]]) + >>> LA.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> LA.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> LA.norm(c, ord=1, dim=1) + tensor([6., 6.]) + +Using the :attr:`dim` argument to compute matrix norms:: + + >>> m = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(m, dim=(1,2)) + tensor([ 3.7417, 11.2250]) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (tensor(3.7417), tensor(11.2250)) +""") From 56b1e42134098c9729418080a7b736f59a6e992a Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 25 Aug 2020 14:41:58 -0500 Subject: [PATCH 2/6] Add svd issue link to broken test cases --- test/test_linalg.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index cbcdb6344844..6b58443fb18a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -427,11 +427,13 @@ def test_norm_extreme_values(self, device): result_n = np.linalg.norm(x_n, ord=ord) self.assertEqual(result, result_n, msg=msg) - # TODO: Need to fix these cases + # TODO: Remove this function once the broken cases are fixed def is_broken_matrix_norm_case(ord, x): if self.device_type == 'cuda': if x.size() == torch.Size([1, 2]): if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1: + # These cases are broken because of an issue with svd + # https://github.com/pytorch/pytorch/issues/43567 return True return False @@ -439,12 +441,14 @@ def is_broken_matrix_norm_case(ord, x): x = torch.tensor(matrix).to(device) x_n = x.cpu().numpy() for ord in matrix_ords: - if is_broken_matrix_norm_case(ord, x): - continue msg = f'ord={ord}, matrix={matrix}' result = torch.linalg.norm(x, ord=ord) result_n = np.linalg.norm(x_n, ord=ord) - self.assertEqual(result, result_n, msg=msg) + + if is_broken_matrix_norm_case(ord, x): + self.assertNotEqual(result, result_n, msg=msg) + else: + self.assertEqual(result, result_n, msg=msg) # Test degenerate shape results match numpy for linalg.norm vector norms @skipCUDAIfNoMagma From f2be712afe7ac52681d99b0a232a0ca7f6d599c8 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 25 Aug 2020 19:32:11 -0500 Subject: [PATCH 3/6] Small fixes to comments, error messages, documentation, and code --- aten/src/ATen/native/LinearAlgebra.cpp | 35 +++++++++++++------------- test/test_linalg.py | 2 +- torch/linalg/__init__.py | 10 +++++--- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 5c47abc727f1..00f12ae3af94 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1386,6 +1386,8 @@ Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bo return result; } +// Creates a vector of length ndim with values equal to its indices +// (e.g. [0, 1, 2, ..., ndim-1]) static std::vector make_dim_list(int64_t ndim) { std::vector dim_list(ndim); for (int64_t ind = 0; ind < ndim; ind++) { @@ -1394,10 +1396,10 @@ static std::vector make_dim_list(int64_t ndim) { return dim_list; } -static void check_str_ord_valid(std::string str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { +// Checks for valid arguments to linalg_norm when type(ord) == str +static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); - TORCH_CHECK(!opt_dtype.has_value(), "dtype argument is currently not supported in frobenius norm, ", - "but will be in the future"); + TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument"); bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); TORCH_CHECK(dims_valid, "order \"", str_ord, "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); @@ -1439,26 +1441,21 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "matrix norm only supports strided layout, got: ", self.layout()); - if ((dim.size() == 0) && (self.dim() == 2)) { - dim = {0, 1}; - } - TORCH_CHECK(dim.size() == 2, "_norm_matrix: 'dim' must either specify 2 dimensions, or if ", - "'self' is 2-D 'dim' can specify 0 dimensions for a full reduction. Got 'dim' specifying ", - dim.size(), " dims and 'self' is ", self.dim(), "-D"); - ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); - TORCH_CHECK( - at::isFloatingType(scalarType) || at::isComplexType(scalarType), - "Can only calculate the mean of floating types. Got ", - toString(scalarType), - " instead."); + TORCH_CHECK(dim.size() == 2, "_linalg_norm_matrix: 'dim' must either specify 2 dimensions. ", + "Got 'dim' specifying ", dim.size(), " dims"); auto dim_ = dim.vec(); maybe_wrap_dims(dim_, self.dim()); TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead"); - Tensor self_; + ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); + TORCH_CHECK( + at::isFloatingType(scalarType) || at::isComplexType(scalarType), + "Can only calculate the mean of floating and complex types. Got ", + toString(scalarType), " instead."); + Tensor self_; if (opt_dtype.has_value()) { self_ = self.to(scalarType); } else { @@ -1588,13 +1585,15 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional // Numerical or None norms Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { - Tensor result = at::empty({0}, opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).to(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); } // Frobenius and nuclear norms Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { - Tensor result = at::empty({0}, opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).to(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } diff --git a/test/test_linalg.py b/test/test_linalg.py index 6b58443fb18a..1e3d8d94374e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -110,7 +110,7 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) # and add 'nuc' and 'fro' to ord_matrix above for ord in ['nuc', 'fro']: input = torch.randn(10, 10, device=device) - with self.assertRaisesRegex(RuntimeError, r'dtype argument is currently not supported'): + with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"): torch.linalg.norm(input, ord, dtype=torch.float) # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 758701e0d402..5e2b59c45c80 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -15,7 +15,7 @@ """) norm = _add_docstr(_linalg.linalg_norm, r""" -linalg.norm(input, ord=None, dim=None, keepdim=False, out=None, dtype=None) -> Tensor +linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor Returns the matrix norm or vector norm of a given tensor. @@ -60,14 +60,16 @@ keepdim (bool, optional): If set to True, the reduced dimensions are retained in the result as dimensions with size one. Default: ``False`` +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to :attr:`dtype` before performing the operation, and the returned tensor's type will be :attr:`dtype`. If this argument is used in conjunction with the - :attr:`out` argument, the output tensor's type must match this argument; - otherwise, a RuntimeError will be raised. This argument is not currently - supported for :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + :attr:`out` argument, the output tensor's type must match this argument or a + RuntimeError will be raised. This argument is not currently supported for + :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` Examples:: From ef39be2224f81ea331489f4342c8af57f2677e9c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 27 Aug 2020 11:33:37 -0500 Subject: [PATCH 4/6] Use set_ in nuclear_norm_out to avoid resize and copy * Also add unit test --- aten/src/ATen/native/LinearAlgebra.cpp | 3 +-- test/test_torch.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 00f12ae3af94..3b91b0b4678d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1380,8 +1380,7 @@ Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bo if (keepdim) { result.unsqueeze_(-1); Tensor result_ = result.permute(permutation_reverse); - at::native::resize_output(result, result_.sizes()); - result.copy_(result_); + result.set_(result_); } return result; } diff --git a/test/test_torch.py b/test/test_torch.py index 88b328e9f9a2..a3badced551f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9798,6 +9798,31 @@ def test_diagflat(self, device): expected = torch.diag(x.contiguous().view(-1)) self.assertEqual(result, expected) + # Ensure that nuclear_norm's out variant gives the same result as the non-out + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64) + def test_nuclear_norm_out(self, device, dtype): + test_cases = [ + # input size, dim + ((25, 25), None), + ((25, 25), (0, 1)), + ((25, 25), (1, 0)), + ((25, 25, 25), (2, 0)), + ((25, 25, 25), (0, 1)), + ] + for keepdim in [False, True]: + for input_size, dim in test_cases: + x = torch.randn(*input_size, device=device, dtype=dtype) + result_out = torch.tensor(0, device=device, dtype=dtype) + if dim is None: + result = torch.nuclear_norm(x, keepdim=keepdim) + torch.nuclear_norm(x, keepdim=keepdim, out=result_out) + else: + result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) + torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) + self.assertEqual(result, result_out) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "Numpy not found") From 063d4c1201159beebff80db41242c4ac73fbe438 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 27 Aug 2020 14:26:18 -0500 Subject: [PATCH 5/6] Small fixes * Avoid unecessary clone * Use torch.empty rather than torch.tensor in unit test --- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- test/test_torch.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 5e8ce0f101f3..a692ca0a800e 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1551,7 +1551,7 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional "dtype = ", dtype, ", out.dtype = ", result.scalar_type()); } int64_t ndim = self.dim(); - Tensor result_ = result.clone(); + Tensor result_; if (opt_str_ord.has_value()) { // 'ord' is string auto str_ord = opt_str_ord.value(); diff --git a/test/test_torch.py b/test/test_torch.py index c7757bf251d2..ecb884b2f0fe 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9816,15 +9816,16 @@ def test_nuclear_norm_out(self, device, dtype): ] for keepdim in [False, True]: for input_size, dim in test_cases: + msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' x = torch.randn(*input_size, device=device, dtype=dtype) - result_out = torch.tensor(0, device=device, dtype=dtype) + result_out = torch.empty(0, device=device, dtype=dtype) if dim is None: result = torch.nuclear_norm(x, keepdim=keepdim) torch.nuclear_norm(x, keepdim=keepdim, out=result_out) else: result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) - self.assertEqual(result, result_out) + self.assertEqual(result, result_out, msg=msg) @skipCUDAIfNoMagma @skipCPUIfNoLapack From 96d33fb350a3d90895453890a34410a0ca012480 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 27 Aug 2020 23:00:04 -0500 Subject: [PATCH 6/6] Avoid running nuclear norm test on XLA --- test/test_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_torch.py b/test/test_torch.py index ecb884b2f0fe..def4f5993e81 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9802,6 +9802,7 @@ def test_diagflat(self, device): self.assertEqual(result, expected) # Ensure that nuclear_norm's out variant gives the same result as the non-out + @onlyOnCPUAndCUDA @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64)