10000 unwrapped super methods in shell helper methods · nipype/nipype2pydra@abf3551 · GitHub
[go: up one dir, main page]

Skip to content

Commit abf3551

Browse files
committed
unwrapped super methods in shell helper methods
1 parent 955e6c8 commit abf3551

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

nipype2pydra/interface/shell_command.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
UsedSymbols,
1414
split_source_into_statements,
1515
INBUILT_NIPYPE_TRAIT_NAMES,
16+
extract_args,
1617
find_super_method,
1718
)
1819
from fileformats.core.mixin import WithClassifiers
@@ -27,6 +28,7 @@
2728
@attrs.define(slots=False)
2829
class ShellCommandInterfaceConverter(BaseInterfaceConverter):
2930

31+
converter_type = "shell_command"
3032
_format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict)
3133

3234
@cached_property
@@ -237,11 +239,19 @@ def output_fields(self):
237239

238240
@property
239241
def formatted_input_field_names(self):
240-
return re.findall(r"name == \"(\w+)\"", self._format_arg_body)
242+
if not self._format_arg_body:
243+
return []
244+
sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0]
245+
name_arg = re.match(r"\s*def _format_arg\(self, (\w+),", sig).group(1)
246+
return re.findall(name_arg + r" == \"(\w+)\"", self._format_arg_body)
241247

242248
@property
243249
def callable_default_input_field_names(self):
244-
return re.findall(r"name == \"(\w+)\"", self._gen_filename_body)
250+
if not self._gen_filename_body:
251+
return []
252+
sig = inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[0]
253+
name_arg = re.match(r"\s*def _gen_filename\((\w+),", sig).group(1)
254+
return re.findall(name_arg + r" == \"(\w+)\"", self._gen_filename_body)
245255

246256
@property
247257
def callable_output_fields(self):
@@ -262,17 +272,13 @@ def callable_output_field_names(self):
262272
def _format_arg_body(self):
263273
if self.method_omitted("_format_arg"):
264274
return ""
265-
return _strip_doc_string(
266-
inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1]
267-
)
275+
return self._unwrap_supers(self.nipype_interface._format_arg)
268276

269277
@cached_property
270278
def _gen_filename_body(self):
271279
if self.method_omitted("_gen_filename"):
272280
return ""
273-
return _strip_doc_string(
274-
inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1]
275-
)
281+
return self._unwrap_supers(self.nipype_interface._gen_filename)
276282

277283
@property
278284
def format_arg_code(self):
@@ -333,9 +339,7 @@ def format_arg_code(self):
333339
def parse_inputs_code(self) -> str:
334340
if "_parse_inputs" not in self.included_methods:
335341
return ""
336-
body = _strip_doc_string(
337-
inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1]
338-
)
342+
body = self._unwrap_supers(self.nipype_interface._parse_inputs)
339343
body = self._process_inputs(body)
340344
body = re.sub(
341345
r"self.\_format_arg\((\w+), (\w+), (\w+)\)",
@@ -412,11 +416,7 @@ def callables_code(self):
412416
code_str = ""
413417
if "aggregate_outputs" in self.included_methods:
414418
func_name = "aggregate_outputs"
415-
agg_body = _strip_doc_string(
416-
inspect.getsource(self.nipype_interface.aggregate_outputs).split(
417-
"\n", 1
418-
)[-1]
419-
)
419+
agg_body = self._unwrap_supers(self.nipype_interface.aggregate_outputs)
420420
need_list_outputs = bool(re.findall(r"\b_list_outputs\b", agg_body))
421421
agg_body = self._process_inputs(agg_body)
422422

@@ -476,11 +476,7 @@ def callables_code(self):
476476

477477
return code_str
478478
else:
479-
lo_body = _strip_doc_string(
480-
inspect.getsource(self.nipype_interface._list_outputs).split(
481-
"\n", 1
482-
)[-1]
483-
)
479+
lo_body = self._unwrap_supers(self.nipype_interface._list_outputs)
484480
lo_body = self._process_inputs(lo_body)
485481

486482
if not lo_body:
@@ -538,6 +534,32 @@ def method_omitted(self, method_name: str) -> bool:
538534
find_super_method(self.nipype_interface, method_name, include_class=True)[1]
539535
)
540536

537+
def _unwrap_supers(
538+
self, method: ty.Callable, base=None, base_replacement="", arg_names=None
539+
) -> str:
540+
if base is None:
541+
base = self.nipype_interface
542+
if self.package.is_omitted(base):
543+
return base_replacement
544+
method_name = method.__name__
545+
sig, body = inspect.getsource(method).split("\n", 1)
546+
body = _strip_doc_string(body)
547+
args = extract_args(sig)[1][1:]
548+
if arg_names:
549+
for new, old 341A in zip(args, arg_names):
550+
if new != old:
551+
body = re.sub(r"\b" + old + r"\b", new, body)
552+
super_re = re.compile(
553+
r"\n\s*(return )?super\([^\)]*\)\." + method_name + r"\([^\)]+\)"
554+
)
555+
if super_re.search(body):
556+
super_method, base = find_super_method(base, method_name)
557+
super_body = self._unwrap_supers(
558+
super_method, base, base_replacement, arg_names=args
559+
)
560+
body = super_re.sub("\n" + super_body, body)
561+
return body
562+
541563

542564
def _strip_doc_string(body: str) -> str:
543565
if re.match(r"\s*(\"|')", body):

nipype2pydra/utils/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,14 +551,14 @@ def get_return_line(func: ty.Union[str, ty.Callable]) -> str:
551551

552552
def find_super_method(
553553
super_base: type, method_name: str, include_class: bool = False
554-
) -> ty.Optional[ty.Tuple[ty.Callable, type]]:
554+
) -> ty.Tuple[ty.Optional[ty.Callable], ty.Optional[type]]:
555555
mro = super_base.__mro__
556556
if not include_class:
557557
mro = mro[1:]
558558
for base in mro:
559559
if method_name in base.__dict__: # Found the match
560560
return getattr(base, method_name), base
561-
return None
561+
return None, None
562562
# raise RuntimeError(
563563
# f"Could not find super of '{method_name}' method in base classes of "
564564
# f"{super_base}"

0 commit comments

Comments
 (0)
0