44#include < ATen/NativeFunctions.h>
55#include < ATen/native/CPUBlas.h>
66#include < ATen/native/LinearAlgebraUtils.h>
7+ #include < ATen/native/Resize.h>
78#include < ATen/TensorUtils.h>
89#include < ATen/Parallel.h>
910#include < ATen/LegacyTHFunctionsCPU.h>
@@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
12841285 if (dim.size () == 1 || dim.size () == 0 ) {
12851286 return at::norm (self, 2 , dim, keepdim);
12861287 }
1288+ auto dim_ = dim.vec ();
1289+ maybe_wrap_dims (dim_, self.dim ());
1290+ TORCH_CHECK (dim_[0 ] != dim_[1 ], " Expected dims to be different, got " , dim, " instead" );
12871291 if (self.is_complex ()){
1288- return at::sqrt (at::sum (at::real (self.conj () * self), dim , keepdim));
1292+ return at::sqrt (at::sum (at::real (self.conj () * self), dim_ , keepdim));
12891293 } else {
1290- return at::sqrt (at::sum ((self * self), dim , keepdim));
1294+ return at::sqrt (at::sum ((self * self), dim_ , keepdim));
12911295 }
12921296}
12931297
@@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out(
13051309 if (dim.size () == 1 || dim.size () == 0 ) {
13061310 return at::norm_out (result, self, 2 , dim, keepdim, self.scalar_type ());
13071311 }
1312+ auto dim_ = dim.vec ();
1313+ maybe_wrap_dims (dim_, self.dim ());
1314+ TORCH_CHECK (dim_[0 ] != dim_[1 ], " Expected dims to be different, got " , dim, " instead" );
13081315 if (self.is_complex ()){
1309- return at::sqrt_out (result, at::sum (at::real (self.conj () * self), dim , keepdim));
1316+ return at::sqrt_out (result, at::sum (at::real (self.conj () * self), dim_ , keepdim));
13101317 } else {
1311- return at::sqrt_out (result, at::sum ((self * self), dim , keepdim));
1318+ return at::sqrt_out (result, at::sum ((self * self), dim_ , keepdim));
13121319 }
13131320}
13141321
@@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
13421349
13431350Tensor nuclear_norm (const Tensor& self, IntArrayRef dim, bool keepdim) {
13441351 TORCH_CHECK (dim.size () == 2 , " nuclear norm requires a 'dim' argument of size 2" );
1352+ auto dim_ = dim.vec ();
1353+ maybe_wrap_dims (dim_, self.dim ());
13451354
1346- auto permutation = create_dim_backshift_permutation (dim [0 ], dim [1 ], self.dim ());
1355+ auto permutation = create_dim_backshift_permutation (dim_ [0 ], dim_ [1 ], self.dim ());
13471356 auto permutation_reverse = create_reverse_permutation (permutation);
13481357 Tensor p = self.permute (permutation);
13491358 // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
@@ -1360,19 +1369,243 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
13601369
13611370Tensor& nuclear_norm_out (Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
13621371 TORCH_CHECK (dim.size () == 2 , " nuclear norm requires a 'dim' argument of size 2" );
1372+ auto dim_ = dim.vec ();
1373+ maybe_wrap_dims (dim_, self.dim ());
13631374
1364- auto permutation = create_dim_backshift_permutation (dim [0 ], dim [1 ], self.dim ());
1375+ auto permutation = create_dim_backshift_permutation (dim_ [0 ], dim_ [1 ], self.dim ());
13651376 auto permutation_reverse = create_reverse_permutation (permutation);
13661377
13671378 Tensor p = self.permute (permutation);
13681379 at::sum_out (result, std::get<1 >(at::svd (p, /* some=*/ true , /* compute_uv=*/ false )), -1 , keepdim);
13691380 if (keepdim) {
13701381 result.unsqueeze_ (-1 );
1371- result = result.permute (permutation_reverse);
1382+ Tensor result_ = result.permute (permutation_reverse);
1383+ result.set_ (result_);
1384+ }
1385+ return result;
1386+ }
1387+
1388+ // Creates a vector of length ndim with values equal to its indices
1389+ // (e.g. [0, 1, 2, ..., ndim-1])
1390+ static std::vector<int64_t > make_dim_list (int64_t ndim) {
1391+ std::vector<int64_t > dim_list (ndim);
1392+ for (int64_t ind = 0 ; ind < ndim; ind++) {
1393+ dim_list[ind] = ind;
1394+ }
1395+ return dim_list;
1396+ }
1397+
1398+ // Checks for valid arguments to linalg_norm when type(ord) == str
1399+ static void check_str_ord_valid (const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim, optional<ScalarType> opt_dtype) {
1400+ TORCH_CHECK ((str_ord == " nuc" ) || (str_ord == " fro" ), " Invalid norm order: " , str_ord);
1401+ TORCH_CHECK (!opt_dtype.has_value (), " ord=\' " , str_ord, " \' does not yet support the dtype argument" );
1402+ bool dims_valid = (ndim == 2 && !opt_dim.has_value ()) || (opt_dim.has_value () && opt_dim.value ().size () == 2 );
1403+ TORCH_CHECK (dims_valid, " order \" " , str_ord,
1404+ " \" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)" );
1405+ }
1406+
1407+ // Performs vector norm for ord = +/-infinity, and the second dimension reduction
1408+ // for matrix norms.
1409+ static Tensor _norm_min_max (Tensor& self, double ord, int64_t dim, bool keepdim) {
1410+ Tensor result;
1411+ if (self.numel () == 0 && self.sizes ()[dim] > 0 ) {
1412+ // This special case is needed in matrix norm for tensors with 3 or more dims,
1413+ // or in vector norm for order inf and -inf for tesnsors with 2 or more dims.
1414+ // When the sizes of the dims to be reduced are greater than 0 but another dim
1415+ // in the tensor is size 0 (thus numel == 0), we must either flatten or resize
1416+ // the second reduction dim to 1, to avoid calling min/max, which would throw
1417+ // an error.
1418+ if (self.sizes ()[dim] != 1 ) {
1419+ auto new_sizes = self.sizes ().vec ();
1420+ new_sizes[dim] = 1 ;
1421+ self.resize_ (new_sizes);
1422+ }
1423+ result = keepdim ? self : self.flatten (dim);
1424+ } else {
1425+ if (ord > 0 ) {
1426+ result = std::get<0 >(self.max (dim, keepdim));
1427+ } else {
1428+ result = std::get<0 >(self.min (dim, keepdim));
1429+ }
1430+ }
1431+ return result;
1432+ }
1433+
1434+ // Performs matrix norm
1435+ static Tensor _linalg_norm_matrix (const Tensor &self, optional<Scalar> opt_ord,
1436+ IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
1437+ Tensor result;
1438+ auto ord = opt_ord.value_or (2.0 ).toDouble ();
1439+ TORCH_CHECK (self.device ().type () == DeviceType::CPU || self.device ().type () == DeviceType::CUDA,
1440+ " matrix norm only supports CPU AND CUDA device type, got: " , self.device ().type ());
1441+ TORCH_CHECK (self.layout () == Layout::Strided,
1442+ " matrix norm only supports strided layout, got: " , self.layout ());
1443+
1444+ TORCH_CHECK (dim.size () == 2 , " _linalg_norm_matrix: 'dim' must either specify 2 dimensions. " ,
1445+ " Got 'dim' specifying " , dim.size (), " dims" );
1446+ auto dim_ = dim.vec ();
1447+ maybe_wrap_dims (dim_, self.dim ());
1448+ TORCH_CHECK (dim_[0 ] != dim_[1 ],
1449+ " Expected dims to be different, got (" , dim[0 ], " , " , dim[1 ], " ) instead" );
1450+
1451+ ScalarType scalarType = opt_dtype.has_value () ? opt_dtype.value () : self.scalar_type ();
1452+ TORCH_CHECK (
1453+ at::isFloatingType (scalarType) || at::isComplexType (scalarType),
1454+ " Can only calculate the mean of floating and complex types. Got " ,
1455+ toString (scalarType), " instead." );
1456+
1457+ Tensor self_;
1458+ if (opt_dtype.has_value ()) {
1459+ self_ = self.to (scalarType);
1460+ } else {
1461+ self_ = self;
1462+ }
1463+
1464+ if (std::abs (ord) == 2 ) {
1465+ // Need to shift the reduction dims to the back, because at::svd will only operate on
1466+ // the last 2 dimensions
1467+ auto permutation = create_dim_backshift_permutation (dim_[0 ], dim_[1 ], self.dim ());
1468+ auto permutation_reverse = create_reverse_permutation (permutation);
1469+
1470+ result = std::get<1 >(self_.permute (permutation).svd ()).abs ();
1471+ result = _norm_min_max (result, ord, result.dim () - 1 , keepdim);
1472+
1473+ if (keepdim) {
1474+ result.unsqueeze_ (-1 );
1475+ result = result.permute (permutation_reverse);
1476+ }
1477+ } else {
1478+ // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except
1479+ // that the order of the two dims is swapped. So we can swap the dims if
1480+ // abs(p) == infinity to simplify the rest of the operation's logic.
1481+ if (std::abs (ord) == INFINITY) {
1482+ std::swap (dim_[0 ], dim_[1 ]);
1483+ }
1484+ // If the dim of the second reduction is greater than that of the first reduction
1485+ // and we are not keeping the dims, then the fact that the output of the first
1486+ // reduction will have one fewer dimension means that the second reduction dim
1487+ // will be off by one, so we need to correct that.
1488+ if ((dim_[1 ] > dim_[0 ]) && !keepdim) {
1489+ dim_[1 ]--;
1490+ }
1491+ if (std::abs (ord) == 1 || std::abs (ord) == INFINITY) {
1492+ result = self_.abs ().sum (dim_[0 ], keepdim);
1493+ result = _norm_min_max (result, ord, dim_[1 ], keepdim);
41EC
1494+ } else {
1495+ TORCH_CHECK (false , " Order " , ord, " not supported for matrix norm" );
1496+ }
13721497 }
13731498 return result;
13741499}
13751500
1501+ // Performs vector norm
1502+ // This function mostly serves as a wrapper for at::norm, but it overrides a few cases
1503+ // for numpy compatibility. These cases are corrected within this wrapper, rather than
1504+ // in at::norm itself, to avoid breaking backward compatibility.
1505+ static Tensor _linalg_norm_vector (const Tensor& self, optional<Scalar> opt_ord, std::vector<int64_t > dim, bool keepdim, optional<ScalarType> opt_dtype) {
1506+ if (opt_ord.has_value ()) {
1507+ TORCH_INTERNAL_ASSERT (dim.size () == 1 );
1508+ auto ord = opt_ord.value ().toDouble ();
1509+ Tensor self_ = opt_dtype.has_value () ? self.to (opt_dtype.value ()) : self;
1510+ if (std::abs (ord) == INFINITY) {
1511+ // The ord = +/-infinity case is overridden because at::norm does not match numpy
1512+ // when the input contains extreme values (like nan or +/-inf) or if the input
1513+ // size is degenerate (like size(0), size(0, N), etc)
1514+ self_ = self_.abs ();
1515+ return _norm_min_max (self_, ord, dim[0 ], keepdim);
1516+ } else if ((self_.numel () == 0 ) && (ord < 0 )) {
1517+ // For negative orders with degenerate input sizes, at::norm's result does not
1518+ // match numpy.
1519+ Tensor result = self_.abs ().pow (ord + 1 ).sum (dim[0 ], keepdim);
1520+ if (ord >= -1 ) {
1521+ // Result must be infinite in this case, and the simplest way to make that
1522+ // happen is to simply add infinity
1523+ result += INFINITY;
1524+ } else {
1525+ result = result.pow (1.0 / (ord + 1 ));
1526+ }
1527+ return result;
1528+ }
1529+ } else {
1530+ // If ord == None, need to check for unique dims because at::norm does not check it
1531+ // for this case.
1532+ std::vector<int64_t > dim_ (dim);
1533+ maybe_wrap_dims (dim_, self.dim ());
1534+ bool unique_dims = (std::unique (dim_.begin (), dim_.end ())) == dim_.end ();
1535+ TORCH_CHECK (unique_dims, " Expected dims to be different, got this instead: (" , dim, " )" );
1536+ }
1537+ if (opt_dtype.has_value ()) {
1538+ return at::norm (self, opt_ord, dim, keepdim, opt_dtype.value ());
1539+ } else {
1540+ return at::norm (self, opt_ord, dim, keepdim);
1541+ }
1542+ }
1543+
1544+ static Tensor& linalg_norm_out_impl (Tensor& result, const Tensor& self, optional<Scalar> opt_num_ord, optional<std::string> opt_str_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1545+ // Callers must give the ord argument as either a number, a string, or neither.
1546+ // Since the user-facing API has no direct control over how this function is called, this is an internal assert.
1547+ TORCH_INTERNAL_ASSERT (!(opt_num_ord.has_value () && opt_str_ord.has_value ()));
1548+ if (opt_dtype.has_value ()) {
1549+ auto dtype = opt_dtype.value ();
1550+ TORCH_CHECK (dtype == result.scalar_type (), " provided dtype must match dtype of result, but got" ,
1551+ " dtype = " , dtype, " , out.dtype = " , result.scalar_type ());
1552+ }
1553+ int64_t ndim = self.dim ();
1554+ Tensor result_;
1555+ if (opt_str_ord.has_value ()) {
1556+ // 'ord' is string
1557+ auto str_ord = opt_str_ord.value ();
1558+ check_str_ord_valid (str_ord, opt_dim, ndim, opt_dtype);
1559+ if (str_ord == " fro" ) {
1560+ result_ = at::frobenius_norm (self, opt_dim.value_or (IntArrayRef ({0 , 1 })), keepdim);
1561+ } else if (str_ord == " nuc" ) {
1562+ if (opt_dim.has_value ()) {
1563+ result_ = at::nuclear_norm (self, opt_dim.value (), keepdim);
1564+ } else {
1565+ result_ = at::nuclear_norm (self, keepdim);
1566+ }
1567+ }
1568+ } else {
1569+ // 'ord' is int or None
1570+ std::vector<int64_t > dim_ = opt_dim.has_value () ? opt_dim.value ().vec () : make_dim_list (ndim);
1571+ if (!opt_num_ord.has_value () || dim_.size () == 1 ) {
1572+ result_ = _linalg_norm_vector (self, opt_num_ord, dim_, keepdim, opt_dtype);
1573+ } else if (dim_.size () == 2 ) {
1574+ result_ = _linalg_norm_matrix (self, opt_num_ord.value (), dim_, keepdim, opt_dtype);
1575+ } else {
1576+ TORCH_CHECK (false , " 'dim' must specify 1 or 2 dimensions when order is numerical and input is "
1577+ " not 1-D or 2-D" );
1578+ }
1579+ }
1580+ resize_output (result, result_.sizes ());
1581+ result.copy_ (result_);
1582+ return result;
1583+ }
1584+
1585+ // Numerical or None norms
41EC
1586+ Tensor linalg_norm (const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1587+ auto options = TensorOptions ().dtype (opt_dtype.has_value () ? opt_dtype.value () : self.scalar_type ()).device (self.device ());
1588+ Tensor result = at::empty ({0 }, options);
1589+ return at::native::linalg_norm_out (result, self, opt_ord, opt_dim, keepdim, opt_dtype);
1590+ }
1591+
1592+ // Frobenius and nuclear norms
1593+ Tensor linalg_norm (const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1594+ auto options = TensorOptions ().dtype (opt_dtype.has_value () ? opt_dtype.value () : self.scalar_type ()).device (self.device ());
1595+ Tensor result = at::empty ({0 }, options);
1596+ return at::native::linalg_norm_out (result, self, ord, opt_dim, keepdim, opt_dtype);
1597+ }
1598+
1599+ // Numerical or None norms
1600+ Tensor& linalg_norm_out (Tensor& result, const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1601+ return linalg_norm_out_impl (result, self, opt_ord, c10::nullopt , opt_dim, keepdim, opt_dtype);
1602+ }
1603+
1604+ // Frobenius and nuclear norms
1605+ Tensor& linalg_norm_out (Tensor& result, const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1606+ return linalg_norm_out_impl (result, self, c10::nullopt , ord, opt_dim, keepdim, opt_dtype);
1607+ }
1608+
13761609static inline Tensor _chain_matmul_general (TensorList matrices, std::vector<std::vector<int64_t >>& order, int64_t i, int64_t j) {
13771610 if (i == j)
13781611 return matrices[i];
0 commit comments