8000 [caffe2] Export operators to c10 without including ATen/Tensor.h · pytorch/pytorch@da06b0a · GitHub
[go: up one dir, main page]

Skip to content

Commit da06b0a

Browse files
committed
[caffe2] Export operators to c10 without including ATen/Tensor.h
With this change, the only caffe2 files that depend on `ATen/Tensor.h` are ones that directly use the ATen API. Specifically, ``` [ "caffe2/CMakeFiles/torch_cuda_cpp.dir/contrib/aten/aten_op_gpu.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/core/tensor.cc.o", "caffe2/CMakeFiles/torch_cuda_cpp.dir/operators/layer_norm_op.cu.o", "caffe2/CMakeFiles/torch_cpu.dir/core/IValueInterface.cc.o", "caffe2/CMakeFiles/cuda_tensor_interop_test.dir/__/aten/src/ATen/test/cuda_tensor_interop_test.cpp.o", "caffe2/CMakeFiles/torch_cpu.dir/contrib/aten/aten_op.cc.o", "caffe2/CMakeFiles/caffe2_pybind11_state_gpu.dir/python/pybind_state.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/operators/layer_norm_op.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/core/export_c10_op_to_caffe2.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/core/export_caffe2_op_to_c10.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/operators/enforce_finite_op.cc.o", "caffe2/CMakeFiles/torch_cpu.dir/core/operator.cc.o", "caffe2/CMakeFiles/tensor_interop_test.dir/__/aten/src/ATen/test/tensor_interop_test.cpp.o", "caffe2/CMakeFiles/caffe2_pybind11_state.dir/python/pybind_state.cc.o" ] ``` ghstack-source-id: 990ba65 Pull Request resolved: #67096
1 parent 04b21c3 commit da06b0a

File tree

3 files changed

+237
-171
lines changed

3 files changed

+237
-171
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
2+
#include <caffe2/core/export_caffe2_op_to_c10.h>
3+
#undef TORCH_ASSERT_NO_OPERATORS
4+
5+
#if defined(EXPOSE_C2_OPS) || \
6+
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
7+
8+
#include <ATen/core/function_schema.h>
9+
#include <ATen/core/dispatch/Dispatcher.h>
10+
#include <torch/csrc/jit/frontend/function_schema_parser.h>
11+
#include <torch/library.h>
12+
13+
namespace caffe2 {
14+
namespace detail {
15+
16+
// This function is inline in the hope that compilers optimizing for speed will
17+
// inline it into call_caffe2_op_from_c10, allowing call_op to be inlined and
18+
// avoiding the function pointer indirection, while compilers optimizing for
19+
// binary size will keep it a separate function instead of inlining it into
20+
// a template and will reuse the binary code of this function between ops.
21+
// We measured and confirmed that binary size off the instagram ios app is
22+
// reduced when having _call_caffe2_op_from_c10 separate from the templated
23+
// call_caffe2_op_from_c10.
24+
void call_caffe2_op_from_c10(
25+
const OperatorHandle &opHandle,
26+
c10::Stack* stack,
27+
_CallCaffe2OpFunc* call_op) {
28+
// precondition: on the stack, there's one IValue for each argument of the
29+
// c10 schema. The last argument is an optional tensor list that
30+
// (if not ivalue::None) contains a preallocated output tensor for each
31+
// operator output.
32+
33+
// As an invariant, we don't want any autograd gradients to be tracked in
34+
// Caffe2 operators.
35+
at::NoGradGuard guard;
36+
37+
const auto &schema = opHandle.schema();
38+
AT_ASSERT(
39+
schema.arguments().size() != 0 &&
40+
schema.arguments().back().type()->isSubtypeOf(
41+
*OptionalType::create(ListType::ofTensors())));
42+
IValue preallocated_outputs = torch::jit::pop(*stack);
43+
44+
const size_t num_outputs = schema.returns().size();
45+
const size_t num_inputs = schema.arguments().size() -
46+
1; // -1 because the last argument is the list of preallocated tensors
47+
48+
c10::List<at::Tensor> outputs;
49+
if (preallocated_outputs.isNone()) {
50+
// either the schema doesn't support preallocated outputs or it does but
51+
// they haven't been passed in. Pass a list of uninitialized tensors to
52+
// the caffe2 operator as preallocated outputs.
53+
outputs.resize(num_outputs);
54+
} else {
55+
AT_ASSERT(preallocated_outputs.isTensorList());
56+
outputs = std::move(preallocated_outputs).toTensorList();
57+
}
58+
59+
// TODO Avoid vector allocation. One idea would be to keep the std::vector
60+
// instances in the cache.
61+
std::vector<IValue> inputs = torch::jit::pop(*stack, num_inputs);
62+
63+
// Convert outputs to caffe2::Tensor
64+
c10::SmallVector<caffe2::Tensor, 6> outputs_c2(num_outputs);
65+
for (auto i : c10::irange(num_outputs)) {
66+
outputs_c2[i] = caffe2::Tensor(outputs.get(i));
67+
}
68+
69+
const StreamId stream(-1);
70+
auto new_outputs_c2 = (*call_op)(schema, std::move(inputs), outputs_c2, stream);
71+
72+
73+
bool return_tensor_list = false;
74+
if (schema.returns().size() == 1) {
75+
auto type = schema.returns()[0].type();
76+
if (c10::ListTypePtr list_type = type->cast<c10::ListType>()) {
77+
if (list_type->getElementType()->kind() == c10::TypeKind::TensorType) {
78+
return_tensor_list = true;
79+
}
80+
}
81+
}
82+
if (return_tensor_list) {
83+
for (auto i : c10::irange(num_outputs)) {
84+
outputs.set(i, at::Tensor(std::move(new_outputs_c2[i])));
85+
}
86+
torch::jit::push(*stack, outputs);
87+
} else {
88+
for (auto i : c10::irange(num_outputs)) {
89+
torch::jit::push(*stack, at::Tensor(std::move(new_outputs_c2[i])));
90+
}
91+
}
92+
93+
// postcondition: All inputs are cleared from the stack, there's now one
94+
// IValue for each output which holds the result. This
95+
// might reuse one of the preallocated tensors but doesn't have
96+
// to.
97+
}
98+
99+
static FunctionSchema make_function_schema_for_c10(const char* schema_str) {
100+
#if !defined(EXPOSE_C2_OPS) && \
101+
(defined(CAFFE2_IS_XPLAT_BUILD) || defined(C10_MOBILE))
102+
throw std::logic_error(
103+
"We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
104+
#else
105+
c10::FunctionSchema parsed_schema = torch::jit::parseSchema(schema_str);
106+
std::vector<c10::Argument> arguments = parsed_schema.arguments();
107+
arguments.emplace_back(
108+
PREALLOCATED_OUTPUT_ARGNAME,
109+
c10::OptionalType::create(c10::ListType::ofTensors()),
110+
nullopt,
111+
IValue());
112+
113+
return FunctionSchema(
114+
parsed_schema.name(),
115+
parsed_schema.overload_name(),
116+
std::move(arguments),
117+
parsed_schema.returns(),
118+
parsed_schema.is_vararg(),
119+
parsed_schema.is_varret());
120+
#endif
121+
}
122+
123+
InitCPUDefinition::InitCPUDefinition(const char *name, KernelFunction func) {
124+
static torch::Library cpu_lib(
125+
torch::Library::IMPL, "_caffe2", c10::DispatchKey::CPU,
126+
__FILE__, __LINE__);
127+
if (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::CPU)) {
128+
cpu_lib.def(name, torch::CppFunction::makeFromKernelFunction(func));
129+
}
130+
}
131+
132+
InitCUDADefinition::InitCUDADefinition(const char *name, KernelFunction func) {
133+
static torch::Library cuda_lib(
134+
torch::Library::IMPL, "_caffe2", c10::DispatchKey::CUDA,
135+
__FILE__, __LINE__);
136+
if (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::CUDA)) {
137+
cuda_lib.def(name, torch::CppFunction::makeFromKernelFunction(func));
138+
}
139+
}
140+
141+
InitHIPDefinition::InitHIPDefinition(const char *name, KernelFunction func) {
142+
static torch::Library hip_lib(
143+
torch::Library::IMPL, "_caffe2", c10::DispatchKey::HIP,
144+
__FILE__, __LINE__);
145+
if (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::HIP)) {
146+
hip_lib.def(name, torch::CppFunction::makeFromKernelFunction(func));
147+
}
148+
}
149+
150+
InitSchema::InitSchema(const char *schema_str) {
151+
static torch::Library schema_lib(
152+
torch::Library::FRAGMENT, "_caffe2", c10::nullopt,
153+
__FILE__, __LINE__);
154+
schema_lib.def(make_function_schema_for_c10(schema_str));
155+
}
156+
157+
} // namespace detail
158+
} // namespace caffe2
159+
160+
#endif

0 commit comments

Comments
 (0)
0