8000 [PT2][Optimus][fp8 compuation quantizatoin] Add fallback logic · pytorch/pytorch@fc4e798 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc4e798

Browse files
mengluy0125facebook-github-bot
authored andcommitted
[PT2][Optimus][fp8 compuation quantizatoin] Add fallback logic
Summary: It is possible that the data has overflow when do the non-scaling quantization, we thus check its data range, and will automatically use scaling version if such overflow is detected. Differential Revision: D74610644
1 parent 641e4be commit fc4e798

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

torch/_functorch/partitioners.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,21 @@ def calculate_range(dtype: torch.dtype) -> tuple:
517517
# 8-bit floating-point format with e5m2 layout
518518
min_val = -57344.0
519519
max_val = 57344.0
520+
if dtype == torch.float8_e4m3fn:
521+
min_val = -448
522+
max_val = 448
520523
else:
521524
raise ValueError(f"Unsupported dtype: {dtype}")
522525
return min_val, max_val
523526

524527

528+
def has_overflow(node: torch.fx.Node, max_val: float) -> bool:
529+
abs = torch.ops.aten.abs.default(node.meta["val"])
530+
amex = torch.ops.aten.amax.default(abs.meta["val"], [-1], True)
531+
532+
return amex.item() > max_val
533+
534+
525535
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
526536
output = graph.find_nodes(op="output")[0]
527537
fwd_outputs = output.args[0]
@@ -532,23 +542,10 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
532542
for node in fwd_outputs:
533543
# check if the activation node is the node saved for quantization
534544
if node.meta.get("saved_for_quantization", False):
535-
# case: use scaling
536-
if torch._inductor.config.post_grad_fusion_options[
545+
# use non-scaling in default, fallback to scaling when overflow
546+
if not torch._inductor.config.post_grad_fusion_options[
537547
"activation_quantization_aten_pass"
538-
].get("use_scaling", False):
539-
# calculating the scale
540-
scale_node = calculate_quantization_scaling(
541-
graph, node, clamp_max, 1e-12
542-
)
543-
# converting to fp8
544-
quant_node = perform_quantization(
545-
graph, node, scale_node, quant_type, clamp_min, clamp_max
546-
)
547-
if not is_sym_node(scale_node):
548-
tensor_scale_nodes.append(scale_node)
549-
else:
550-
sym_scale_nodes.append(scale_node)
551-
else:
548+
].get("use_scaling", False) and not has_overflow(node, clamp_max):
552549
# case: do not use scaling
553550
with graph.inserting_after(node):
554551
quant_node = graph.call_function(
@@ -564,6 +561,19 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
564561
quant_node.meta["tensor_meta"] = extract_tensor_metadata(
565562
quant_node.meta["val"]
566563
)
564+
else:
565+
# calculating the scale
566+
scale_node = calculate_quantization_scaling(
567+
graph, node, clamp_max, 1e-12
568+
)
569+
# converting to fp8
570+
quant_node = perform_quantization(
571+
graph, node, scale_node, quant_type, clamp_min, clamp_max
572+
)
573+
if not is_sym_node(scale_node):
574+
tensor_scale_nodes.append(scale_node)
575+
else:
576+
sym_scale_nodes.append(scale_node)
567577
node_to_quant[node] = quant_node
568578
# only update the return node args, and remain all other users unchanged
569579
output_updated_args = [

0 commit comments

Comments
 (0)
0