1
1
#include < nlohmann/json.hpp>
2
2
#include < fstream>
3
3
#include < iostream>
4
- #include < vector>
5
4
6
- #include < c10/util/Exception.h>
7
5
#include < torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
8
6
9
7
namespace {
@@ -15,7 +13,7 @@ at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
15
13
namespace torch ::aot_inductor {
16
14
17
15
void OSSProxyExecutor::prefill_stack_with_static_arguments (
18
- size_t index,
16
+ int index,
19
17
const at::TypePtr& schema_arg_type,
20
18
const nlohmann::json& serialized_arg,
21
19
OSSOpKernel& op_kernel) {
@@ -36,6 +34,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
36
34
index,
37
35
" but got " ,
38
36
serialized_arg_type);
37
+ stack.emplace_back ();
39
38
dynamic_args.emplace_back (index, DynamicArgType::TensorType, 1 );
40
39
break ;
41
40
}
@@ -48,6 +47,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
48
47
index,
49
48
" but got " ,
50
49
serialized_arg_type);
50
+ stack.emplace_back ();
51
51
dynamic_args.emplace_back (index, DynamicArgType::IntType, 1 );
52
52
break ;
53
53
}
@@ -61,6 +61,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
61
61
index,
62
62
" but got " ,
63
63
serialized_arg_type);
64
+ stack.emplace_back ();
64
65
dynamic_args.emplace_back (index, DynamicArgType::IntType, 1 );
65
66
break ;
66
67
}
@@ -73,7 +74,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
73
74
index,
74
75
" but got " ,
75
76
serialized_arg_type);
76
- stack.at (index) = serialized_arg_val.get <double >();
77
+ stack.emplace_back ( serialized_arg_val.get <double >() );
77
78
break ;
78
79
}
79
80
case c10::TypeKind::BoolType: {
@@ -85,17 +86,18 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
85
86
index,
86
87
" but got " ,
87
88
serialized_arg_type);
88
- stack.at (index) = serialized_arg_val.get <bool >();
89
+ stack.emplace_back ( serialized_arg_val.get <bool >() );
89
90
break ;
90
91
}
91
92
case c10::TypeKind::NumberType: {
92
93
if (serialized_arg_type == " as_int" ) {
93
94
// Only int Scalar is treated as dynamic arg for now
95
+ stack.emplace_back ();
94
96
dynamic_args.emplace_back (index, DynamicArgType::IntType, 1 );
95
97
} else if (serialized_arg_type == " as_float" ) {
96
- stack.at (index) = serialized_arg_val.get <doub
F438
le >();
98
+ stack.emplace_back ( serialized_arg_val.get <double >() );
97
99
} else if (serialized_arg_type == " as_bool" ) {
98
- stack.at (index) = serialized_arg_val.get <bool >();
100
+ stack.emplace_back ( serialized_arg_val.get <bool >() );
99
101
} else {
100
102
TORCH_CHECK (
101
103
false ,
@@ -117,7 +119,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
117
119
index,
118
120
" but got " ,
119
121
serialized_arg_type);
120
- stack.at (index) = serialized_arg_val.get <std::string>();
122
+ stack.emplace_back ( serialized_arg_val.get <std::string>() );
121
123
break ;
122
124
}
123
125
case c10::TypeKind::ScalarTypeType: {
@@ -129,7 +131,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
129
131
index,
130
132
" but got " ,
131
133
serialized_arg_type);
132
- stack.at (index) = serialized_arg_val.get <c10::ScalarType>();
134
+ stack.emplace_back ( serialized_arg_val.get <c10::ScalarType>() );
133
135
break ;
134
136
}
135
137
case c10::TypeKind::MemoryFormatType: {
@@ -141,7 +143,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
141
143
index,
142
144
" but got " ,
143
145
serialized_arg_type);
144
- stack.at (index) = serialized_arg_val.get <c10::MemoryFormat>();
146
+ stack.emplace_back ( serialized_arg_val.get <c10::MemoryFormat>() );
145
147
break ;
146
148
}
147
149
case c10::TypeKind::LayoutType: {
@@ -153,7 +155,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
153
155
index,
154
156
" but got " ,
155
157
serialized_arg_type);
156
- stack.at (index) = serialized_arg_val.get <c10::Layout>();
158
+ stack.emplace_back ( serialized_arg_val.get <c10::Layout>() );
157
159
break ;
158
160
}
159
161
case c10::TypeKind::DeviceObjType: {
@@ -167,8 +169,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
167
169
serialized_arg_type);
168
170
169
171
std::string device_string = serialized_arg_val[" type" ].get <std::string>();
170
- if (serialized_arg_val.contains (" index" ) &&
171
- serialized_arg_val[" index" ].is_number ()) {
172
+ if (serialized_arg_val[" index" ].is_number ()) {
172
173
device_string += " :" + serialized_arg_val[" index" ].get <std::string>();
173
174
}
174
175
@@ -181,7 +182,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
181
182
<< device << " . Please ensure this is intentional." ;
182
183
}
183
184
184
- stack.at (index) = *device_;
185
+ stack.emplace_back ( *device_) ;
185
186
break ;
186
187
}
187
188
case c10::TypeKind::ListType: {
@@ -195,6 +196,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
195
196
" but got " ,
196
197
serialized_arg_type);
197
198
TORCH_CHECK (serialized_arg_type == " as_tensors" );
199
+ stack.emplace_back ();
198
200
dynamic_args.emplace_back (
199
201
index, DynamicArgType::ListTensorType, serialized_arg_val.size ());
200
202
} else if (schema_arg_type->isSubtypeOf (at::ListType::ofInts ())) {
@@ -208,6 +210,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
208
210
serialized_arg_type);
209
211
dynamic_args.emplace_back (
210
212
index, DynamicArgType::ListIntType, serialized_arg_val.size ());
213
+ stack.emplace_back ();
211
214
} else if (schema_arg_type->isSubtypeOf (at::ListType::ofSymInts ())) {
212
215
TORCH_CHECK (
213
216
serialized_arg_type == " as_ints" ||
@@ -220,6 +223,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
220
223
serialized_arg_type);
221
224
dynamic_args.emplace_back (
222
225
index, DynamicArgType::ListIntType, serialized_arg_val.size ());
226
+ stack.emplace_back ();
223
227
} else if (schema_arg_type->isSubtypeOf (at::ListType::ofFloats ())) {
224
228
TORCH_CHECK (
225
229
serialized_arg_type == " as_floats" ,
@@ -233,7 +237,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
233
237
for (const auto & arg : serialized_arg_val) {
234
238
ret.push_back (arg.get <double >());
235
239
}
236
- stack.at (index) = std::move (ret);
240
+ stack.emplace_back (ret);
237
241
} else if (schema_arg_type->isSubtypeOf (at::ListType::ofBools ())) {
238
242
TORCH_CHECK (
239
243
serialized_arg_type == " as_bools" ,
@@ -247,23 +251,24 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
247
251
for (const auto & arg : serialized_arg_val) {
248
252
ret.push_back (arg.get <bool >());
249
253
}
250
- stack.at (index) = std::move (ret);
254
+ stack.emplace_back (ret);
251
255
} else if (schema_arg_type->isSubtypeOf (at::ListType::ofNumbers ())) {
252
256
if (serialized_arg_type == " as_ints" ) {
253
257
dynamic_args.emplace_back (
254
258
index, DynamicArgType::ListIntType, serialized_arg_val.size ());
259
+ stack.emplace_back ();
255
260
} else if (serialized_arg_type == " as_floats" ) {
256
261
std::vector<double > ret;
257
262
for (const auto & arg : serialized_arg_val) {
258
263
ret.push_back (arg);
259
264
}
260
- stack.at (index) = std::move (ret);
265
+ stack.emplace_back (ret);
261
266
} else if (serialized_arg_type == " as_bools" ) {
262
267
std::vector<bool > ret;
263
268
for (const auto & arg : serialized_arg_val) {
264
269
ret.push_back (arg);
265
270
}
266
- stack.at (index) = std::move (ret);
271
+ stack.emplace_back (ret);
267
272
} else {
268
273
TORCH_CHECK (
269
274
false ,
@@ -281,12 +286,14 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
281
286
for (const auto & arg : serialized_arg_val) {
282
287
list_item_types.push_back (arg.begin ().key ());
283
288
}
289
+ stack.emplace_back ();
284
290
dynamic_args.emplace_back (
285
291
index,
286
292
DynamicArgType::ListOptionalTensorType,
287
293
serialized_arg_val.size (),
288
294
list_item_types);
289
295
} else if (serialized_arg_type == " as_tensors" ) {
296
+ stack.emplace_back ();
290
297
dynamic_args.emplace_back (
291
298
index, DynamicArgType::ListTensorType, serialized_arg_val.size ());
292
299
} else {
@@ -312,7 +319,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
312
319
for (const auto & arg : serialized_arg_val) {
313
320
ret.push_back (arg.get <std::string>());
314
321
}
315
- stack.at (index) = std::move (ret);
322
+ stack.emplace_back (ret);
316
323
} else {
317
324
TORCH_CHECK (
318
325
false ,
@@ -330,7 +337,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
330
337
schema_arg_type->castRaw <at::OptionalType>()->getElementType ();
331
338
332
339
if (serialized_arg_type == " as_none" ) {
333
- stack.at (index) = c10::IValue{} ;
340
+ stack.emplace_back (std::nullopt) ;
334
341
if (inner_type->kind () == c10::TypeKind::TensorType) {
335
342
// Tensor is None
336
343
dynamic_args.emplace_back (index, DynamicArgType::TensorType, 0 );
@@ -374,35 +381,16 @@ void OSSProxyExecutor::get_input_info_from_serialized(
374
381
const std::vector<c10::Argument>& schema_args,
375
382
const nlohmann::json& serialized_node,
376
383
OSSOpKernel& op_kernel) {
377
- std::vector<bool > filled (schema_args.size (), false );
378
- TORCH_CHECK (op_kernel.stack_ .size () == 0 );
379
- op_kernel.stack_ .resize (schema_args.size ());
384
+ int index = 0 ;
380
385
for (const auto & named_argument : serialized_node[" inputs" ]) {
381
386
const auto & arg = named_argument[" arg" ];
382
- const auto & name = named_argument[" name" ];
383
-
384
- // Doing a linear lookup in the schema to find the index
385
- // of a static argument. Should be fine performance wise
386
- // because we usually only have small amount of arguments.
387
- for (size_t index = 0 ; index < schema_args.size (); index++) {
388
- auto & schema_arg = schema_args[index];
389
- if (schema_arg.name () == name) {
390
- prefill_stack_with_static_arguments (
391
- index, schema_arg.real_type (), arg, op_kernel);
392
- filled[index] = true ;
393
- break ;
394
- }
395
- }
396
- }
387
+ auto & schema_arg = schema_args[index];
397
388
398
- // If an argument is not filled and has a default value, we should
399
- // also prefill the default value.
400
- for (size_t index = 0 ; index < schema_args.size (); index++) {
401
- if (!filled[index] && schema_args[index].default_value ()) {
402
- auto default_value = *schema_args[index].default_value ();
403
- op_kernel.stack_ .at (index) = default_value;
404
- }
389
+ prefill_stack_with_static_arguments (
390
+ index++, schema_arg.real_type (), arg, op_kernel);
405
391
}
392
+
393
+ // TODO: prefill default values
406
394
}
407
395
408
396
// Populates op_kernel.outputs_
0 commit comments