@@ -142,6 +142,91 @@ class OnnxClosure:
142
142
function : Any
143
143
144
144
145
+ UserModeValue = Any
146
+ EagerModeValue = Any
147
+ ExtendedModeValue = Any
148
+
149
+ # UserModeValue = Union[Optional[np.ndarray], List["UserModeValue"], Tuple["UserModeValue", ...]]
150
+
151
+ # EagerModeValue = Union[
152
+ # Optional["tensor.Tensor"], List["EagerModeValue"], Tuple["EagerModeValue", ...]
153
+ # ]
154
+
155
+ # ExtendedModeValue = Union[
156
+ # Optional["tensor.Tensor"],
157
+ # List["ExtendedModeValue"],
158
+ # Tuple["ExtendedModeValue", ...],
159
+ # np.ndarray,
160
+ # int,
161
+ # float,
162
+ # bool,
163
+ # ]
164
+
165
+
166
+ def _adapt_to_eager_mode (inputs : ExtendedModeValue ) -> EagerModeValue :
167
+ """Adapts inputs into representation used by onnxscript eager mode.
168
+
169
+ This does the following transformations:
170
+ * It adds an onnxscript Tensor wrapper around numpy arrays, which
171
+ allows the use of overloaded operators like + to be controlled by onnxscript.
172
+ * It also provides a promotion of scalars into tensors as a convenience.
173
+ This is needed to complement the similar promotion supported by the
174
+ onnxscript converter (for example, when an attribute is promoted and used
175
+ as an input argument).
176
+
177
+ Args:
178
+ inputs: a list/tuple of inputs to an ONNX function
179
+
180
+ Returns:
181
+ a pair (wrapped_inputs, flag) where flag indicates whether any numpy array
182
+ was wrapped into a Tensor.
183
+ """
184
+ has_array = False
185
+
186
+ def adapt (input : ExtendedModeValue ) -> EagerModeValue :
187
+ if isinstance (input , np .ndarray ):
188
+ nonlocal has_array
189
+ has_array = True
190
+ return tensor .Tensor (input )
191
+ elif isinstance (input , tensor .Tensor ):
192
+ return input
193
+ elif isinstance (input , (bool , int , float )):
194
+ return tensor .Tensor (np .array (input ))
195
+ elif input is None :
196
+ return None
197
+ elif isinstance (input , list ):
198
+ return [adapt (elt ) for elt in input ]
199
+ elif isinstance (input , tuple ):
200
+ return tuple (adapt (elt ) for elt in input )
201
+ raise TypeError (f"Unexpected input type { type (input )} ." )
202
+
203
+ result = adapt (inputs )
204
+ return result , has_array
205
+
206
+
207
+ def _adapt_to_user_mode (output : ExtendedModeValue ) -> UserModeValue :
208
+ """Unwraps Tensor wrapper around numpy arrays.
209
+
210
+ Args:
211
+ output: output of an ONNX function, which can be either a single
212
+ onnx value or a list/tuple of onnx values.
213
+
214
+ Returns:
215
+ unwrapped output
216
+ """
217
+ if isinstance (output , tensor .Tensor ):
218
+ return output .value
219
+ elif output is None :
220
+ return None
221
+ elif isinstance (output , list ):
222
+ return [_adapt_to_user_mode (elt ) for elt in output ]
223
+ elif isinstance (output , tuple ):
224
+ return tuple (_adapt_to_user_mode (elt ) for elt in output )
225
+ elif isinstance (output , np .ndarray ):
226
+ return output
227
+ raise TypeError (f"Unexpected type { type (output )} ." )
228
+
229
+
145
230
class OnnxFunction (Op ):
146
231
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
147
232
@@ -185,75 +270,17 @@ def fun(*args, **kwargs):
185
270
186
271
def __call__ (self , * args , ** kwargs ):
187
272
"""Implements an eager-mode execution of an onnxscript function."""
188
- if len (args ) == 0 :
189
- # Operator Constant, it is usually called within a function.
190
- return self ._libcall (** kwargs )
191
- if isinstance (args [0 ], tensor .Tensor ):
192
- return self ._libcall (* args , ** kwargs )
193
- return self ._usercall (* args , ** kwargs )
194
-
195
- def _usercall (self , * args , ** kwargs ):
196
- """Eager mode"""
197
- new_args = []
198
- for i , a in enumerate (args ):
199
- if isinstance (a , np .ndarray ):
200
- new_args .append (tensor .Tensor (a ))
201
- elif isinstance (a , (bool , int , float )):
202
- new_args .append (tensor .Tensor (np .array (a )))
203
- else :
204
- raise TypeError (f"Unexpected input type { type (a )} for an input { i } ." )
205
- res = self .function (* new_args , ** kwargs )
206
- if isinstance (res , np .ndarray ):
207
- return res
208
- if isinstance (res , tensor .Tensor ):
209
- return res .value
210
- if isinstance (res , (list , tuple )):
211
- unwrapped = []
212
- for i , r in enumerate (res ):
213
- if isinstance (r , np .ndarray ):
214
- unwrapped .append (r )
215
- elif isinstance (r , tensor .Tensor ):
216
- unwrapped .append (r .value )
217
- else :
218
- raise TypeError (
219
- f"Unexpected output type { type (r )} for an output { i } "
220
- f"in function { self .function !r} ."
221
- )
222
- if isinstance (res , tuple ):
223
- return tuple (unwrapped )
224
- return unwrapped
225
- raise TypeError (f"Unexpected output type { type (res )} in function { self .function !r} ." )
226
-
227
- def _libcall (self , * args , ** kwargs ):
228
- """This method must be called when a function decoracted with `script`
229
- calls another one decorated with `script`.
230
- """
231
- new_args = []
232
- for i , a in enumerate (args ):
233
- if isinstance (a , tensor .Tensor ):
234
- new_args .append (a )
235
- elif isinstance (a , bool ):
236
- # TODO: default values for function parameters
237
- # are not properly handled yet. This section
238
- # should disappear.
239
- new_args .append (tensor .Tensor (np .array (a )))
240
- else :
241
- raise TypeError (f"Unexpected input type { type (a )} for an input { i } ." )
242
- res = self .function (* new_args , ** kwargs )
243
- if isinstance (res , tensor .Tensor ):
244
- return res
245
- if isinstance (res , tuple ):
246
- unwrapped = []
247
- for i , r in enumerate (res ):
248
- if isinstance (r , tensor .Tensor ):
249
- unwrapped .append (r )
250
- else :
251
- raise TypeError (
252
- f"Unexpected output type { type (r )} for an output { i } "
253
- f"in function { self .function !r} ."
254
- )
255
- return tuple (unwrapped )
256
- raise TypeError (f"Unexpected output type { type (res )} in function { self .function !r} ." )
273
+ new_args , has_array = _adapt_to_eager_mode (args )
274
+ result = self .function (* new_args , ** kwargs )
275
+
276
+ # We use a heuristic to decide whether to return output values as
277
+ # numpy arrays or tensor.Tensors. If the function has at least one
278
+ # numpy array as input, we return numpy arrays. Otherwise, we return
279
+ # tensor.Tensors. We could use a user-specified flag to control this
280
+ # or explicitly track whether this is a top-level function-call or
281
+ # a nested function-call.
282
+
283
+ return _adapt_to_user_mode (result ) if has_array else result
257
284
258
285
def to_function_proto (self , domain = None ):
259
286
"""Converts the function into :class:`onnx.FunctionProto`."""
0 commit comments