8000 ENH: Add a force argument to `numpy()` (#78564) (#78564) · pytorch/pytorch@b711b6c · GitHub
[go: up one dir, main page]

Skip to content

Commit b711b6c

Browse files
HaoZekefacebook-github-bot
authored andcommitted
ENH: Add a force argument to numpy() (#78564) (#78564)
Summary: **Reopened** to help with merge issues. See #59790 for full context. Fixes #20778. Helps #71688. Finalizes martinPasen's force argument for `Tensor.numpy()`. It is set to False by default. If it's set to True then we: 1. detatch the Tensor, if requires_grad == True 2. move to cpu, if not on cpu already 3. Uses .resolve_conj() if .is_conj() == True 4. Uses .resolve_neg() if .is_neg() == True cc albanD Pull Request resolved: #78564 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3f58dd18dc6fc18ed82fb1632cea48373c0a7798 Reviewed By: seemethere Differential Revision: D36935606 Pulled By: seemethere fbshipit-source-id: dc2dd7f569feb8da29add55db3d1625241ff8d77
1 parent 8a711ed commit b711b6c

File tree

6 files changed

+92
-39
lines changed

6 files changed

+92
-39
lines changed

test/test_numpy_interop.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,31 @@ def test_to_numpy_bool(self, device) -> None:
156156
self.assertEqual(y.dtype, np.bool_)
157157
self.assertEqual(x[0], y[0])
158158

159+
def test_to_numpy_force_argument(self, device) -> None:
160+
for force in [False, True]:
161+
for requires_grad in [False, True]:
162+
for sparse in [False, True]:
163+
for conj in [False, True]:
164+
data = [[1 + 2j, -2 + 3j], [-1 - 2j, 3 - 2j]]
165+
x = torch.tensor(data, requires_grad=requires_grad, device=device)
166+
y = x
167+
if sparse:
168+
if requires_grad:
169+
continue
170+
x = x.to_sparse()
171+
if conj:
172+
x = x.conj()
173+
y = x.resolve_conj()
174+
expect_error = requires_grad or sparse or conj or not device == 'cpu'
175+
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
176+
if not force and expect_error:
177+
self.assertRaisesRegex((RuntimeError, TypeError), error_msg, lambda: x.numpy())
178+
self.assertRaisesRegex((RuntimeError, TypeError), error_msg, lambda: x.numpy(force=False))
179+
elif force and sparse:
180+
self.assertRaisesRegex(TypeError, error_msg, lambda: x.numpy(force=True))
181+
else:
182+
self.assertEqual(x.numpy(force=force), y)
183+
159184
def test_from_numpy(self, device) -> None:
160185
dtypes = [
161186
np.double,

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -791,15 +791,22 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
791791

792792
// implemented on the python object bc PyObjects not declarable in native_functions.yaml
793793
// See: ATen/native/README.md for more context
794-
static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
794+
static PyObject * THPVariable_numpy(PyObject* self, PyObject* args, PyObject* kwargs)
795795
{
796796
HANDLE_TH_ERRORS
797-
if (check_has_torch_function(self)) {
798-
return handle_torch_function(self, "numpy");
797+
static PythonArgParser parser({
798+
"numpy(*, bool force=False)"
799+
});
800+
auto& self_ = THPVariable_Unpack(self);
801+
ParsedArgs<1> parsed_args;
802+
auto r = parser.parse(self, args, kwargs, parsed_args);
803+
804+
if (r.has_torch_function()) {
805+
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
799806
}
807+
800808
jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
801-
auto& self_ = THPVariable_Unpack(self);
802-
return torch::utils::tensor_to_numpy(self_);
809+
return torch::utils::tensor_to_numpy(self_, r.toBool(0));
803810
END_HANDLE_TH_ERRORS
804811
}
805812

@@ -1271,7 +1278,7 @@ PyMethodDef variable_methods[] = {
12711278
{"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, NULL},
12721279
{"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, NULL},
12731280
{"numel", THPVariable_numel, METH_NOARGS, NULL},
1274-
{"numpy", THPVariable_numpy, METH_NOARGS, NULL},
1281+
{"numpy", castPyCFunctionWithKeywords(THPVariable_numpy), METH_VARARGS | METH_KEYWORDS, NULL},
12751282
{"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, NULL},
12761283
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
12771284
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},

tools/pyi/gen_pyi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def gen_pyi(
642642
"cuda": [
643643
"def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..."
644644
],
645-
"numpy": ["def numpy(self) -> Any: ..."],
645+
"numpy": ["def numpy(self, *, force: _bool=False) -> Any: ..."],
646646
"apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
647647
"map_": [
648648
"def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."

torch/_tensor_docs.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,11 +2841,26 @@ def callable(a, b) -> number
28412841

28422842
add_docstr_all('numpy',
28432843
r"""
2844-
numpy() -> numpy.ndarray
2844+
numpy(*, force=False) -> numpy.ndarray
28452845
2846-
Returns :attr:`self` tensor as a NumPy :class:`ndarray`. This tensor and the
2847-
returned :class:`ndarray` share the same underlying storage. Changes to
2848-
:attr:`self` tensor will be reflected in the :class:`ndarray` and vice versa.
2846+
Returns the tensor as a NumPy :class:`ndarray`.
2847+
2848+
If :attr:`force` is ``False`` (the default), the conversion
2849+
is performed only if the tensor is on the CPU, does not require grad,
2850+
does not have its conjugate bit set, and is a dtype and layout that
2851+
NumPy supports. The returned ndarray and the tensor will share their
2852+
storage, so changes to the tensor will be reflected in the ndarray
2853+
and vice versa.
2854+
2855+
If :attr:`force` is ``True`` this is equivalent to
2856+
calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``.
2857+
If the tensor isn't on the CPU or the conjugate or negative bit is set,
2858+
the tensor won't share its storage with the returned ndarray.
2859+
Setting :attr:`force` to ``True`` can be a useful shorthand.
2860+
2861+
Args:
2862+
force (bool): if ``True``, the ndarray may be a copy of the tensor
2863+
instead of always sharing memory, defaults to ``False``.
28492864
""")
28502865

28512866
add_docstr_all('orgqr',

torch/csrc/utils/tensor_numpy.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,49 +105,55 @@ static std::vector<int64_t> seq_to_aten_shape(PyObject *py_seq) {
105105
return result;
106106
}
107107

108-
PyObject* tensor_to_numpy(const at::Tensor& tensor) {
108+
PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force/*=false*/) {
109109
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
110110

111-
TORCH_CHECK_TYPE(tensor.device().type() == DeviceType::CPU,
112-
"can't convert ", tensor.device().str().c_str(),
113-
" device type tensor to numpy. Use Tensor.cpu() to ",
114-
"copy the tensor to host memory first.");
111+
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(),
112+
".numpy() is not supported for tensor subclasses.");
115113

116114
TORCH_CHECK_TYPE(tensor.layout() == Layout::Strided,
117115
"can't convert ", c10::str(tensor.layout()).c_str(),
118-
" layout tensor to numpy.",
119-
"convert the tensor to a strided layout first.");
120-
121-
TORCH_CHECK(!(at::GradMode::is_enabled() && tensor.requires_grad()),
122-
"Can't call numpy() on Tensor that requires grad. "
123-
"Use tensor.detach().numpy() instead.");
124-
125-
TORCH_CHECK(!tensor.is_conj(),
126-
"Can't call numpy() on Tensor that has conjugate bit set. ",
127-
"Use tensor.resolve_conj().numpy() instead.");
116+
" layout tensor to numpy. ",
117+
"Use Tensor.dense() first.");
118+
119+
if (!force){
120+
TORCH_CHECK_TYPE(tensor.device().type() == DeviceType::CPU,
121+
"can't convert ", tensor.device().str().c_str(),
122+
" device type tensor to numpy. Use Tensor.cpu() to ",
123+
"copy the tensor to host memory first.");
124+
125+
TORCH_CHECK(!(at::GradMode::is_enabled() && tensor.requires_grad()),
126+
"Can't call numpy() on Tensor that requires grad. "
127+
"Use tensor.detach().numpy() instead.");
128+
129+
TORCH_CHECK(!tensor.is_conj(),
130+
"Can't call numpy() on Tensor that has conjugate bit set. ",
131+
"Use tensor.resolve_conj().numpy() instead.");
132+
133+
TORCH_CHECK(!tensor.is_neg(),
134+
"Can't call numpy() on Tensor that has negative bit set. "
135+
"Use tensor.resolve_neg().numpy() instead.");
136+
}
128137

129-
TORCH_CHECK(!tensor.is_neg(),
130-
"Can't call numpy() on Tensor that has negative bit set. "
131-
"Use tensor.resolve_neg().numpy() instead.");
138+
auto prepared_tensor = tensor.detach().cpu().resolve_conj().resolve_neg();
132139

133-
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");
140+
auto dtype = aten_to_numpy_dtype(prepared_tensor.scalar_type());
141+
auto sizes = to_numpy_shape(prepared_tensor.sizes());
142+
auto strides = to_numpy_shape(prepared_tensor.strides());
134143

135-
auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
136-
auto sizes = to_numpy_shape(tensor.sizes());
137-
auto strides = to_numpy_shape(tensor.strides());
138144
// NumPy strides use bytes. Torch strides use element counts.
139-
auto element_size_in_bytes = tensor.element_size();
145+
auto element_size_in_bytes = prepared_tensor.element_size();
140146
for (auto& stride : strides) {
141147
stride *= element_size_in_bytes;
142148
}
143149

144150
auto array = THPObjectPtr(PyArray_New(
145151
&PyArray_Type,
146-
tensor.dim(),
152+
prepared_tensor.dim(),
147153
sizes.data(),
148154
dtype,
149155
strides.data(),
150-
tensor.data_ptr(),
156+
prepared_tensor.data_ptr(),
151157
0,
152158
NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE,
153159
nullptr));
@@ -157,13 +163,13 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
157163
// object of the ndarray to the tensor and disabling resizes on the storage.
158164
// This is not sufficient. For example, the tensor's storage may be changed
159165
// via Tensor.set_, which can free the underlying memory.
160-
PyObject* py_tensor = THPVariable_Wrap(tensor);
166+
PyObject* py_tensor = THPVariable_Wrap(prepared_tensor);
161167
if (!py_tensor) throw python_error();
162168
if (PyArray_SetBaseObject((PyArrayObject*)array.get(), py_tensor) == -1) {
163169
return nullptr;
164170
}
165171
// Use the private storage API
166-
tensor.storage().unsafeGetStorageImpl()->set_resizable(false);
172+
prepared_tensor.storage().unsafeGetStorageImpl()->set_resizable(false);
167173

168174
return array.release();
169175
}

torch/csrc/utils/tensor_numpy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace torch { namespace utils {
77

8-
PyObject* tensor_to_numpy(const at::Tensor& tensor);
8+
PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force=false);
99
at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true);
1010

1111
int aten_to_numpy_dtype(const at::ScalarType scalar_type);

0 commit comments

Comments
 (0)
0