10000 Fix signature of torch.sparse_coo_tensor() (#152681) · pytorch/pytorch@cf7451f · GitHub
[go: up one dir, main page]

Skip to content

Commit cf7451f

Browse files
ILCSFNOpytorchmergebot
authored andcommitted
Fix signature of torch.sparse_coo_tensor() (#152681)
Fixes #145371 @pearu Searched all and find these codes, wondering whether is the root cause of the issue, could you have a review? Thanks a lot! Pull Request resolved: #152681 Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/nikitaved
1 parent f58143b commit cf7451f

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

aten/src/ATen/native/sparse/SparseTensor.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,14 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
356356
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
357357
}
358358

359-
return at::_sparse_coo_tensor_with_dims_and_tensors(
360-
sparse_dim,
361-
dense_dim,
362-
computed_sizes,
359+
return at::native::_sparse_coo_tensor_unsafe(
363360
indices,
364361
values,
365-
values.options().layout(kSparse),
362+
computed_sizes,
363+
optTypeMetaToScalarType(options.dtype_opt()),
364+
options.layout_opt(),
365+
options.device_opt(),
366+
options.pinned_memory_opt(),
366367
is_coalesced);
367368
}
368369

test/test_sparse.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,18 +440,22 @@ def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced):
440440
self.assertEqual(t.is_coalesced(), coalesced)
441441

442442
def func(indices, values, shape, is_coalesced):
443-
s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced)
443+
if shape is None:
444+
s = torch.sparse_coo_tensor(indices, values, check_invariants=True, is_coalesced=is_coalesced)
445+
else:
446+
s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced)
444447
self.assertEqual(s.is_coalesced(), is_coalesced)
445448
return s.to_dense(masked_grad=False)
446449

447-
if coalesced:
448-
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
449-
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
450-
else:
451-
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
452-
with self.assertRaisesRegex(RuntimeError,
453-
"cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
454-
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
450+
for shape in {t.shape, None}:
451+
if coalesced:
452+
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, False))
453+
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, True))
454+
else:
455+
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, False))
456+
with self.assertRaisesRegex(RuntimeError,
457+
"cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
458+
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, True))
455459

456460
@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
457461
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")

torch/csrc/autograd/python_torch_functions_manual.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ static PyObject* THPVariable_sparse_coo_tensor(
221221
PyObject* kwargs) {
222222
HANDLE_TH_ERRORS
223223
static PythonArgParser parser({
224-
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
224+
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None, bool is_coalesced=None)",
225225
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None, bool is_coalesced=None)",
226226
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
227227
});

torch/csrc/utils/tensor_new.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ Tensor sparse_coo_tensor_ctor(
11571157
ARG_PIN_MEMORY,
11581158
ARG_REQUIRES_GRAD,
11591159
ARG_CHECK_INVARIANTS,
1160+
ARG_IS_COALESCED,
11601161
ARGS_COUNT
11611162
};
11621163
enum {
@@ -1218,7 +1219,8 @@ Tensor sparse_coo_tensor_ctor(
12181219
return at::sparse_coo_tensor(
12191220
indices,
12201221
values,
1221-
values.options().layout(at::kSparse).pinned_memory(pin_memory))
1222+
values.options().layout(at::kSparse).pinned_memory(pin_memory),
1223+
r.toBoolOptional(ARG_IS_COALESCED))
12221224
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
12231225
} else if (r.idx == 1) {
12241226
bool pin_memory = r.toBool(ARG_PIN_MEMORY1);

0 commit comments

Comments
 (0)
0