8000 Update torch::stable::Tensor() default constructor (#159507) · pytorch/pytorch@5e0d293 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e0d293

Browse files
mikaylagawareckichuanhaozhuge
authored andcommitted
Update torch::stable::Tensor() default constructor (#159507)
Allows things like ```cpp Tensor cu_seqlens_q; if (...) { cu_seqlens_q = ... } ... ``` Also adds `torch::stable::Tensor.defined()` Pull Request resolved: #159507 Approved by: https://github.com/janeyx99
1 parent d5fe6a6 commit 5e0d293

File tree

6 files changed

+89
-3
lines changed

6 files changed

+89
-3
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320320
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
321321
m.impl("my_zero_", &boxed_my_zero_);
322322
}
323+
324+
bool test_default_constructor(bool defined) {
325+
Tensor out;
326+
if (defined) {
327+
AtenTensorHandle defined_ath;
328+
int64_t sizes[] = {2, 3};
329+
int64_t strides[] = {3, 1};
330+
aoti_torch_empty_strided(
331+
2,
332+
sizes,
333+
strides,
334+
aoti_torch_dtype_float32(),
335+
aoti_torch_device_type_cpu(),
336+
0,
337+
&defined_ath);
338+
out = Tensor(defined_ath);
339+
}
340+
return out.defined();
341+
}
342+
343+
void boxed_test_default_constructor(
344+
StableIValue* stack,
345+
uint64_t num_args,
346+
uint64_t num_outputs) {
347+
bool res = test_default_constructor(to<bool>(stack[0]));
348+
stack[0] = from(res);
349+
}
350+
351+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
352+
m.def("test_default_constructor(bool undefined) -> bool");
353+
}
354+
355+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
356+
m.impl("test_default_constructor", &boxed_test_default_constructor);
357+
}

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor:
164164
Returns: The modified tensor (same as input)
165165
"""
166166
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
167+
168+
169+
def test_default_constructor(defined) -> bool:
170+
"""
171+
Tests the default constructor for torch::stable::Tensor.
172+
173+
Args:
174+
defined: bool - if True, tests defined tensor; if False, tests undefined tensor
175+
176+
Returns: bool - result of calling .defined() on the tensor
177+
"""
178+
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,20 @@ def test_fill_infinity(self, device):
218218
expected = torch.full_like(t, math.inf)
219219
self.assertEqual(out, expected)
220220

221+
@onlyCPU
222+
def test_default_constructor(self):
223+
import libtorch_agnostic
224+
225+
defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor(
226+
True
227+
)
228+
self.assertTrue(defined_tensor_is_defined)
229+
230+
undefined_tensor_is_defined = (
231+
libtorch_agnostic.ops.test_default_constructor(False)
232+
)
233+
self.assertFalse(undefined_tensor_is_defined)
234+
221235
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
222236

223237
if __name__ == "__main__":

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
227227
AOTI_TORCH_EXPORT AOTITorchError
228228
aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous);
229229

230+
AOTI_TORCH_EXPORT AOTITorchError
231+
aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined);
232+
230233
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
231234
AtenTensorHandle orig_handle,
232235
AtenTensorHandle* new_handle);

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous(
402402
});
403403
}
404404

405+
AOTITorchError aoti_torch_is_defined(
406+
AtenTensorHandle tensor,
407+
bool* ret_is_defined) {
408+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
409+
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
410+
*ret_is_defined = t->defined();
411+
});
412+
}
413+
405414
AOTITorchError aoti_torch_new_tensor_handle(
406415
AtenTensorHandle orig_handle,
407416
AtenTensorHandle* new_handle) {
@@ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) {
12041213
if (msg) {
12051214
std::cout << " " << msg;
12061215
}
1207-
std::cout << " "
1208-
<< "]:" << '\n';
1216+
std::cout << " " << "]:" << '\n';
12091217

12101218
// Print exact tensor values for small size tensors
12111219
const int64_t numel = t->numel();

torch/csrc/stable/tensor.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@ class Tensor {
2929
std::shared_ptr<AtenTensorOpaque> ath_;
3030

3131
public:
32-
Tensor() = delete;
32+
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
33+
// Steals ownership from the ATH
34+
Tensor() {
35+
AtenTensorHandle ret;
36+
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
37+
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
38+
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
39+
});
40+
}
3341

3442
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
3543
// Steals ownership from the ATH
@@ -115,6 +123,12 @@ class Tensor {
115123
return size;
116124
}
117125

126+
bool defined() const {
127+
bool defined;
128+
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
129+
return defined;
130+
}
131+
118132
// =============================================================================
119133
// END of C-shimified TensorBase APIs
120134
// =============================================================================

0 commit comments

Comments
 (0)
0