-
Notifications
You must be signed in to change notification settings - Fork 29.8k
🚨🚨🚨 [pipelines] update defaults in pipelines that can generate
#38129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
@@ -1688,7 +1688,7 @@ def _prepare_generation_config( | |||
if custom_gen_config_value == default_value and model_gen_config_value != default_value: | |||
modified_values[key] = model_gen_config_value | |||
setattr(generation_config, key, model_gen_config_value) | |||
8000 | if len(modified_values) > 0: | ||
if use_model_defaults is None and len(modified_values) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only want to print the warning [model-specific values overriding global defaults] when the use_model_defaults
is unset.
In pipelines we were already using the model defaults, this change prevents unwanted warnings.
(model defaults is messy at the moment, model-specific generation config files should untangle this mess)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This area feels a little confusing and bug-prone because we have default_generation_config
, self.generation_config
and generation_config
all being accessed in the same loop. Maybe we can alias self.generation_config
or something? I feel like I'm extremely likely to get it mixed up with generation_config
if I edit this code.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yayyyyy - quite excited for this!
@@ -129,6 +134,11 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): | |||
[huggingface.co/models](https://huggingface.co/models?filter=document-question-answering). | |||
""" | |||
|
|||
# Make sure the docstring is updated when the default generation config is changed | |||
_default_generation_config = GenerationConfig( | |||
max_new_tokens=256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_new_tokens=256, | |
max_new_tokens=1024, |
maybe a bit higher specifically if some one wants to just pass a long prompt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's for max new tokens (not total tokens) so it doesn't take into account the prompt! Same answer as below, I'd encourage not setting too high values either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, my understanding was that if EOS token is encountered it'll early exit the generation anyway - and it'll not generate all tokens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to enable something like max_new_length="auto"
, which maximizes max_new_length
according to available HW?
As VB wrote, in most models we don't actually hit max_new_length
because of the EOS token(s). But a large default max_new_tokens
will 100% hit OOM issues (it was one of the reasons behind the old default btw 👀 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah! might be a good idea, but maybe in another PR as follow-up if we don't see a lot of edge cases from this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, 100% as a follow-up
@@ -95,6 +102,13 @@ class TextGenerationPipeline(Pipeline): | |||
begging for his blessing. <eod> </s> <eos> | |||
""" | |||
|
|||
# Make sure the docstring is updated when the default generation config is changed | |||
_default_generation_config = GenerationConfig( | |||
max_new_tokens=256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_new_tokens=256, | |
max_new_tokens=1024, |
also think here a sensible default would be 1K, for typical summarise X
usecases, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 1024 can be really, really long to generate - for example some Qwen models have 2k set and even on an M4 laptop I can end up waiting minutes for an output
I think having a very large default for systems that are streaming is fine, but for systems which are blocking, I'd actually advise something lower (256 already seems high to me, but I can live with it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
256 sounds good to me IMO!
130d794
to
a0c6cfd
Compare
@@ -204,8 +204,9 @@ def test_validate(self): | |||
|
|||
# By default we throw a short warning. However, we log with INFO level the details. | |||
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details. | |||
with CaptureLogger(logger) as captured_logs: | |||
GenerationConfig(do_sample=False, temperature=0.5) | |||
with LoggingLevel(logging.WARNING): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is something in our test suite that likely changes the default logging level -- this test is flaky because the line below is capturing INFO level logs 🤔
these flaky checks were merged today (2025-may-19) by me
There was a problem hiding this comment.
Choose a reason for hiding this comment
10000 The reason will be displayed to describe this comment to others. Learn more.
LGTM! The changes to tests and individual pipelines are mostly straightforward, so the core changes are:
- Replacing errors when an assistant model is passed for generations that don't support it (e.g. beam search) with a warning
- Flagging model-specific values overriding defaults in
generation/utils.py
- Adding the
_pipeline_calls_generate
attribute inpipelines/base.py
to replacemodel.can_generate()
, which doesn't always indicate that the pipeline will be generating text
Those changes all make sense to me! I made a couple of small comments, but they're nits and none of them are blockers - it's fine to ignore them or deal with them after merging.
self.assistant_model, self.assistant_tokenizer = load_assistant_model( | ||
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a little weird for this to be outside self.model.can_generate()
- is there a case where the model can't generate but we still have assistant_model
? I'm guessing this is related to the changes in configuration_utils.py
that replace errors for unexpected assistant model kwargs with warnings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a case where the model can't generate but we still have assistant_model?
No, moving these lines in :)
The replacement of errors by warnings is because of the new defaults in some pipelines. e.g. on ASR we now set num_beams=5
, which would make any attempt of calling assisted generation crash unless we were to manually set num_beams=1
. We want the code to run if we pass an assistant model, so the warning is a compromise 🤔
@@ -420,6 +420,7 @@ def test_return_timestamps_in_preprocess(self): | |||
|
|||
@slow | |||
@require_torch | |||
@unittest.skip("TODO (joao, eustache): this test is failing, find the breaking PR and fix the cause or the test") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly nervous about all the skips here, but if we do a follow-up PR soon then it's probably fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, these tests were already failing on main
, and I think for quite a long time 👀
"text2text-generation", | ||
model="patrickvonplaten/t5-tiny-random", | ||
framework="pt", | ||
num_beams=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is num_beams=1
not the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not with these changes (we recommend using beam search in text2text models in our docs, so might as well set it as default for those pipelines!)
@@ -1688,7 +1688,7 @@ def _prepare_generation_config( | |||
if custom_gen_config_value == default_value and model_gen_config_value != default_value: | |||
modified_values[key] = model_gen_config_value | |||
setattr(generation_config, key, model_gen_config_value) | |||
if len(modified_values) > 0: | |||
if use_model_defaults is None and len(modified_values) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This area feels a little confusing and bug-prone because we have default_generation_config
, self.generation_config
and generation_config
all being accessed in the same loop. Maybe we can alias self.generation_config
or something? I feel like I'm extremely likely to get it mixed up with generation_config
if I edit this code.
…gingface#38129) * pipeline generation defaults * add max_new_tokens=20 in test pipelines * pop all kwargs that are used to parameterize generation config * add class attr that tell us whether a pipeline calls generate * tmp commit * pt text gen pipeline tests passing * remove failing tf tests * fix text gen pipeline mixin test corner case * update text_to_audio pipeline tests * trigger tests * a few more tests * skips * some more audio tests * not slow * broken * lower severity of generation mode errors * fix all asr pipeline tests * nit * skip * image to text pipeline tests * text2test pipeline * last pipelines * fix flaky * PR comments * handle generate attrs more carefully in models that cant generate * same as above
…gingface#38129) * pipeline generation defaults * add max_new_tokens=20 in test pipelines * pop all kwargs that are used to parameterize generation config * add class attr that tell us whether a pipeline calls generate * tmp commit * pt text gen pipeline tests passing * remove failing tf tests * fix text gen pipeline mixin test corner case * update text_to_audio pipeline tests * trigger tests * a few more tests * skips * some more audio tests * not slow * broken * lower severity of generation mode errors * fix all asr pipeline tests * nit * skip * image to text pipeline tests * text2test pipeline * last pipelines * fix flaky * PR comments * handle generate attrs more carefully in models that cant generate * same as above
Some PEFT integration tests involving text generation pipelines were failing since huggingface#38129 because the base model is too small to generate longer sequences. Setting max_new_tokens fixes this.
Some PEFT integration tests involving text generation pipelines were failing since #38129 because the base model is too small to generate longer sequences. Setting max_new_tokens fixes this.
What does this PR do?
TL;DR our defaults for text generation are very outdated. Most notably, the default maximum length.
This PR adds the tooling to easily add pipeline-specific
generate
defaults, and adds new defaults to all pipelines that callgenerate
.In all pipelines that call
generate
, the newmax_new_tokens
default is256
Tests
This PR reviewed the tests of the pipelines that call
generate
such that their CI is now green. Overall changes:main
-> added a skip with a TODO@slow
max_new_tokens
Example
On main it prints
With this PR it prints