@@ -150,12 +150,7 @@ def false_fn(x, z):
150
150
x = torch .cond (x .sum () > 0 , true_fn , false_fn , (x , z ))
151
151
return x , z
152
152
153
- onnx_program = torch .onnx .export (
154
- CondModel (),
155
- (torch .tensor ([1 , 2 ]),),
156
- dynamo = True ,
157
- fallback = False ,
158
- )
153
+ onnx_program = self .export (CondModel (), (torch .tensor ([1 , 2 ]),))
159
154
onnx_testing .assert_onnx_program (onnx_program )
160
155
onnx_testing .assert_onnx_program (onnx_program , args = (torch .tensor ([- 1 , - 2 ]),))
161
156
@@ -194,56 +189,34 @@ def forward(self, x):
194
189
_ = self .export (exported_program )
195
190
196
191
@common_utils .parametrize (
197
- "float8_type" ,
192
+ "float8_type, onnx_type " ,
198
193
[
199
194
common_utils .subtest (
200
- torch .float8_e5m2 ,
195
+ ( torch .float8_e5m2 , ir . DataType . FLOAT8E5M2 ) ,
201
196
name = "torch_float8_e5m2" ,
202
197
),
203
198
common_utils .subtest (
204
- torch .float8_e5m2fnuz ,
199
+ ( torch .float8_e5m2fnuz , ir . DataType . FLOAT8E5M2FNUZ ) ,
205
200
name = "torch_float8_e5m2fnuz" ,
206
201
),
207
202
common_utils .subtest (
208
- torch .float8_e4m3fn ,
203
+ ( torch .float8_e4m3fn , ir . DataType . FLOAT8E4M3FN ) ,
209
204
name = "torch_float8_e4m3fn" ,
210
205
),
211
206
common_utils .subtest (
212
- torch .float8_e4m3fnuz ,
207
+ ( torch .float8_e4m3fnuz , ir . DataType . FLOAT8E4M3FNUZ ) ,
213
208
name = "torch_float8_e4m3fnuz" ,
214
209
),
215
210
],
216
211
)
217
- def test_float8_support (self , float8_type ):
212
+ def test_float8_support (self , float8_type : torch . dtype , onnx_type : ir . DataType ):
218
213
class Float8Module (torch .nn .Module ):
219
214
def forward (self , input : torch .Tensor ):
220
215
input = input .to (float8_type )
221
216
return input
222
217
223
- _ = self .export (Float8Module (), (torch .randn (1 , 2 ),))
224
-
225
- def test_bfloat16_support (self ):
226
- class BfloatModel (torch .nn .Module ):
227
- def __init__ (self ):
228
- super ().__init__ ()
229
- # Test parameters
230
- self .param = torch .nn .Parameter (torch .tensor (2.0 , dtype = torch .bfloat16 ))
231
-
232
- def forward (self , x ):
233
- # Test constant tensors are stored as bfloat16
234
- const = torch .tensor (1.0 , dtype = torch .bfloat16 )
235
- return x * const * self .param
236
-
237
- input = torch .tensor ([1.0 , 2.0 ], dtype = torch .bfloat16 )
238
- onnx_program = self .export (BfloatModel (), (input ,), optimize = False )
239
- initializers = onnx_program .model .graph .initializers .values ()
240
- self .assertEqual (len (initializers ), 2 )
241
- for initializer in initializers :
242
- self .assertEqual (initializer .dtype , ir .DataType .BFLOAT16 )
243
- self .assertEqual (onnx_program .model .graph .inputs [0 ].dtype , ir .DataType .BFLOAT16 )
244
- self .assertEqual (
245
- onnx_program .model .graph .outputs [0 ].dtype , ir .DataType .BFLOAT16
246
- )
218
+ onnx_program = self .export (Float8Module (), (torch .randn (1 , 2 ),))
219
+ self .assertEqual (onnx_program .model .graph .outputs [0 ].dtype , onnx_type )
247
220
248
221
def test_export_with_logging_logger (self ):
249
222
logger = logging .getLogger (__name__ )
0 commit comments