8000 Treat dim=[] same as dim=None by Silv3S · Pull Request #153570 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Treat dim=[] same as dim=None #153570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -5318,6 +5318,15 @@ def fn(x):

self.common(fn, (x,))

def test_vector_norm_compile(self):
x = torch.randn([16, 32], dtype=torch.float)
ref = torch.linalg.vector_norm(x, ord=2, dim=[], keepdim=False, dtype=None)
compiled_vector_norm = torch.compile(
torch.linalg.vector_norm, backend="inductor"
)
res = compiled_vector_norm(x, ord=2, dim=[], keepdim=False, dtype=None)
self.assertEqual(ref, res)


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down
9 changes: 5 additions & 4 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _rand_shape(dim, min_size, max_size):
shape.append(random.randint(min_size, max_size))
return tuple(shape)

def _reduced_shape(shape, dim=None, keepdim=False):
def _reduced_shape(shape, empty_dim_as_none=False, dim=None, keepdim=False):
"""Computes the expected reduced shape given dim and keepdim

Args:
Expand All @@ -77,7 +77,7 @@ def _reduced_shape(shape, dim=None, keepdim=False):
Returns:
The reduced shape
"""
if dim is None:
if dim is None or (empty_dim_as_none and dim == []):
return [1] * len(shape) if keepdim else []

# Wrap negative dims
Expand Down Expand Up @@ -105,7 +105,8 @@ def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim)
t = make_tensor(shape, dtype=torch.float, device=device)
args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim))
result = op(t, *args, **dim_keepdim, **kwargs)
expected_shape = _reduced_shape(shape, **dim_keepdim)
empty_dim_as_none = (op.name == "linalg.vector_norm" or op.name == "_refs.linalg.vector_norm")
expected_shape = _reduced_shape(shape, empty_dim_as_none, **dim_keepdim)
self.assertEqual(result.shape, expected_shape, f"""
expected output shape to be {expected_shape} but got {list(result.shape)}
for input shape {shape} and {dim_keepdim}
Expand Down Expand Up @@ -314,7 +315,7 @@ def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo):
for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
result = op(t, *args, dim=dim, **kwargs)
self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
self.assertEqual(result.shape, _reduced_shape(t.shape, dim=dim))

def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
"""Helper method to test noncontiguous input tensors."""
Expand Down
3 changes: 3 additions & 0 deletions torch/_refs/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def vector_norm(
reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)

is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
if dim == []:
dim = None

if (dim is None and x.numel() == 1) or (
dim is not None
and (x.ndim > 0 and all(guard_size_oblivious(x.shape[d] == 1) for d in dim))
Expand Down
7 changes: 0 additions & 7 deletions torch/testing/_internal/opinfo/definitions/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,13 +1751,6 @@ def make_input():
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
generate_args_kwargs=sample_kwargs_vector_norm,
aten_name="linalg_vector_norm",
skips=(
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
),
),
OpInfo(
"linalg.lu_factor",
Expand Down
Loading
0