8000 Add shim.h C API to call dispatcher on our own aten ops (#148832) · pytorch/pytorch@e6ef062 · GitHub
[go: up one dir, main page]

Skip to content

Commit e6ef062

Browse files
janeyx99pytorchmergebot
authored andcommitted
Add shim.h C API to call dispatcher on our own aten ops (#148832)
This PR still needs testing through some cpp extension Pull Request resolved: #148832 Approved by: https://github.com/albanD, https://github.com/atalman ghstack dependencies: #148124
1 parent cf19efd commit e6ef062

File tree

6 files changed

+163
-0
lines changed

6 files changed

+163
-0
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,25 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
125125
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
126126
m.impl("identity", &boxed_identity);
127127
}
128+
129+
RAIIATH my_abs(RAIIATH t) {
130+
const auto num_args = 1;
131+
StableIValue stack[num_args];
132+
stack[0] = from(t.release());
133+
aoti_torch_call_dispatcher("aten::abs", "", stack);
134+
return RAIIATH(to<AtenTensorHandle>(stack[0]));
135+
}
136+
137+
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
138+
RAIIATH t(to<AtenTensorHandle>(stack[0]));
139+
RAIIATH raiiath_res = my_abs(std::move(t));
140+
stack[0] = from(raiiath_res.release());
141+
}
142+
143+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
144+
m.def("my_abs(Tensor t) -> Tensor");
145+
}
146+
147+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
148+
m.impl("my_abs", &boxed_my_abs);
149+
}

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,16 @@ def identity(t) -> Tensor:
3636
a Tensor, the same as input.
3737
"""
3838
return torch.ops.libtorch_agnostic.identity.default(t)
39+
40+
41+
def my_abs(t) -> Tensor:
42+
"""
43+
Returns abs on the input tensor, outputs a new Tensor
44+
45+
Args:
46+
t: any Tensor
47+
48+
Returns:
49+
a Tensor
50+
"""
51+
return torch.ops.libtorch_agnostic.my_abs.default(t)

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ def _run_identity(prior_mem):
5252
curr_mem = torch.cuda.memory_allocated(device)
5353
self.assertEqual(curr_mem, init_mem)
5454

55+
def test_my_abs(self, device):
56+
t = torch.rand(32, 16, device=device)
57+
cpu_t = libtorch_agnostic.ops.my_abs(t)
58+
self.assertEqual(cpu_t, torch.abs(t))
59+
60+
def _make_cuda_tensors(prior_mem):
61+
cuda_t = libtorch_agnostic.ops.my_abs(t)
62+
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
63+
self.assertEqual(cuda_t, torch.abs(t))
64+
65+
if t.is_cuda:
66+
init_mem = torch.cuda.memory_allocated(device)
67+
for _ in range(3):
68+
_make_cuda_tensors(init_mem)
69+
curr_mem = torch.cuda.memory_allocated(device)
70+
self.assertEqual(curr_mem, init_mem)
71+
5572
@onlyCUDA
5673
def test_z_delete_torch_lib(self, device):
5774
# Why the z + CUDA? THIS TEST MUST BE RUN LAST

test/test_cpp_extensions_aot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,23 @@ def _run_identity(prior_mem, device):
270270
curr_mem = torch.cuda.memory_allocated(device)
271271
self.assertEqual(curr_mem, init_mem)
272272

273+
# (3) test calling our dispatcher on ones_like
274+
t = torch.rand(32, 16, device=device)
275+
cpu_t = libtorch_agnostic.ops.my_abs(t)
276+
self.assertEqual(cpu_t, torch.abs(t))
277+
278+
def _make_cuda_tensors(prior_mem):
279+
cuda_t = libtorch_agnostic.ops.my_abs(t)
280+
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
281+
self.assertEqual(cuda_t, torch.abs(t))
282+
283+
if t.is_cuda:
284+
init_mem = torch.cuda.memory_allocated(device)
285+
for _ in range(3):
286+
_make_cuda_tensors(init_mem)
287+
curr_mem = torch.cuda.memory_allocated(device)
288+
self.assertEqual(curr_mem, init_mem)
289+
273290

274291
@torch.testing._internal.common_utils.markDynamoStrictTest
275292
class TestPybindTypeCasters(common.TestCase):

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,15 @@ aoti_torch_library_def(TorchLibraryHandle self, const char* schema);
677677
AOTI_TORCH_EXPORT AOTITorchError
678678
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
679679

680+
// calls the op overload defined by a given opName, overloadName, and a
681+
// stack of StableIValues. This call will populate any return values of the
682+
// op into the stack in their StableIValue form, with ret0 at index 0, ret1
683+
// at index 1, and so on.
684+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
685+
const char* opName,
686+
const char* overloadName,
687+
StableIValue* stack);
688+
680689
#ifdef USE_CUDA
681690

682691
struct CUDAGuardOpaque;

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 85 additions & 0 deletions
1485
Original file line numberDiff line numberDiff line change
@@ -1429,3 +1429,88 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
14291429
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
14301430
{ delete reinterpret_cast<torch::Library*>(tlh); });
14311431
}
1432+
1433+
static c10::IValue to_ivalue(
1434+
const c10::TypePtr& arg_type,
1435+
const StableIValue stable_ivalue) {
1436+
switch (arg_type->kind()) {
1437+
case c10::TypeKind::TensorType: {
1438+
// stable_ivalue must be an ATH
1439+
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
1440+
to<AtenTensorHandle>(stable_ivalue));
1441+
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
1442+
ret_raiiath.get());
1443+
return (c10::IValue(arg));
1444+
}
1445+
case c10::TypeKind::IntType: {
1446+
return c10::IValue(to<int64_t>(stable_ivalue));
1447+
}
1448+
case c10::TypeKind::FloatType: {
1449+
return c10::IValue(to<double>(stable_ivalue));
1450+
}
1451+
case c10::TypeKind::BoolType: {
1452+
return c10::IValue(to<bool>(stable_ivalue));
1453+
}
1454+
case c10::TypeKind::ScalarTypeType: {
1455+
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
1456+
}
1457+
case c10::TypeKind::LayoutType: {
1458+
return c10::IValue(to<c10::Layout>(stable_ivalue));
1459+
}
1460+
case c10::TypeKind::MemoryFormatType: {
1461+
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
1462+
}
1463+
default: {
1464+
TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str());
1465+
}
1466+
}
1467+
}
1468+
1469+
AOTITorchError aoti_torch_call_dispatcher(
1470+
const char* opName,
1471+
const char* overloadName,
1472+
StableIValue* stack) {
1473+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1474+
static auto op =
1475+
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
1476+
1477+
const auto& schema = op.schema();
1478+
const auto num_returns = schema.returns().size();
1479+
const auto num_arguments = schema.arguments().size();
1480+
1481+
torch::jit::Stack ivalue_stack;
1482+
// we will only need max(num_args, num_returns)
1483+
ivalue_stack.reserve(std::max(num_arguments, num_returns));
1484+
+
// convert StableIValue stack to c10::IValue stack
1486+
for (const auto idx : c10::irange(num_arguments)) {
1487+
auto stable_ivalue = stack[idx];
1488+
auto arg_type = schema.arguments()[idx].type();
1489+
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
1490+
}
1491+
1492+
op.callBoxed(ivalue_stack);
1493+
1494+
// there should then be num_returns IValues on the stack, which
1495+
// we will convert to StableIValue and repopulate user input stack
1496+
for (const auto idx : c10::irange(num_returns)) {
1497+
const c10::IValue& ret = torch::jit::pop(ivalue_stack);
1498+
const auto stack_idx = num_returns - idx - 1;
1499+
if (ret.isInt()) {
1500+
stack[stack_idx] = from(ret.toInt());
1501+
} else if (ret.isDouble()) {
1502+
stack[stack_idx] = from(ret.toDouble());
1503+
} else if (ret.isBool()) {
1504+
stack[stack_idx] = from(ret.toBool());
1505+
} else if (ret.isNone()) {
1506+
stack[stack_idx] = from(nullptr);
1507+
} else if (ret.isTensor()) {
1508+
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
1509+
std::move(const_cast<at::Tensor&>(ret.toTensor())));
1510+
stack[stack_idx] = from(ath);
1511+
} else {
1512+
TORCH_CHECK(false, "Other types of IValue returns not yet handled!");
1513+
}
1514+
}
1515+
});
1516+
}

0 commit comments

Comments
 (0)
0