10000 Cleanup _libcall and _usercall distinction (#199) · microsoft/onnxscript@830f8dc · GitHub
[go: up one dir, main page]

Skip to content

Commit 830f8dc

Browse files
Cleanup _libcall and _usercall distinction (#199)
Cleanup _libcall and _usercall distinction Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 7d8afbf commit 830f8dc

File tree

1 file changed

+96
-69
lines changed

1 file changed

+96
-69
lines changed

onnxscript/values.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,91 @@ class OnnxClosure:
142142
function: Any
143143

144144

145+
UserModeValue = Any
146+
EagerModeValue = Any
147+
ExtendedModeValue = Any
148+
149+
# UserModeValue = Union[Optional[np.ndarray], List["UserModeValue"], Tuple["UserModeValue", ...]]
150+
151+
# EagerModeValue = Union[
152+
# Optional["tensor.Tensor"], List["EagerModeValue"], Tuple["EagerModeValue", ...]
153+
# ]
154+
155+
# ExtendedModeValue = Union[
156+
# Optional["tensor.Tensor"],
157+
# List["ExtendedModeValue"],
158+
# Tuple["ExtendedModeValue", ...],
159+
# np.ndarray,
160+
# int,
161+
# float,
162+
# bool,
163+
# ]
164+
165+
166+
def _adapt_to_eager_mode(inputs: ExtendedModeValue) -> EagerModeValue:
167+
"""Adapts inputs into representation used by onnxscript eager mode.
168+
169+
This does the following transformations:
170+
* It adds an onnxscript Tensor wrapper around numpy arrays, which
171+
allows the use of overloaded operators like + to be controlled by onnxscript.
172+
* It also provides a promotion of scalars into tensors as a convenience.
173+
This is needed to complement the similar promotion supported by the
174+
onnxscript converter (for example, when an attribute is promoted and used
175+
as an input argument).
176+
177+
Args:
178+
inputs: a list/tuple of inputs to an ONNX function
179+
180+
Returns:
181+
a pair (wrapped_inputs, flag) where flag indicates whether any numpy array
182+
was wrapped into a Tensor.
183+
"""
184+
has_array = False
185+
186+
def adapt(input: ExtendedModeValue) -> EagerModeValue:
187+
if isinstance(input, np.ndarray):
188+
nonlocal has_array
189+
has_array = True
190+
return tensor.Tensor(input)
191+
elif isinstance(input, tensor.Tensor):
192+
return input
193+
elif isinstance(input, (bool, int, float)):
194+
return tensor.Tensor(np.array(input))
195+
elif input is None:
196+
return None
197+
elif isinstance(input, list):
198+
return [adapt(elt) for elt in input]
199+
elif isinstance(input, tuple):
200+
return tuple(adapt(elt) for elt in input)
201+
raise TypeError(f"Unexpected input type {type(input)}.")
202+
203+
result = adapt(inputs)
204+
return result, has_array
205+
206+
207+
def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue:
208+
"""Unwraps Tensor wrapper around numpy arrays.
209+
210+
Args:
211+
output: output of an ONNX function, which can be either a single
212+
onnx value or a list/tuple of onnx values.
213+
214+
Returns:
215+
unwrapped output
216+
"""
217+
if isinstance(output, tensor.Tensor):
218+
return output.value
219+
elif output is None:
220+
return None
221+
elif isinstance(output, list):
222+
return [_adapt_to_user_mode(elt) for elt in output]
223+
elif isinstance(output, tuple):
224+
return tuple(_adapt_to_user_mode(elt) for elt in output)
225+
elif isinstance(output, np.ndarray):
226+
return output
227+
raise TypeError(f"Unexpected type {type(output)}.")
228+
229+
145230
class OnnxFunction(Op):
146231
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
147232
@@ -185,75 +270,17 @@ def fun(*args, **kwargs):
185270

186271
def __call__(self, *args, **kwargs):
187272
"""Implements an eager-mode execution of an onnxscript function."""
188-
if len(args) == 0:
189-
# Operator Constant, it is usually called within a function.
190-
return self._libcall(**kwargs)
191-
if isinstance(args[0], tensor.Tensor):
192-
return self._libcall(*args, **kwargs)
193-
return self._usercall(*args, **kwargs)
194-
195-
def _usercall(self, *args, **kwargs):
196-
"""Eager mode"""
197-
new_args = []
198-
for i, a in enumerate(args):
199-
if isinstance(a, np.ndarray):
200-
new_args.append(tensor.Tensor(a))
201-
elif isinstance(a, (bool, int, float)):
202-
new_args.append(tensor.Tensor(np.array(a)))
203-
else:
204-
raise TypeError(f"Unexpected input type {type(a)} for an input {i}.")
205-
res = self.function(*new_args, **kwargs)
206-
if isinstance(res, np.ndarray):
207-
return res
208-
if isinstance(res, tensor.Tensor):
209-
return res.value
210-
if isinstance(res, (list, tuple)):
211-
unwrapped = []
212-
for i, r in enumerate(res):
213-
if isinstance(r, np.ndarray):
214-
unwrapped.append(r)
215-
elif isinstance(r, tensor.Tensor):
216-
unwrapped.append(r.value)
217-
else:
218-
raise TypeError(
219-
f"Unexpected output type {type(r)} for an output {i} "
220-
f"in function {self.function!r}."
221-
)
222-
if isinstance(res, tuple):
223-
return tuple(unwrapped)
224-
return unwrapped
225-
raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.")
226-
227-
def _libcall(self, *args, **kwargs):
228-
"""This method must be called when a function decoracted with `script`
229-
calls another one decorated with `script`.
230-
"""
231-
new_args = []
232-
for i, a in enumerate(args):
233-
if isinstance(a, tensor.Tensor):
234-
new_args.append(a)
235-
elif isinstance(a, bool):
236-
# TODO: default values for function parameters
237-
# are not properly handled yet. This section
238-
# should disappear.
239-
new_args.append(tensor.Tensor(np.array(a)))
240-
else:
241-
raise TypeError(f"Unexpected input type {type(a)} for an input {i}.")
242-
res = self.function(*new_args, **kwargs)
243-
if isinstance(res, tensor.Tensor):
244-
return res
245-
if isinstance(res, tuple):
246-
unwrapped = []
247-
for i, r in enumerate(res):
248-
if isinstance(r, tensor.Tensor):
249-
unwrapped.append(r)
250-
else:
251-
raise TypeError(
252-
f"Unexpected output type {type(r)} for an output {i} "
253-
f"in function {self.function!r}."
254-
)
255-
return tuple(unwrapped)
256-
raise TypeError(f"Unexpected output type {type(res)} in function {self.function!r}.")
273+
new_args, has_array = _adapt_to_eager_mode(args)
274+
result = self.function(*new_args, **kwargs)
275+
276+
# We use a heuristic to decide whether to return output values as
277+
# numpy arrays or tensor.Tensors. If the function has at least one
278+
# numpy array as input, we return numpy arrays. Otherwise, we return
279+
# tensor.Tensors. We could use a user-specified flag to control this
280+
# or explicitly track whether this is a top-level function-call or
281+
# a nested function-call.
282+
283+
return _adapt_to_user_mode(result) if has_array else result
257284

258285
def to_function_proto(self, domain=None):
259286
"""Converts the function into :class:`onnx.FunctionProto`."""

0 commit comments

Comments
 (0)
0