8000 print default value in FunctionSignature (#127059) · pytorch/pytorch@c9172d4 · GitHub
[go: up one dir, main page]

Skip to content

Commit c9172d4

Browse files
huihoaanpytorchmergebot
authored andcommitted
print default value in FunctionSignature (#127059)
Fixes #[126758](#126758) and #[126759](#126759) The output information in the issue is not accurate because `FunctionSignature::toString()` print the schema strings without default. https://github.com/pytorch/pytorch/blob/cb6ef68caa22c1a2f7a4e8583c0e7c923c8bfd17/torch/csrc/utils/python_arg_parser.cpp#L1282-L1283 This pr, by adding a `default_value` to save the default str ,which shoule be priented. Of course, can also add an new api to reverse `default_bool/default_int` to string, which is slightly more complicated. result: ![image](https://github.com/pytorch/pytorch/assets/37650440/f58a4cbf-b0f4-4c81-9106-59f0d35c54ea) Pull Request resolved: #127059 Approved by: https://github.com/janeyx99
1 parent 045309a commit c9172d4

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
12071207
} else {
12081208
throw std::runtime_error("unknown parameter type");
12091209
}
1210+
default_value = str;
12101211
}
12111212

12121213
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@@ -1280,7 +1281,6 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
12801281
}
12811282

12821283
std::string FunctionSignature::toString() const {
1283-
// TODO: consider printing more proper schema strings with defaults,
12841284
// optionals, etc.
12851285
std::ostringstream ss;
12861286
bool keyword_already = false;
@@ -1295,6 +1295,9 @@ std::string FunctionSignature::toString() const {
12951295
keyword_already = true;
12961296
}
12971297
ss << param.type_name() << " " << param.name;
1298+
if (param.optional) {
1299+
ss << " = " << param.default_value;
1300+
}
12981301
i++;
12991302
}
13001303
ss << ")";

torch/csrc/utils/python_arg_parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct FunctionParameter {
347347
at::ScalarType default_scalartype;
348348
at::Layout default_layout;
349349
};
350+
std::string default_value;
350351
};
351352

352353
template <int N>

0 commit comments

Comments
 (0)
0