-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras's capabilities for large-scale model training by introducing a robust framework for Tensor Parallelism. It provides a new PyTorch-specific distributed backend with collective communication primitives and a flexible parameter sharding mechanism. This allows Keras models to efficiently distribute their parameters across multiple devices, paving the way for more advanced distributed training strategies within the Keras ecosystem. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a foundational framework for Tensor Parallelism in Keras and adds a Torch backend for distributed communication. The changes are substantial and add significant new capabilities. My review focuses on the correctness, generality, and test coverage of this new framework. I've identified some critical issues, such as backend-specific implementations in what should be a backend-agnostic framework, and tests that don't cover the new Torch implementation. There are also opportunities to improve code quality by removing hardcoded logic and reducing code duplication. Addressing these points will help ensure the new Tensor Parallelism framework is robust and maintainable.
def train_step(self, data): | ||
"""Custom training step for the parameter-sharded model. | ||
|
||
This method performs a standard forward and backward pass but | ||
adds a crucial gradient synchronization step (`all_reduce`) before | ||
applying gradients. This ensures that each device updates its | ||
local weight shards using gradients computed from all devices. | ||
|
||
Args: | ||
data: A tuple of (x, y, sample_weight) as passed by `fit()`. | ||
|
||
Returns: | ||
A dictionary mapping metric names to their current values. | ||
""" | ||
import tensorflow as tf | ||
|
||
import keras | ||
|
||
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) | ||
|
||
with tf.GradientTape() as tape: | ||
y_pred = self(x, training=True) | ||
loss = self.compute_loss( | ||
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight | ||
) | ||
|
||
trainable_vars = self.trainable_variables | ||
gradients = tape.gradient(loss, trainable_vars) | ||
|
||
synced_gradients = self.communicator.all_reduce( | ||
gradients, op="sum", axis_name="model" | ||
) | ||
self.optimizer.apply_gradients( | ||
zip(synced_gradients, trainable_vars) | ||
) | ||
|
||
self.compiled_metrics.update_state(y, y_pred, sample_weight) | ||
|
||
return {m.name: m.result() for m in self.metrics} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The train_step
method in ParameterShardedModel
is implemented using tf.GradientTape
, which is specific to the TensorFlow backend. This will cause errors when using any other backend, including PyTorch, which this PR aims to support. The training logic should be backend-agnostic, possibly by leveraging the backend-specific Trainer
implementations that keras.Model
already inherits from.
10000 @pytest.mark.skipif( | ||
keras.backend.backend() != "jax", | ||
reason="This test is JAX-specific.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test file is marked as JAX-specific and uses JAX/XLA-specific configurations (e.g., XLA_FLAGS
). Since this pull request introduces tensor parallelism features for the PyTorch backend, it's crucial to have tests that validate this functionality on PyTorch. This test does not run for the torch backend and therefore doesn't verify the new implementation. Please add tests for the PyTorch backend.
if isinstance(layer, layers.Add): | ||
try: | ||
if "feedforward_output" in layer.name: | ||
residual_source_name = layer.name.replace( | ||
"feedforward_output", "self_attention_output" | ||
) | ||
elif "self_attention_output" in layer.name: | ||
residual_source_name = layer.name.replace( | ||
"self_attention_output", "input_layer_norm" | ||
) | ||
else: | ||
residual_source_name = None | ||
|
||
if ( | ||
residual_source_name | ||
and residual_source_name in tensor_cache | ||
): | ||
layer_inputs = [ | ||
current_tensor, | ||
tensor_cache[residual_source_name], | ||
] | ||
else: | ||
layer_inputs = [current_tensor, current_tensor] | ||
except Exception: | ||
layer_inputs = [current_tensor, current_tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call
method in ParameterShardedModel
contains hardcoded logic that depends on specific layer names like "feedforward_output"
, "self_attention_output"
, and "input_layer_norm"
to handle residual connections. This makes the parameter sharding framework highly coupled to a specific model architecture (likely a Transformer) and not generally applicable. This logic should be generalized to not rely on fragile name matching. A more robust approach could involve using the functional API's graph structure to identify residual connections.
def apply_gradients( | ||
gradients: List[torch.Tensor], | ||
trainable_vars: List[torch.Tensor], | ||
learning_rate: float = 0.001, | ||
) -> List[torch.Tensor]: | ||
"""Applies gradients and returns the updated variables. | ||
|
||
Updates are performed in-place within a `torch.no_grad()` context | ||
to prevent the update operation from being part of the computation graph. | ||
""" | ||
with torch.no_grad(): | ||
updated_vars = [] | ||
for grad, var in zip(gradients, trainable_vars): | ||
if grad is not None: | ||
var.sub_(learning_rate * grad) | ||
updated_vars.append(var) | ||
return updated_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The apply_gradients
function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g., _apply_gradients
) or moving it into the test suite to avoid confusion.
@pytest.mark.skipif( | ||
backend.backend() != "torch", | ||
reason="Jax Backend specific test", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The skipif
reason "Jax Backend specific test" appears to be a copy-paste error. This test class is for the PyTorch distributed backend and should be labeled as such.
@pytest.mark.skipif( | |
backend.backend() != "torch", | |
reason="Jax Backend specific test", | |
) | |
@pytest.mark.skipif( | |
backend.backend() != "torch", | |
reason="Torch Backend specific test", | |
) |
def apply_parameter_sharding_to_existing_model( | ||
model, config: ConfigKeras, rank: int, world_size: int | ||
): | ||
"""Applies parameter sharding directly to an existing model instance. | ||
|
||
This function modifies a model in-place. Instead of returning a new | ||
wrapped model, it shards the weights and attaches the sharding strategy | ||
to the original model object. This is useful when the model's execution | ||
logic is handled externally. | ||
|
||
Args: | ||
model: The Keras model to modify. | ||
config (ConfigKeras): Configuration object with sharding rules. | ||
rank (int): The rank of the current process. | ||
world_size (int): The total number of processes. | ||
|
||
Returns: | ||
The modified model with an attached `_tensor_parallel_sharding` | ||
strategy attribute. | ||
""" | ||
|
||
sharding_strategy = ParameterShardingStrategy(world_size, rank) | ||
for pattern, action in config.state_rules.items(): | ||
if isinstance(action, StateActionKeras): | ||
matching_params = sharding_strategy._find_matching_parameters( | ||
model, pattern | ||
) | ||
|
||
for param_name, param in matching_params: | ||
try: | ||
param_id = id(param.experimental_ref()) | ||
except AttributeError: | ||
param_id = id(param) | ||
|
||
if param_id in sharding_strategy.sharded_weights_by_id: | ||
sharding_strategy.sharded_weights[param_name] = ( | ||
sharding_strategy.sharded_weights_by_id[param_id] | ||
) | ||
existing_param_name = next( | ||
k | ||
for k, v in sharding_strategy.sharded_weights.items() | ||
if v | ||
is sharding_strategy.sharded_weights_by_id[param_id] | ||
) | ||
sharding_strategy.weight_mapping[param_name] = ( | ||
sharding_strategy.weight_mapping[existing_param_name] | ||
) | ||
continue | ||
|
||
sharded_param = action(param, rank) | ||
|
||
sharding_strategy.sharded_weights[param_name] = sharded_param | ||
sharding_strategy.sharded_weights_by_id[param_id] = ( | ||
sharded_param | ||
) | ||
|
||
sharding_strategy.weight_mapping[param_name] = { | ||
"original_shape": param.shape, | ||
"sharded_shape": sharded_param.shape, | ||
"action": action, | ||
} | ||
|
||
model._tensor_parallel_sharding = sharding_strategy | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function apply_parameter_sharding_to_existing_model
contains logic for finding and sharding parameters that is very similar to the implementation within ParameterShardingStrategy.shard_model_parameters
(lines 137-175). This code duplication can make maintenance harder. Consider refactoring the common logic into a shared private helper function to improve code reuse and maintainability.
This pull request introduces a foundational framework for Tensor Parallelism in Keras, Parameter_sharding.py, enabling the training of large-scale models by sharding their parameters across multiple devices. This is a significant step towards supporting advanced distributed training strategies directly within the Keras ecosystem.
The core of this contribution is a new, backend-agnostic parameter sharding framework and the necessary distributed communication primitives for the PyTorch backend.
Key Changes
PyTorch Distributed Backend
A new distributed_backend.py module has been added for the PyTorch backend.
It implements essential collective communication operations (all_reduce, all_gather, broadcast, scatter) using the torch.distributed package.
Provides helper functions for gradient computation (compute_gradients) and device management, aligning its interface with other Keras backends.
Parameter Sharding Framework
Introduces a powerful parameter sharding API under keras/src/distribution/tensor_parallel/.
ParameterShardingStrategy: A new class that manages the logic for splitting model weights based on user-defined rules specified in a ConfigKeras object.
ShardedWeight: A wrapper class for sharded keras.Variable objects, allowing them to be seamlessly integrated into the model.
make_parameter_sharded_model: A factory function that takes a standard Keras model and returns a sharded version, automatically handling the weight splitting and model wrapping. The wrapped ParameterShardedModel injects communication ops (e.g., all-reduce) into the forward pass to ensure correct computations.