@@ -517,11 +517,21 @@ def calculate_range(dtype: torch.dtype) -> tuple:
517
517
# 8-bit floating-point format with e5m2 layout
518
518
min_val = - 57344.0
519
519
max_val = 57344.0
520
+ if dtype == torch .float8_e4m3fn :
521
+ min_val = - 448
522
+ max_val = 448
520
523
else :
521
524
raise ValueError (f"Unsupported dtype: { dtype } " )
522
525
return min_val , max_val
523
526
524
527
528
+ def has_overflow (node : torch .fx .Node , max_val : float ) -> bool :
529
+ abs = torch .ops .aten .abs .default (node .meta ["val" ])
530
+ amex = torch .ops .aten .amax .default (abs .meta ["val" ], [- 1 ], True )
531
+
532
+ return amex .item () > max_val
533
+
534
+
525
535
def quantize_activation_fw (graph : torch .fx .Graph ) -> None :
526
536
output = graph .find_nodes (op = "output" )[0 ]
527
537
fwd_outputs = output .args [0 ]
@@ -532,23 +542,10 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
532
542
for node in fwd_outputs :
533
543
# check if the activation node is the node saved for quantization
534
544
if node .meta .get ("saved_for_quantization" , False ):
535
- # case: use scaling
536
- if torch ._inductor .config .post_grad_fusion_options [
545
+ # use non- scaling in default, fallback to scaling when overflow
546
+ if not torch ._inductor .config .post_grad_fusion_options [
537
547
"activation_quantization_aten_pass"
538
- ].get ("use_scaling" , False ):
539
- # calculating the scale
540
- scale_node = calculate_quantization_scaling (
541
- graph , node , clamp_max , 1e-12
542
- )
543
- # converting to fp8
544
- quant_node = perform_quantization (
545
- graph , node , scale_node , quant_type , clamp_min , clamp_max
546
- )
547
- if not is_sym_node (scale_node ):
548
- tensor_scale_nodes .append (scale_node )
549
- else :
550
- sym_scale_nodes .append (scale_node )
551
- else :
548
+ ].get ("use_scaling" , False ) and not has_overflow (node , clamp_max ):
552
549
# case: do not use scaling
553
550
with graph .inserting_after (node ):
554
551
quant_node = graph .call_function (
@@ -564,6 +561,19 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
564
561
quant_node .meta ["tensor_meta" ] = extract_tensor_metadata (
565
562
quant_node .meta ["val" ]
566
563
)
564
+ else :
565
+ # calculating the scale
566
+ scale_node = calculate_quantization_scaling (
567
+ graph , node , clamp_max , 1e-12
568
+ )
569
+ # converting to fp8
570
+ quant_node = perform_quantization (
571
+ graph , node , scale_node , quant_type , clamp_min , clamp_max
572
+ )
573
+ if not is_sym_node (scale_node ):
574
+ tensor_scale_nodes .append (scale_node )
575
+ else :
576
+ sym_scale_nodes .append (scale_node )
567
577
node_to_quant [node ] = quant_node
568
578
# only update the return node args, and remain all other users unchanged
569
579
output_updated_args = [
0 commit comments