8000 Revert "[aoti] Assign proxy call args by name, and support default va… · pytorch/pytorch@106acf0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 106acf0

Browse files
Revert "[aoti] Assign proxy call args by name, and support default values. (#146263)"
This reverts commit 11f6980. Reverted #146263 on behalf of https://github.com/atalman due to multiple build failures, please see associated diff ([comment](#146263 (comment)))
1 parent e0f22e5 commit 106acf0

File tree

3 files changed

+34
-63
lines changed

3 files changed

+34
-63
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4275,23 +4275,6 @@ def forward(self, x):
42754275
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
42764276
)
42774277

4278-
def test_stft(self):
4279-
N_FFT = 400
4280-
HOP_LENGTH = 160
4281-
4282-
class Model(torch.nn.Module):
4283-
def forward(self, x):
4284-
window = torch.hann_window(N_FFT).to(x.device)
4285-
stft = torch.stft(
4286-
x, N_FFT, HOP_LENGTH, window=window, return_complex=True
4287-
)
4288-
magnitudes = stft[..., :-1].abs() ** 2
4289-
return magnitudes
4290-
4291-
model = Model()
4292-
example_inputs = (torch.randn(500, device=self.device),)
4293-
self.check_model(model, example_inputs)
4294-
42954278
def test_conv3d(self):
42964279
if self.device != GPU_TYPE or not is_big_gpu():
42974280
raise unittest.SkipTest("requires modern GPU to run max-autotune")

torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#include <nlohmann/json.hpp>
22
#include <fstream>
33
#include <iostream>
4-
#include <vector>
54

6-
#include <c10/util/Exception.h>
75
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
86

97
namespace {
@@ -15,7 +13,7 @@ at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
1513
namespace torch::aot_inductor {
1614

1715
void OSSProxyExecutor::prefill_stack_with_static_arguments(
18-
size_t index,
16+
int index,
1917
const at::TypePtr& schema_arg_type,
2018
const nlohmann::json& serialized_arg,
2119
OSSOpKernel& op_kernel) {
@@ -36,6 +34,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
3634
index,
3735
" but got ",
3836
serialized_arg_type);
37+
stack.emplace_back();
3938
dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1);
4039
break;
4140
}
@@ -48,6 +47,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
4847
index,
4948
" but got ",
5049
serialized_arg_type);
50+
stack.emplace_back();
5151
dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
5252
break;
5353
}
@@ -61,6 +61,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
6161
index,
6262
" but got ",
6363
serialized_arg_type);
64+
stack.emplace_back();
6465
dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
6566
break;
6667
}
@@ -73,7 +74,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
7374
index,
7475
" but got ",
7576
serialized_arg_type);
76-
stack.at(index) = serialized_arg_val.get<double>();
77+
stack.emplace_back(serialized_arg_val.get<double>());
7778
break;
7879
}
7980
case c10::TypeKind::BoolType: {
@@ -85,17 +86,18 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
8586
index,
8687
" but got ",
8788
serialized_arg_type);
88-
stack.at(index) = serialized_arg_val.get<bool>();
89+
stack.emplace_back(serialized_arg_val.get<bool>());
8990
break;
9091
}
9192
case c10::TypeKind::NumberType: {
9293
if (serialized_arg_type == "as_int") {
9394
// Only int Scalar is treated as dynamic arg for now
95+
stack.emplace_back();
9496
dynamic_args.emplace_back(index, DynamicArgType::IntType, 1);
9597
} else if (serialized_arg_type == "as_float") {
96-
stack.at(index) = serialized_arg_val.get<doub F438 le>();
98+
stack.emplace_back(serialized_arg_val.get<double>());
9799
} else if (serialized_arg_type == "as_bool") {
98-
stack.at(index) = serialized_arg_val.get<bool>();
100+
stack.emplace_back(serialized_arg_val.get<bool>());
99101
} else {
100102
TORCH_CHECK(
101103
false,
@@ -117,7 +119,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
117119
index,
118120
" but got ",
119121
serialized_arg_type);
120-
stack.at(index) = serialized_arg_val.get<std::string>();
122+
stack.emplace_back(serialized_arg_val.get<std::string>());
121123
break;
122124
}
123125
case c10::TypeKind::ScalarTypeType: {
@@ -129,7 +131,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
129131
index,
130132
" but got ",
131133
serialized_arg_type);
132-
stack.at(index) = serialized_arg_val.get<c10::ScalarType>();
134+
stack.emplace_back(serialized_arg_val.get<c10::ScalarType>());
133135
break;
134136
}
135137
case c10::TypeKind::MemoryFormatType: {
@@ -141,7 +143,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
141143
index,
142144
" but got ",
143145
serialized_arg_type);
144-
stack.at(index) = serialized_arg_val.get<c10::MemoryFormat>();
146+
stack.emplace_back(serialized_arg_val.get<c10::MemoryFormat>());
145147
break;
146148
}
147149
case c10::TypeKind::LayoutType: {
@@ -153,7 +155,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
153155
index,
154156
" but got ",
155157
serialized_arg_type);
156-
stack.at(index) = serialized_arg_val.get<c10::Layout>();
158+
stack.emplace_back(serialized_arg_val.get<c10::Layout>());
157159
break;
158160
}
159161
case c10::TypeKind::DeviceObjType: {
@@ -167,8 +169,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
167169
serialized_arg_type);
168170

169171
std::string device_string = serialized_arg_val["type"].get<std::string>();
170-
if (serialized_arg_val.contains("index") &&
171-
serialized_arg_val["index"].is_number()) {
172+
if (serialized_arg_val["index"].is_number()) {
172173
device_string += ":" + serialized_arg_val["index"].get<std::string>();
173174
}
174175

@@ -181,7 +182,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
181182
<< device << ". Please ensure this is intentional.";
182183
}
183184

184-
stack.at(index) = *device_;
185+
stack.emplace_back(*device_);
185186
break;
186187
}
187188
case c10::TypeKind::ListType: {
@@ -195,6 +196,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
195196
" but got ",
196197
serialized_arg_type);
197198
TORCH_CHECK(serialized_arg_type == "as_tensors");
199+
stack.emplace_back();
198200
dynamic_args.emplace_back(
199201
index, DynamicArgType::ListTensorType, serialized_arg_val.size());
200202
} else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) {
@@ -208,6 +210,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
208210
serialized_arg_type);
209211
dynamic_args.emplace_back(
210212
index, DynamicArgType::ListIntType, serialized_arg_val.size());
213+
stack.emplace_back();
211214
} else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) {
212215
TORCH_CHECK(
213216
serialized_arg_type == "as_ints" ||
@@ -220,6 +223,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
220223
serialized_arg_type);
221224
dynamic_args.emplace_back(
222225
index, DynamicArgType::ListIntType, serialized_arg_val.size());
226+
stack.emplace_back();
223227
} else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) {
224228
TORCH_CHECK(
225229
serialized_arg_type == "as_floats",
@@ -233,7 +237,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
233237
for (const auto& arg : serialized_arg_val) {
234238
ret.push_back(arg.get<double>());
235239
}
236-
stack.at(index) = std::move(ret);
240+
stack.emplace_back(ret);
237241
} else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) {
238242
TORCH_CHECK(
239243
serialized_arg_type == "as_bools",
@@ -247,23 +251,24 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
247251
for (const auto& arg : serialized_arg_val) {
248252
ret.push_back(arg.get<bool>());
249253
}
250-
stack.at(index) = std::move(ret);
254+
stack.emplace_back(ret);
251255
} else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) {
252256
if (serialized_arg_type == "as_ints") {
253257
dynamic_args.emplace_back(
254258
index, DynamicArgType::ListIntType, serialized_arg_val.size());
259+
stack.emplace_back();
255260
} else if (serialized_arg_type == "as_floats") {
256261
std::vector<double> ret;
257262
for (const auto& arg : serialized_arg_val) {
258263
ret.push_back(arg);
259264
}
260-
stack.at(index) = std::move(ret);
265+
stack.emplace_back(ret);
261266
} else if (serialized_arg_type == "as_bools") {
262267
std::vector<bool> ret;
263268
for (const auto& arg : serialized_arg_val) {
264269
ret.push_back(arg);
265270
}
266-
stack.at(index) = std::move(ret);
271+
stack.emplace_back(ret);
267272
} else {
268273
TORCH_CHECK(
269274
false,
@@ -281,12 +286,14 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
281286
for (const auto& arg : serialized_arg_val) {
282287
list_item_types.push_back(arg.begin().key());
283288
}
289+
stack.emplace_back();
284290
dynamic_args.emplace_back(
285291
index,
286292
DynamicArgType::ListOptionalTensorType,
287293
serialized_arg_val.size(),
288294
list_item_types);
289295
} else if (serialized_arg_type == "as_tensors") {
296+
stack.emplace_back();
290297
dynamic_args.emplace_back(
291298
index, DynamicArgType::ListTensorType, serialized_arg_val.size());
292299
} else {
@@ -312,7 +319,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
312319
for (const auto& arg : serialized_arg_val) {
313320
ret.push_back(arg.get<std::string>());
314321
}
315-
stack.at(index) = std::move(ret);
322+
stack.emplace_back(ret);
316323
} else {
317324
TORCH_CHECK(
318325
false,
@@ -330,7 +337,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
330337
schema_arg_type->castRaw<at::OptionalType>()->getElementType();
331338

332339
if (serialized_arg_type == "as_none") {
333-
stack.at(index) = c10::IValue{};
340+
stack.emplace_back(std::nullopt);
334341
if (inner_type->kind() == c10::TypeKind::TensorType) {
335342
// Tensor is None
336343
dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0);
@@ -374,35 +381,16 @@ void OSSProxyExecutor::get_input_info_from_serialized(
374381
const std::vector<c10::Argument>& schema_args,
375382
const nlohmann::json& serialized_node,
376383
OSSOpKernel& op_kernel) {
377-
std::vector<bool> filled(schema_args.size(), false);
378-
TORCH_CHECK(op_kernel.stack_.size() == 0);
379-
op_kernel.stack_.resize(schema_args.size());
384+
int index = 0;
380385
for (const auto& named_argument : serialized_node["inputs"]) {
381386
const auto& arg = named_argument["arg"];
382-
const auto& name = named_argument["name"];
383-
384-
// Doing a linear lookup in the schema to find the index
385-
// of a static argument. Should be fine performance wise
386-
// because we usually only have small amount of arguments.
387-
for (size_t index = 0; index < schema_args.size(); index++) {
388-
auto& schema_arg = schema_args[index];
389-
if (schema_arg.name() == name) {
390-
prefill_stack_with_static_arguments(
391-
index, schema_arg.real_type(), arg, op_kernel);
392-
filled[index] = true;
393-
break;
394-
}
395-
}
396-
}
387+
auto& schema_arg = schema_args[index];
397388

398-
// If an argument is not filled and has a default value, we should
399-
// also prefill the default value.
400-
for (size_t index = 0; index < schema_args.size(); index++) {
401-
if (!filled[index] && schema_args[index].default_value()) {
402-
auto default_value = *schema_args[index].default_value();
403-
op_kernel.stack_.at(index) = default_value;
404-
}
389+
prefill_stack_with_static_arguments(
390+
index++, schema_arg.real_type(), arg, op_kernel);
405391
}
392+
393+
// TODO: prefill default values
406394
}
407395

408396
// Populates op_kernel.outputs_

torch/csrc/inductor/aoti_torch/oss_proxy_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class OSSProxyExecutor : public ProxyExecutor {
8181

8282
private:
8383
void prefill_stack_with_static_arguments(
84-
size_t index,
84+
int index,
8585
const at::TypePtr& schema_arg_type,
8686
const nlohmann::json& serialized_arg,
8787
OSSOpKernel& op_kernel);

0 commit comments

Comments
 (0)
0