10000 [Model] add dots1 by redmoe-moutain · Pull Request #38143 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

[Model] add dots1 #38143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 25, 2025
Merged

[Model] add dots1 #38143

merged 7 commits into from
Jun 25, 2025

Conversation

redmoe-moutain
Copy link
Contributor
@redmoe-moutain redmoe-moutain commented May 15, 2025

What does this PR do?

Support model dots.llm1 by rednote-hilab

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@redmoe-moutain redmoe-moutain marked this pull request as ready for review May 20, 2025 09:24
@redmoe-moutain redmoe-moutain changed the title add dots1 [Model] add dots1 May 20, 2025
@redmoe-moutain redmoe-moutain marked this pull request as draft May 20, 2025 09:29
@redmoe-moutain redmoe-moutain marked this pull request as ready for review May 20, 2025 10:28
@redmoe-moutain
Copy link
Contributor Author

@ArthurZucker Could you please take a look?

@Rocketknight1
Copy link
Member
Rocketknight1 commented May 22, 2025

Hi @redmoe-moutain is there an existing pre-trained dots1 model somewhere? We generally don't add architectures until we need them to support a significant model checkpoint

@redmoe-moutain
Copy link
Contributor Author

@Rocketknight1 We're rolling out the open-source models dots.llm1. You can check out the pretrained model here: https://huggingface.co/rednote-hilab/dots.llm1.base. The instruct version and a detailed report are coming soon.

@Rocketknight1
Copy link
Member

Cool! In that case, @Cyrilvallez can you take the review?

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks actually very nice! Thanks for the clean modular, I actually have a hard time believing it !

)


class Dots1ModelTester:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check test_modeling_llama we have a new simpler mixin for testing!

@@ -0,0 +1,192 @@
from ...configuration_utils import PretrainedConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a licecne!

Comment on lines +102 to +127
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
"layers.*.mlp.experts.*.up_proj": "local_colwise",
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
"layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
"layers.*.mlp.shared_experts": "local",
"layers.*.mlp.gate_proj": "local_colwise",
"layers.*.mlp.up_proj": "local_colwise",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick q, did you test TP to make sure it works?

@redmoe-moutain
Copy link
Contributor Author

@ArthurZucker Thank you for your insightful review. I've updated the license and testing as suggested.

While we've tested on PP, we haven't yet covered TP testing cases. Could you provide some examples of how we should approach TP? Any guidance would be greatly appreciated.

@redmoe-moutain
Copy link
Contributor Author

@ArthurZucker Could you please take another look?

@Rocketknight1
Copy link
Member

cc @Cyrilvallez for core maintainer review since Arthur is out!

Copy link
Member
@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazingly simple modular! Super nice 🤗 Added a few comments, but this is truly almost ready to be shipped 👌
For TP, you can check out the example in the doc here. Let me know if something is still unclear!
You could add more integration tests as well, maybe try beyond the sliding window etc as in Qwen3 but this is optional

@@ -0,0 +1,40 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 2025! 🤗

@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here haha

Comment on lines 74 to 76
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental: tensor parallelism rank used during pretraining. This is necessary for exact reproducibility
of pretraining results.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed!

Comment on lines 89 to 90
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best is to not have this arg, and simply check sliding_window is None instead - so to remove

Comment on lines +1 to +2
from ...modeling_outputs import CausalLMOutputWithPast
from ...processing_utils import Unpack
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a license here at the top

Comment on lines 11 to 14
from ..llama.modeling_llama import (
KwargsForCausalLM,
LlamaRMSNorm,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's import those 2 classes from Qwen3 instead! They are similar, and it's easier to follow if we import everything from the same model!

@@ -0,0 +1,144 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025 as well!

# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, do_sample=False)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(text)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this print

@redmoe-moutain
Copy link
Contributor Author

@Cyrilvallez Thanks for the review. It looks much cleaner now!

I followed the documentation to test tp with:

model = AutoModelForCausalLM.from_pretrained("rednote-hilab/dots.llm1.inst", tp_plan="auto", torch_dtype=torch.bfloat16)

However, I encountered the following error:

[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/GitHub/transformers/src/transformers/models/dots1/modeling_dots1.py", line 308, in forward
[rank0]:     topk_indices, topk_weights = self.gate(hidden_states)
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/mnt/miniconda3/envs/vllm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1803, in inner
[rank0]:     hook_result = hook(self, args, result)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/GitHub/transformers/src/transformers/integrations/tensor_parallel.py", line 329, in <lambda>
[rank0]:     module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
[rank0]:                                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/GitHub/transformers/src/transformers/integrations/tensor_parallel.py", line 447, in _prepare_output_fn
[rank0]:     return outputs.to_local() if use_local_output else outputs
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'tuple' object has no attribute 'to_local'

I modified it to apply .to_local() to each element if outputs is a tuple:

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
    if isinstance(outputs, tuple):
        return tuple(output.to_local() if use_local_output else output for output in outputs)
    return outputs.to_local() if use_local_output else outputs

After this change, it works as expected.
Let me know if you’d like me to open a separate issue or PR to further discuss.

Thanks!

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not seeing the changes to tenso 8000 r parallel but yes let's open a different PR and let's merge this!


## Overview

The `dots.llm1` model was proposed in dots.llm1 technical report by rednote-hilab team.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a hot link here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've added the hyperlink.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit 7503cb9 into huggingface:main Jun 25, 2025
18 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for bearing with us and kudos for the release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0