8000 Introducing Parameter Sharding and Torch backend for Tensor Parallelism by buildwithsuhana · Pull Request #21724 · keras-team/keras · GitHub
[go: up one dir, main page]

Skip to content

Conversation

buildwithsuhana
Copy link
Contributor

This pull request introduces a foundational framework for Tensor Parallelism in Keras, Parameter_sharding.py, enabling the training of large-scale models by sharding their parameters across multiple devices. This is a significant step towards supporting advanced distributed training strategies directly within the Keras ecosystem.

The core of this contribution is a new, backend-agnostic parameter sharding framework and the necessary distributed communication primitives for the PyTorch backend.

Key Changes

PyTorch Distributed Backend
A new distributed_backend.py module has been added for the PyTorch backend.

It implements essential collective communication operations (all_reduce, all_gather, broadcast, scatter) using the torch.distributed package.

Provides helper functions for gradient computation (compute_gradients) and device management, aligning its interface with other Keras backends.

Parameter Sharding Framework
Introduces a powerful parameter sharding API under keras/src/distribution/tensor_parallel/.

ParameterShardingStrategy: A new class that manages the logic for splitting model weights based on user-defined rules specified in a ConfigKeras object.

ShardedWeight: A wrapper class for sharded keras.Variable objects, allowing them to be seamlessly integrated into the model.

make_parameter_sharded_model: A factory function that takes a standard Keras model and returns a sharded version, automatically handling the weight splitting and model wrapping. The wrapped ParameterShardedModel injects communication ops (e.g., all-reduce) into the forward pass to ensure correct computations.

@buildwithsuhana buildwithsuhana marked this pull request as draft October 7, 2025 05:40
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances Keras's capabilities for large-scale model training by introducing a robust framework for Tensor Parallelism. It provides a new PyTorch-specific distributed backend with collective communication primitives and a flexible parameter sharding mechanism. This allows Keras models to efficiently distribute their parameters across multiple devices, paving the way for more advanced distributed training strategies within the Keras ecosystem.

Highlights

  • PyTorch Distributed Backend: A new distributed_backend.py module is introduced for PyTorch, providing essential collective communication operations like all_reduce, all_gather, broadcast, and scatter, along with gradient computation and device management utilities.
  • Parameter Sharding Framework: A foundational framework for parameter sharding is added under keras/src/distribution/tensor_parallel/. This includes ParameterShardingStrategy for managing weight splitting, ShardedWeight for wrapping sharded variables, and make_parameter_sharded_model for creating sharded Keras models.
  • Backend Integration: The new PyTorch distributed backend is integrated into the Keras backend system, and placeholder distributed_backend = None is added for other backends (TensorFlow, JAX, NumPy, OpenVINO) to maintain consistency.
  • Comprehensive Testing: New unit tests are added for both the PyTorch distributed backend functions and the parameter sharding framework, ensuring correctness and reliability of the new features.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor
@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a foundational framework for Tensor Parallelism in Keras and adds a Torch backend for distributed communication. The changes are substantial and add significant new capabilities. My review focuses on the correctness, generality, and test coverage of this new framework. I've identified some critical issues, such as backend-specific implementations in what should be a backend-agnostic framework, and tests that don't cover the new Torch implementation. There are also opportunities to improve code quality by removing hardcoded logic and reducing code duplication. Addressing these points will help ensure the new Tensor Parallelism framework is robust and maintainable.

Comment on lines +375 to +413
def train_step(self, data):
"""Custom training step for the parameter-sharded model.

This method performs a standard forward and backward pass but
adds a crucial gradient synchronization step (`all_reduce`) before
applying gradients. This ensures that each device updates its
local weight shards using gradients computed from all devices.

Args:
data: A tuple of (x, y, sample_weight) as passed by `fit()`.

Returns:
A dictionary mapping metric names to their current values.
"""
import tensorflow as tf

import keras

x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)

trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

synced_gradients = self.communicator.all_reduce(
gradients, op="sum", axis_name="model"
)
self.optimizer.apply_gradients(
zip(synced_gradients, trainable_vars)
)

self.compiled_metrics.update_state(y, y_pred, sample_weight)

return {m.name: m.result() for m in self.metrics}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The train_step method in ParameterShardedModel is implemented using tf.GradientTape, which is specific to the TensorFlow backend. This will cause errors when using any other backend, including PyTorch, which this PR aims to support. The training logic should be backend-agnostic, possibly by leveraging the backend-specific Trainer implementations that keras.Model already inherits from.

Comment on lines 22 to 25
10000 @pytest.mark.skipif(
keras.backend.backend() != "jax",
reason="This test is JAX-specific.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test file is marked as JAX-specific and uses JAX/XLA-specific configurations (e.g., XLA_FLAGS). Since this pull request introduces tensor parallelism features for the PyTorch backend, it's crucial to have tests that validate this functionality on PyTorch. This test does not run for the torch backend and therefore doesn't verify the new implementation. Please add tests for the PyTorch backend.

Comment on lines +477 to +501
if isinstance(layer, layers.Add):
try:
if "feedforward_output" in layer.name:
residual_source_name = layer.name.replace(
"feedforward_output", "self_attention_output"
)
elif "self_attention_output" in layer.name:
residual_source_name = layer.name.replace(
"self_attention_output", "input_layer_norm"
)
else:
residual_source_name = None

if (
residual_source_name
and residual_source_name in tensor_cache
):
layer_inputs = [
current_tensor,
tensor_cache[residual_source_name],
]
else:
layer_inputs = [current_tensor, current_tensor]
except Exception:
layer_inputs = [current_tensor, current_tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call method in ParameterShardedModel contains hardcoded logic that depends on specific layer names like "feedforward_output", "self_attention_output", and "input_layer_norm" to handle residual connections. This makes the parameter sharding framework highly coupled to a specific model architecture (likely a Transformer) and not generally applicable. This logic should be generalized to not rely on fragile name matching. A more robust approach could involve using the functional API's graph structure to identify residual connections.

Comment on lines +31 to +47
def apply_gradients(
gradients: List[torch.Tensor],
trainable_vars: List[torch.Tensor],
learning_rate: float = 0.001,
) -> List[torch.Tensor]:
"""Applies gradients and returns the updated variables.

Updates are performed in-place within a `torch.no_grad()` context
to prevent the update operation from being part of the computation graph.
"""
with torch.no_grad():
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
var.sub_(learning_rate * grad)
updated_vars.append(var)
return updated_vars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The apply_gradients function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g., _apply_gradients) or moving it into the test suite to avoid confusion.

Comment on lines +12 to +15
@pytest.mark.skipif(
backend.backend() != "torch",
reason="Jax Backend specific test",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The skipif reason "Jax Backend specific test" appears to be a copy-paste error. This test class is for the PyTorch distributed backend and should be labeled as such.

Suggested change
@pytest.mark.skipif(
backend.backend() != "torch",
reason="Jax Backend specific test",
)
@pytest.mark.skipif(
backend.backend() != "torch",
reason="Torch Backend specific test",
)

Comment on lines +605 to +669
def apply_parameter_sharding_to_existing_model(
model, config: ConfigKeras, rank: int, world_size: int
):
"""Applies parameter sharding directly to an existing model instance.

This function modifies a model in-place. Instead of returning a new
wrapped model, it shards the weights and attaches the sharding strategy
to the original model object. This is useful when the model's execution
logic is handled externally.

Args:
model: The Keras model to modify.
config (ConfigKeras): Configuration object with sharding rules.
rank (int): The rank of the current process.
world_size (int): The total number of processes.

Returns:
The modified model with an attached `_tensor_parallel_sharding`
strategy attribute.
"""

sharding_strategy = ParameterShardingStrategy(world_size, rank)
for pattern, action in config.state_rules.items():
if isinstance(action, StateActionKeras):
matching_params = sharding_strategy._find_matching_parameters(
model, pattern
)

for param_name, param in matching_params:
try:
param_id = id(param.experimental_ref())
except AttributeError:
param_id = id(param)

if param_id in sharding_strategy.sharded_weights_by_id:
sharding_strategy.sharded_weights[param_name] = (
sharding_strategy.sharded_weights_by_id[param_id]
)
existing_param_name = next(
k
for k, v in sharding_strategy.sharded_weights.items()
if v
is sharding_strategy.sharded_weights_by_id[param_id]
)
sharding_strategy.weight_mapping[param_name] = (
sharding_strategy.weight_mapping[existing_param_name]
)
continue

sharded_param = action(param, rank)

sharding_strategy.sharded_weights[param_name] = sharded_param
sharding_strategy.sharded_weights_by_id[param_id] = (
sharded_param
)

sharding_strategy.weight_mapping[param_name] = {
"original_shape": param.shape,
"sharded_shape": sharded_param.shape,
"action": action,
}

model._tensor_parallel_sharding = sharding_strategy

return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function apply_parameter_sharding_to_existing_model contains logic for finding and sharding parameters that is very similar to the implementation within ParameterShardingStrategy.shard_model_parameters (lines 137-175). This code duplication can make maintenance harder. Consider refactoring the common logic into a shared private helper function to improve code reuse and maintainability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0