8000 Add `torch.linalg.norm` (#42749) · pytorch/pytorch@68b9daa · GitHub
[go: up one dir, main page]

Skip to content

Commit 68b9daa

Browse files
kurtamohlerfacebook-github-bot
authored andcommitted
Add torch.linalg.norm (#42749)
Summary: Adds `torch.linalg.norm` function that matches the behavior of `numpy.linalg.norm`. Additional changes: * Add support for dimension wrapping in `frobenius_norm` and `nuclear_norm` * Fix `out` argument behavior for `nuclear_norm` * Fix issue where `frobenius_norm` allowed duplicates in `dim` argument * Add `_norm_matrix` Closes #24802 Pull Request resolved: #42749 Reviewed By: ngimel Differential Revision: D23336234 Pulled By: mruberry fbshipit-source-id: f0aba3089a3a0bf856aa9c4215e673ff34228fac
1 parent cd0bab8 commit 68b9daa

File tree

9 files changed

+928
-10
lines changed

9 files changed

+928
-10
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ namespace c10 {
175175
_(aten, clip_) \
176176
_(aten, det) \
177177
_(aten, linalg_det) \
178+
_(aten, linalg_norm) \
178179
_(aten, append) \
179180
_(aten, item) \
180181
_(aten, format) \

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 240 additions & 7 deletions
41EC
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/NativeFunctions.h>
55
#include <ATen/native/CPUBlas.h>
66
#include <ATen/native/LinearAlgebraUtils.h>
7+
#include <ATen/native/Resize.h>
78
#include <ATen/TensorUtils.h>
89
#include <ATen/Parallel.h>
910
#include <ATen/LegacyTHFunctionsCPU.h>
@@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
12841285
if (dim.size() == 1 || dim.size() == 0) {
12851286
return at::norm(self, 2, dim, keepdim);
12861287
}
1288+
auto dim_ = dim.vec();
1289+
maybe_wrap_dims(dim_, self.dim());
1290+
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
12871291
if (self.is_complex()){
1288-
return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim));
1292+
return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
12891293
} else {
1290-
return at::sqrt(at::sum((self * self), dim, keepdim));
1294+
return at::sqrt(at::sum((self * self), dim_, keepdim));
12911295
}
12921296
}
12931297

@@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out(
13051309
if (dim.size() == 1 || dim.size() == 0) {
13061310
return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type());
13071311
}
1312+
auto dim_ = dim.vec();
1313+
maybe_wrap_dims(dim_, self.dim());
1314+
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
13081315
if (self.is_complex()){
1309-
return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim, keepdim));
1316+
return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim_, keepdim));
13101317
} else {
1311-
return at::sqrt_out(result, at::sum((self * self), dim, keepdim));
1318+
return at::sqrt_out(result, at::sum((self * self), dim_, keepdim));
13121319
}
13131320
}
13141321

@@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
13421349

13431350
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
13441351
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
1352+
auto dim_ = dim.vec();
1353+
maybe_wrap_dims(dim_, self.dim());
13451354

1346-
auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
1355+
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
13471356
auto permutation_reverse = create_reverse_permutation(permutation);
13481357
Tensor p = self.permute(permutation);
13491358
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
@@ -1360,19 +1369,243 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
13601369

13611370
Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
13621371
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
1372+
auto dim_ = dim.vec();
1373+
maybe_wrap_dims(dim_, self.dim());
13631374

1364-
auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
1375+
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
13651376
auto permutation_reverse = create_reverse_permutation(permutation);
13661377

13671378
Tensor p = self.permute(permutation);
13681379
at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
13691380
if (keepdim) {
13701381
result.unsqueeze_(-1);
1371-
result = result.permute(permutation_reverse);
1382+
Tensor result_ = result.permute(permutation_reverse);
1383+
result.set_(result_);
1384+
}
1385+
return result;
1386+
}
1387+
1388+
// Creates a vector of length ndim with values equal to its indices
1389+
// (e.g. [0, 1, 2, ..., ndim-1])
1390+
static std::vector<int64_t> make_dim_list(int64_t ndim) {
1391+
std::vector<int64_t> dim_list(ndim);
1392+
for (int64_t ind = 0; ind < ndim; ind++) {
1393+
dim_list[ind] = ind;
1394+
}
1395+
return dim_list;
1396+
}
1397+
1398+
// Checks for valid arguments to linalg_norm when type(ord) == str
1399+
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim, optional<ScalarType> opt_dtype) {
1400+
TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord);
1401+
TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument");
1402+
bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2);
1403+
TORCH_CHECK(dims_valid, "order \"", str_ord,
1404+
"\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)");
1405+
}
1406+
1407+
// Performs vector norm for ord = +/-infinity, and the second dimension reduction
1408+
// for matrix norms.
1409+
static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) {
1410+
Tensor result;
1411+
if (self.numel() == 0 && self.sizes()[dim] > 0) {
1412+
// This special case is needed in matrix norm for tensors with 3 or more dims,
1413+
// or in vector norm for order inf and -inf for tesnsors with 2 or more dims.
1414+
// When the sizes of the dims to be reduced are greater than 0 but another dim
1415+
// in the tensor is size 0 (thus numel == 0), we must either flatten or resize
1416+
// the second reduction dim to 1, to avoid calling min/max, which would throw
1417+
// an error.
1418+
if (self.sizes()[dim] != 1) {
1419+
auto new_sizes = self.sizes().vec();
1420+
new_sizes[dim] = 1;
1421+
self.resize_(new_sizes);
1422+
}
1423+
result = keepdim ? self : self.flatten(dim);
1424+
} else {
1425+
if (ord > 0) {
1426+
result = std::get<0>(self.max(dim, keepdim));
1427+
} else {
1428+
result = std::get<0>(self.min(dim, keepdim));
1429+
}
1430+
}
1431+
return result;
1432+
}
1433+
1434+
// Performs matrix norm
1435+
static Tensor _linalg_norm_matrix(const Tensor &self, optional<Scalar> opt_ord,
1436+
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
1437+
Tensor result;
1438+
auto ord = opt_ord.value_or(2.0).toDouble();
1439+
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
1440+
"matrix norm only supports CPU AND CUDA device type, got: ", self.device().type());
1441+
TORCH_CHECK(self.layout() == Layout::Strided,
1442+
"matrix norm only supports strided layout, got: ", self.layout());
1443+
1444+
TORCH_CHECK(dim.size() == 2, "_linalg_norm_matrix: 'dim' must either specify 2 dimensions. ",
1445+
"Got 'dim' specifying ", dim.size(), " dims");
1446+
auto dim_ = dim.vec();
1447+
maybe_wrap_dims(dim_, self.dim());
1448+
TORCH_CHECK(dim_[0] != dim_[1],
1449+
"Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead");
1450+
1451+
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
1452+
TORCH_CHECK(
1453+
at::isFloatingType(scalarType) || at::isComplexType(scalarType),
1454+
"Can only calculate the mean of floating and complex types. Got ",
1455+
toString(scalarType), " instead.");
1456+
1457+
Tensor self_;
1458+
if (opt_dtype.has_value()) {
1459+
self_ = self.to(scalarType);
1460+
} else {
1461+
self_ = self;
1462+
}
1463+
1464+
if (std::abs(ord) == 2) {
1465+
// Need to shift the reduction dims to the back, because at::svd will only operate on
1466+
// the last 2 dimensions
1467+
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
1468+
auto permutation_reverse = create_reverse_permutation(permutation);
1469+
1470+
result = std::get<1>(self_.permute(permutation).svd()).abs();
1471+
result = _norm_min_max(result, ord, result.dim() - 1, keepdim);
1472+
1473+
if (keepdim) {
1474+
result.unsqueeze_(-1);
1475+
result = result.permute(permutation_reverse);
1476+
}
1477+
} else {
1478+
// abs(p) == infinity and abs(p) == 1 will perform identical reductions, except
1479+
// that the order of the two dims is swapped. So we can swap the dims if
1480+
// abs(p) == infinity to simplify the rest of the operation's logic.
1481+
if (std::abs(ord) == INFINITY) {
1482+
std::swap(dim_[0], dim_[1]);
1483+
}
1484+
// If the dim of the second reduction is greater than that of the first reduction
1485+
// and we are not keeping the dims, then the fact that the output of the first
1486+
// reduction will have one fewer dimension means that the second reduction dim
1487+
// will be off by one, so we need to correct that.
1488+
if ((dim_[1] > dim_[0]) && !keepdim) {
1489+
dim_[1]--;
1490+
}
1491+
if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) {
1492+
result = self_.abs().sum(dim_[0], keepdim);
1493+
result = _norm_min_max(result, ord, dim_[1], keepdim);
1494+
} else {
1495+
TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm");
1496+
}
13721497
}
13731498
return result;
13741499
}
13751500

1501+
// Performs vector norm
1502+
// This function mostly serves as a wrapper for at::norm, but it overrides a few cases
1503+
// for numpy compatibility. These cases are corrected within this wrapper, rather than
1504+
// in at::norm itself, to avoid breaking backward compatibility.
1505+
static Tensor _linalg_norm_vector(const Tensor& self, optional<Scalar> opt_ord, std::vector<int64_t> dim, bool keepdim, optional<ScalarType> opt_dtype) {
1506+
if (opt_ord.has_value()) {
1507+
TORCH_INTERNAL_ASSERT(dim.size() == 1);
1508+
auto ord = opt_ord.value().toDouble();
1509+
Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self;
1510+
if (std::abs(ord) == INFINITY) {
1511+
// The ord = +/-infinity case is overridden because at::norm does not match numpy
1512+
// when the input contains extreme values (like nan or +/-inf) or if the input
1513+
// size is degenerate (like size(0), size(0, N), etc)
1514+
self_ = self_.abs();
1515+
return _norm_min_max(self_, ord, dim[0], keepdim);
1516+
} else if ((self_.numel() == 0) && (ord < 0)) {
1517+
// For negative orders with degenerate input sizes, at::norm's result does not
1518+
// match numpy.
1519+
Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim);
1520+
if (ord >= -1) {
1521+
// Result must be infinite in this case, and the simplest way to make that
1522+
// happen is to simply add infinity
1523+
result += INFINITY;
1524+
} else {
1525+
result = result.pow(1.0 / (ord + 1));
1526+
}
1527+
return result;
1528+
}
1529+
} else {
1530+
// If ord == None, need to check for unique dims because at::norm does not check it
1531+
// for this case.
1532+
std::vector<int64_t> dim_(dim);
1533+
maybe_wrap_dims(dim_, self.dim());
1534+
bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end();
1535+
TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")");
1536+
}
1537+
if (opt_dtype.has_value()) {
1538+
return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value());
1539+
} else {
1540+
return at::norm(self, opt_ord, dim, keepdim);
1541+
}
1542+
}
1543+
1544+
static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional<Scalar> opt_num_ord, optional<std::string> opt_str_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1545+
// Callers must give the ord argument as either a number, a string, or neither.
1546+
// Since the user-facing API has no direct control over how this function is called, this is an internal assert.
1547+
TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value()));
1548+
if (opt_dtype.has_value()) {
1549+
auto dtype = opt_dtype.value();
1550+
TORCH_CHECK(dtype == result.scalar_type(), "provided dtype must match dtype of result, but got",
1551+
"dtype = ", dtype, ", out.dtype = ", result.scalar_type());
1552+
}
1553+
int64_t ndim = self.dim();
1554+
Tensor result_;
1555+
if (opt_str_ord.has_value()) {
1556+
// 'ord' is string
1557+
auto str_ord = opt_str_ord.value();
1558+
check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype);
1559+
if (str_ord == "fro") {
1560+
result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
1561+
} else if (str_ord == "nuc") {
1562+
if (opt_dim.has_value()) {
1563+
result_ = at::nuclear_norm(self, opt_dim.value(), keepdim);
1564+
} else {
1565+
result_ = at::nuclear_norm(self, keepdim);
1566+
}
1567+
}
1568+
} else {
1569+
// 'ord' is int or None
1570+
std::vector<int64_t> dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim);
1571+
if (!opt_num_ord.has_value() || dim_.size() == 1) {
1572+
result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype);
1573+
} else if (dim_.size() == 2) {
1574+
result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype);
1575+
} else {
1576+
TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is "
1577+
"not 1-D or 2-D");
1578+
}
1579+
}
1580+
resize_output(result, result_.sizes());
1581+
result.copy_(result_);
1582+
return result;
1583+
}
1584+
1585+
// Numerical or None norms 41EC
1586+
Tensor linalg_norm(const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1587+
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
1588+
Tensor result = at::empty({0}, options);
1589+
return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
1590+
}
1591+
1592+
// Frobenius and nuclear norms
1593+
Tensor linalg_norm(const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1594+
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
1595+
Tensor result = at::empty({0}, options);
1596+
return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
1597+
}
1598+
1599+
// Numerical or None norms
1600+
Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1601+
return linalg_norm_out_impl(result, self, opt_ord, c10::nullopt, opt_dim, keepdim, opt_dtype);
1602+
}
1603+
1604+
// Frobenius and nuclear norms
1605+
Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
1606+
return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype);
1607+
}
1608+
13761609
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
13771610
if (i == j)
13781611
return matrices[i];

aten/src/ATen/native/native_functions.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7347,6 +7347,22 @@
73477347

73487348
- func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
73497349

7350+
- func: linalg_norm(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
7351+
python_module: linalg
7352+
variants: function
7353+
7354+
- func: linalg_norm.ord_str(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
7355+
python_module: linalg
7356+
variants: function
7357+
7358+
- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
7359+
python_module: linalg
7360+
variants: function
7361+
7362+
- func: linalg_norm.ord_str_out(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
7363+
python_module: linalg
7364+
variants: function
7365+
73507366
## Functions that are only for testing
73517367
# It is undocumented and should not be used outside of tests.
73527368
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor

docs/source/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Functions
1313
---------
1414

1515
.. autofunction:: det
16+
.. autofunction:: norm

0 commit comments

Comments
 (0)
0