8000 Make TS recognize input arg name (#73253) · cyyever/pytorch_private@3cf94ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 3cf94ef

Browse files
tugsbayasgalancyyever
authored andcommitted
Make TS recognize input arg name (#73253)
Summary: Pull Request resolved: pytorch/pytorch#73253 This PR allows TS schema_matching to match input arg with self for aten operators. This is because, operators in their functional form have input as paremeter instead of self. fixes: pytorch/pytorch#71994 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D34427556 Pulled By: tugsbayasgalan fbshipit-source-id: 96c2340d605c59634bf6e37db1db6025d93a933a (cherry picked from commit 45a593d73bc5e6308dd80a4a29afed8e318a0a1c)
1 parent 74ddffd commit 3cf94ef

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

test/test_jit.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15771,6 +15771,13 @@ def __init__(self,
1577115771

1577215772
torch.jit.script(M(2, 3))
1577315773

15774+
def test_input_keyword_in_schema(self):
15775+
def f(x):
15776+
return torch.ceil(input=x)
15777+
15778+
inp = torch.randn(10)
15779+
self.checkScript(f, (inp, ))
15780+
1577415781
def test_module_method_reassignment(self):
1577515782
class Foo(torch.nn.Module):
1577615783
def __init__(self):

torch/csrc/jit/frontend/schema_matching.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,17 @@ static Value* tryMatchArgument(
230230

231231
c10::optional<size_t> findInputWithName(
232232
const std::string& name,
233-
at::ArrayRef<NamedValue> kwargs) {
233+
at::ArrayRef<NamedValue> kwargs,
234+
bool is_aten) {
234235
for (const auto i : c10::irange(kwargs.size())) {
235-
if (kwargs[i].name() == name)
236+
// TS doesn't understand that the self argument in function
237+
// scheams is renamed to input for the functional variant
238+
if (is_aten && name == "self" && kwargs[i].name() == "input") {
239+
return i;
240+
}
241+
if (kwargs[i].name() == name) {
236242
return i;
243+
}
237244
}
238245
return c10::nullopt;
239246
}
@@ -342,6 +349,13 @@ static c10::optional<MatchedSchema> tryMatchSchema(
342349
std::vector<Value*> positional_inputs;
343350
std::vector<bool> used_kwarg(kwargs.size(), false);
344351

352+
auto schema_namespace = schema.operator_name().getNamespace();
353+
bool is_aten = false;
354+
if (schema_namespace.has_value()) {
355+
if (schema_namespace.value() == "aten") {
356+
is_aten = true;
357+
}
358+
}
345359
// if we finish the loop will we have consumed all arguments?
346360
size_t used_args = 0;
347361
for (const auto schema_i : c10::irange(schema.arguments().size())) {
@@ -386,7 +400,8 @@ static c10::optional<MatchedSchema> tryMatchSchema(
386400
// used
387401
actual_named_value = args[used_args];
388402
used_args++;
389-
} else if (auto kwarg_idx = findInputWithName(arg.name(), kwargs)) {
403+
} else if (
404+
auto kwarg_idx = findInputWithName(arg.name(), kwargs, is_aten)) {
390405
const NamedValue& nv = kwargs[*kwarg_idx];
391406
if (used_kwarg[*kwarg_idx]) {
392407
if (failure_messages) {

torch/csrc/jit/frontend/schema_matching.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ TORCH_API Value* emitBuiltinCall(
5353

5454
TORCH_API c10::optional<size_t> findInputWithName(
5555
const std::string& name,
56-
at::ArrayRef<NamedValue> kwargs);
56+
at::ArrayRef<NamedValue> kwargs,
57+
bool is_aten = false);
5758

5859
// applies implicit conversion from value trying to turn it into type
5960
// concrete_type it succeeds if the return_value->isSubtypeOf(concrete_type)

0 commit comments

Comments
 (0)
29D0
0