8000 Add a Float8LinearInference module to support static, dynamic, and wo quant by drisspg · Pull Request #287 · pytorch-labs/float8_experimental · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add a Float8LinearInference module to support static, dynamic, and wo quant #287

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -e .
pip install -e .'[dev]'
pip install -e .'[test]'
Expand Down
1 change: 0 additions & 1 deletion benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
Expand Down
1 change: 0 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import collections
import json
import re


Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
4 changes: 2 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, x):
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand Down
6 changes: 3 additions & E864 ; 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ def float8_post_forward(self):
self.is_amax_initialized = True
self.amax_and_scale_synced = False

def forward(self, x):
self.float8_pre_forward(x)
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
Expand Down
59 changes: 41 additions & 18 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type
from typing import Callable, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -97,45 +97,51 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
)


def swap_linear_with_float8_linear(
def swap_linear_layers(
module: nn.Module,
module_cls: Type[nn.Module],
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
) -> Optional[nn.Module]:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
Generic function to swap linear layers in a module with a new type of linear layer.

Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead

Args:
module (torch.nn.Module): Module to modify.
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
from_float_kwargs: Additional keyword arguments for from_float_func.

Returns:
nn.Module: The modified module with swapped linear layers.
"""
module_names_to_skip = set(skip_fqn_list or [])

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(module, emulate=emulate)
return from_float_func(
module,
)

# Mark all modules to skip as visited
root_module = module
visited_modules = {root_module}

for module_name, module in root_module.named_modules():
if module_name in module_names_to_skip:
visited_modules.add(module)

# Run a post-order traversal to swap linears
def post_order_traversal(
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
):
Expand All @@ -144,14 +150,15 @@ def post_order_traversal(
if child_module not in visited_modules:
visited_modules.add(child_module)
post_order_traversal(child_module, child_module_name, module)

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(module, emulate=emulate)
setattr(parent_module, module_name, float8linear_module)
new_linear_module = from_float_func(module)
setattr(parent_module, module_name, new_linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
Expand All @@ -160,6 +167,22 @@ def post_order_traversal(
return root_module


def swap_linear_with_float8_linear(
module: nn.Module,
module_cls: Union[Type[Float8Linear], Type[Float8DynamicLinear]],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> Optional[nn.Module]:
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
)


def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
Expand Down
1 change: 1 addition & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def to_float8(
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion
mm_config: Defines the configuration for the scaled_mm

Returns:
Float8Tensor: a float8 tensor
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")


def compute_error(x: torch.Tensor, y: torch.Tensor):
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.

For more details see:
Expand Down
Loading
Loading
0