8000 Fix stateful RNN to raise ValueError on batch size mismatch by El3ssar · Pull Request #21249 · keras-team/keras · GitHub
[go: up one dir, main page]

Skip to content

Fix stateful RNN to raise ValueError on batch size mismatch #21249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

El3ssar
Copy link
@El3ssar El3ssar commented May 4, 2025

This PR fixes a silent failure when calling stateful RNN layers (SimpleRNN, GRU, LSTM) with an input that doesn't match the fixed batch size.

Previously:

  • SimpleRNN would silently broadcast the input across all internal states.
  • GRU and LSTM would crash in CuDNN but still mutate internal state before failing.

Now:

  • A ValueError is raised early in RNN.call() if the batch size of the input doesn't match the expected value from the internal state.
  • This avoids state corruption and aligns the behavior across all RNN variants.

Fixes #21183

Manually tested by calling a stateful model with incorrect batch sizes
and verifying that the new ValueError is raised as expected.

When using stateful=True, RNN layers now raise a ValueError if the input
batch size does not match the internal state. This prevents silent broadcasting
(SimpleRNN) or state mutation prior to CuDNN errors (GRU/LSTM).

Fixes keras-team#21183
Copy link
google-cla bot commented May 4, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link
codecov-commenter commented May 4, 2025

Codecov Report

Attention: Patch coverage is 0% with 5 lines in your changes missing coverage. Please review.

Project coverage is 34.56%. Comparing base (f5171b3) to head (3180ed4).
Report is 124 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/rnn/rnn.py 0.00% 5 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (f5171b3) and HEAD (3180ed4). Click for more details.

HEAD has 8 uploads less than BASE
Flag BASE (f5171b3) HEAD (3180ed4)
keras 5 1
keras-numpy 1 0
keras-torch 1 0
keras-tensorflow 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21249       +/-   ##
===========================================
- Coverage   82.60%   34.56%   -48.04%     
===========================================
  Files         564      567        +3     
  Lines       54543    56219     +1676     
  Branches     8472     8788      +316     
===========================================
- Hits        45054    19431    -25623     
- Misses       7402    35910    +28508     
+ Partials     2087      878     -1209     
Flag Coverage Δ
keras 34.56% <0.00%> (-47.86%) ⬇️
keras-jax ?
keras-numpy ?
keras-openvino 34.56% <0.00%> (+1.56%) ⬆️
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@keerthanakadiri
Copy link
Contributor

Hi @El3ssar, Can you please sign the CLA? Thanks !

@gbaned gbaned requested a review from fchollet May 5, 2025 09:19
@gbaned gbaned added this to PR Queue May 5, 2025
@github-project-automation github-project-automation bot moved this to Assigned Reviewer in PR Queue May 5, 2025
Copy link
Collaborator
@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please add a simple test case for the fix (using self.assertRaisesRegex)

@El3ssar
Copy link
Author
El3ssar commented May 22, 2025

Thanks for the PR! Please add a simple test case for the fix (using self.assertRaisesRegex)

Gladly, where do I put the test?

@fchollet
Copy link
Collaborator

You can put the test in keras/src/layers/rnn/rnn_test.py alongside test_statefulness_two_states. You can target layer = layers.RNN(TwoStatesRNNCell(2), stateful=True) as the RNN layer.

@El3ssar
Copy link
Author
El3ssar commented Jul 25, 2025

@fchollet I've added the requested test in rnn_test.py using TwoStatesRNNCell. It fails without the fix and passes with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

stateful=True RNN silently broadcasts input when batch_size mismatch occurs
5 participants
0