8000 Add `torch.linalg.norm` by kurtamohler · Pull Request #42749 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add torch.linalg.norm #42749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ namespace c10 {
_(aten, clip_) \
_(aten, det) \
_(aten, linalg_det) \
_(aten, linalg_norm) \
_(aten, append) \
_(aten, item) \
_(aten, format) \
Expand Down
247 changes: 240 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
Expand Down Expand Up @@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
if (dim.size() == 1 || dim.size() == 0) {
return at::norm(self, 2, dim, keepdim);
}
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, self.dim());
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
if (self.is_complex()){
return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim));
return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
} else {
return at::sqrt(at::sum((self * self), dim, keepdim));
return at::sqrt(at::sum((self * self), dim_, keepdim));
}
}

Expand All @@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out(
if (dim.size() == 1 || dim.size() == 0) {
return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type());
}
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, self.dim());
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
if (self.is_complex()){
return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim, keepdim));
return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim_, keepdim));
} else {
return at::sqrt_out(result, at::sum((self * self), dim, keepdim));
return at::sqrt_out(result, at::sum((self * self), dim_, keepdim));
}
}

Expand Down Expand Up @@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {

Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, self.dim());

auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);
Tensor p = self.permute(permutation);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
Expand All @@ -1360,19 +1369,243 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {

Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, self.dim());

auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permut 8000 ation);

Tensor p = self.permute(permutation);
at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
if (keepdim) {
result.unsqueeze_(-1);
result = result.permute(permutation_reverse);
Tensor result_ = result.permute(permutation_reverse);
result.set_(result_);
}
return result;
}

// Creates a vector of length ndim with values equal to its indices
// (e.g. [0, 1, 2, ..., ndim-1])
static std::vector<int64_t> make_dim_list(int64_t ndim) {
std::vector<int64_t> dim_list(ndim);
for (int64_t ind = 0; ind < ndim; ind++) {
dim_list[ind] = ind;
}
return dim_list;
}

// Checks for valid arguments to linalg_norm when type(ord) == str
static void check_str_ord_valid(const std::string& str_ord, optional<IntArrayRef> opt_dim, int64_t ndim, optional<ScalarType> opt_dtype) {
TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord);
TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument");
bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2);
TORCH_CHECK(dims_valid, "order \"", str_ord,
"\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)");
}

// Performs vector norm for ord = +/-infinity, and the second dimension reduction
// for matrix norms.
static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) {
Tensor result;
if (self.numel() == 0 && self.sizes()[dim] > 0) {
// This special case is needed in matrix norm for tensors with 3 or more dims,
// or in vector norm for order inf and -inf for tesnsors with 2 or more dims.
// When the sizes of the dims to be reduced are greater than 0 but another dim
// in the tensor is size 0 (thus numel == 0), we must either flatten or resize
// the second reduction dim to 1, to avoid calling min/max, which would throw
// an error.
if (self.sizes()[dim] != 1) {
auto new_sizes = self.sizes().vec();
new_sizes[dim] = 1;
self.resize_(new_sizes);
}
result = keepdim ? s 8000 elf : self.flatten(dim);
} else {
if (ord > 0) {
result = std::get<0>(self.max(dim, keepdim));
} else {
result = std::get<0>(self.min(dim, keepdim));
}
}
return result;
}

// Performs matrix norm
static Tensor _linalg_norm_matrix(const Tensor &self, optional<Scalar> opt_ord,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
Tensor result;
auto ord = opt_ord.value_or(2.0).toDouble();
TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"matrix norm only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"matrix norm only supports strided layout, got: ", self.layout());

TORCH_CHECK(dim.size() == 2, "_linalg_norm_matrix: 'dim' must either specify 2 dimensions. ",
"Got 'dim' specifying ", dim.size(), " dims");
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, self.dim());
TORCH_CHECK(dim_[0] != dim_[1],
"Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead");

ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
at::isFloatingType(scalarType) || at::isComplexType(scalarType),
"Can only calculate the mean of floating and complex types. Got ",
toString(scalarType), " instead.");

Tensor self_;
if (opt_dtype.has_value()) {
self_ = self.to(scalarType);
} else {
self_ = self;
}

if (std::abs(ord) == 2) {
// Need to shift the reduction dims to the back, because at::svd will only operate on
// the last 2 dimensions
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);

result = std::get<1>(self_.permute(permutation).svd()).abs();
result = _norm_min_max(result, ord, result.dim() - 1, keepdim);

if (keepdim) {
result.unsqueeze_(-1);
result = result.permute(permutation_reverse);
}
} else {
// abs(p) == infinity and abs(p) == 1 will perform identical reductions, except
// that the order of the two dims is swapped. So we can swap the dims if
// abs(p) == infinity to simplify the rest of the operation's logic.
if (std::abs(ord) == INFINITY) {
std::swap(dim_[0], dim_[1]);
}
// If the dim of the second reduction is greater than that of the first reduction
// and we are not keeping the dims, then the fact that the output of the first
// reduction will have one fewer dimension means that the second reduction dim
// will be off by one, so we need to correct that.
if ((dim_[1] > dim_[0]) && !keepdim) {
dim_[1]--;
}
if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) {
result = self_.abs().sum(dim_[0], keepdim);
result = _norm_min_max(result, ord, dim_[1], keepdim);
} else {
TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm");
}
}
return result;
}

// Performs vector norm
// This function mostly serves as a wrapper for at::norm, but it overrides a few cases
// for numpy compatibility. These cases are corrected within this wrapper, rather than
// in at::norm itself, to avoid breaking backward compatibility.
static Tensor _linalg_norm_vector(const Tensor& self, optional<Scalar> opt_ord, std::vector<int64_t> dim, bool keepdim, optional<ScalarType> opt_dtype) {
if (opt_ord.has_value()) {
TORCH_INTERNAL_ASSERT(dim.size() == 1);
auto ord = opt_ord.value().toDouble();
Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self;
if (std::abs(ord) == INFINITY) {
// The ord = +/-infinity case is overridden because at::norm does not match numpy
// when the input contains extreme values (like nan or +/-inf) or if the input
// size is degenerate (like size(0), size(0, N), etc)
self_ = self_.abs();
return _norm_min_max(self_, ord, dim[0], keepdim);
} else if ((self_.numel() == 0) && (ord < 0)) {
// For negative orders with degenerate input sizes, at::norm's result does not
// match numpy.
Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim);
if (ord >= -1) {
// Result must be infinite in this case, and the simplest way to make that
// happen is to simply add infinity
result += INFINITY;
} else {
result = result.pow(1.0 / (ord + 1));
}
return result;
}
} else {
// If ord == None, need to check for unique dims because at::norm does not check it
// for this case.
std::vector<int64_t> dim_(dim);
maybe_wrap_dims(dim_, self.dim());
bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end();
TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")");
}
if (opt_dtype.has_value()) {
return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value());
} else {
return at::norm(self, opt_ord, dim, keepdim);
}
}

static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional<Scalar> opt_num_ord, optional<std::string> opt_str_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
// Callers must give the ord argument as either a number, a string, or neither.
// Since the user-facing API has no direct control over how this function is called, this is an internal assert.
TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value()));
if (opt_dtype.has_value()) {
auto dtype = opt_dtype.value();
TORCH_CHECK(dtype == result.scalar_type(), "provided dtype must match dtype of result, but got",
"dtype = ", dtype, ", out.dtype = ", result.scalar_type());
}
int64_t ndim = self.dim();
Tensor result_;
if (opt_str_ord.has_value()) {
// 'ord' is string
auto str_ord = opt_str_ord.value();
check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype);
if (str_ord == "fro") {
result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
} else if (str_ord == "nuc") {
if (opt_dim.has_value()) {
result_ = at::nuclear_norm(self, opt_dim.value(), keepdim);
} else {
result_ = at::nuclear_norm(self, keepdim);
}
}
} else {
// 'ord' is int or None
std::vector<int64_t> dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim);
if (!opt_num_ord.has_value() || dim_.size() == 1) {
result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype);
} else if (dim_.size() == 2) {
result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype);
} else {
TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is "
"not 1-D or 2-D");
}
}
resize_output(result, result_.sizes());
result.copy_(result_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I don't think we can do the set_ pattern used above here because it would replace result's storage. This might be OK, but our current thinking is to preserve given storages where possible.

return result;
}

// Numerical or None norms
Tensor linalg_norm(const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
Tensor result = at::empty({0}, options);
return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
}

// Frobenius and nuclear norms
Tensor linalg_norm(const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
Tensor result = at::empty({0}, options);
return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}

// Numerical or None norms
Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return linalg_norm_out_impl(result, self, opt_ord, c10::nullopt, opt_dim, keepdim, opt_dtype);
}

// Frobenius and nuclear norms
Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype);
}

static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
if (i == j)
return matrices[i];
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7307,6 +7307,22 @@

C958 - func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)

- func: linalg_norm(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg
variants: function

- func: linalg_norm.ord_str(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg
variants: function

- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
variants: function

- func: linalg_norm.ord_str_out(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
variants: function

## Functions that are only for testing
# It is undocumented and should not be used outside of tests.
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Functions
---------

.. autofunction:: det
.. autofunction:: norm
Loading
111
0