8000 Add Arcee model support by Crystalcareai · Pull Request #38621 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

Add Arcee model support #38621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 24, 2025
Merged

Add Arcee model support #38621

merged 23 commits into from
Jun 24, 2025

Conversation

Crystalcareai
Copy link
Contributor

Summary

This PR adds support for the Arcee model architecture, laying the groundwork for the upcoming Arcee Foundation Model (AFM) release. Arcee is a decoder-only transformer model based on the Llama architecture with a key modification: it uses ReLU² (ReLU-squared) activation in the MLP blocks instead of SiLU, following recent research showing improved training efficiency with squared activations.

Model Description

Arcee is architecturally similar to Llama but with the following distinctions:

  • ReLU² activation: Uses x * relu(x) in MLP layers for improved gradient flow
  • Optimized for efficiency: Designed with training and inference efficiency in mind
  • Extended context: Supports extended context with RoPE scaling

Implementation Details

  • Modular implementation inheriting from Llama components where applicable
  • Custom ArceeMLP class implementing the ReLU² activation
  • Full support for all standard transformers features:
    • Flash Attention 2, SDPA, and other attention backends
    • Gradient checkpointing
    • Quantization support (including quantized caches)
    • All standard model variants (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification)

Testing

  • Added comprehensive test suite following standard transformers test patterns
  • Tests for all model variants and core functionality
  • Specific test for ReLU² activation verification
  • RoPE scaling tests including YARN support
  • Tested model forward and backward passes
  • Verified compatibility with existing architecture
  • Model loading and forward passes verified
  • Compatibility with existing infrastructure confirmed

Crystalcareai and others added 8 commits June 4, 2025 14:23
- Add ArceeConfig and model mappings for all task types (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification)
- Add auto-loading support through AutoModel, AutoConfig, and AutoTokenizer
- Use LlamaTokenizer for tokenization
- Add FX graph support for Arcee models
- Create lazy loading module structure for Arcee
- Add test_modeling_arcee.py following standard transformers test patterns
- Include tests for all model variants (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification)
- Add specific test for ReLU² activation in ArceeMLP
- Add RoPE scaling tests including YARN support
- Follow CausalLMModelTest pattern used by similar models
- Add comprehensive model documentation with usage examples
- Include all model variants in autodoc
- Add to table of contents in proper alphabetical order
- Fixes documentation coverage for Arcee model classes
@Rocketknight1
Copy link
Member

looks good @Crystalcareai! Feel free to ping us whenever you're ready for review. You can also resolve the code style errors with pip install -e .[quality] followed by make style or make fixup

@Crystalcareai
Copy link
Contributor Author
Crystalcareai commented Jun 11, 2025

@Rocketknight1 Hey I think I'm ready for a review, Got a lot of the tests passing though still getting some failures that don't seem to be related to my code. Let me know how best I can get this ready for merging.

Copy link
Member
@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Very clean first implementation with modular, congrats!! 🤗 We can still make it even simpler though, see my comments 🚀 Also, let's make sure the copyright at the top of files have the correct informations (dates and company names mostly)

But very nice work in general! 🤗

@@ -0,0 +1,104 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 2025! 🤗

@@ -0,0 +1,32 @@
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong dates and companies here as well

Comment on lines 3 to 6
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true, to remove

Comment on lines 1 to 7
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above here!

Comment on lines 215 to 229
class ArceeMLP(nn.Module):
"""Arcee MLP with configurable activation function (typically relu2)"""

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.up_proj(x)))
return down_proj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could actually inherit that directly from Nemotron as it's 1:1 similar!

Comment on lines 288 to 326
"""
The Arcee Model transformer with a sequence classification head on top (linear layer).
"""

def __init__(self, config):
self.config_class = ArceeConfig
super().__init__(config)
self.model = ArceeModel(config)
# Initialize weights and apply final processing
self.post_init()


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
"""
The Arcee Model transformer with a span classification head on top for extractive question-answering tasks.
"""

def __init__(self, config):
self.config_class = ArceeConfig
super().__init__(config)
# Note: LlamaForQuestionAnswering uses self.transformer, not self.model
self.transformer = ArceeModel(config)
# Initialize weights and apply final processing
self.post_init()


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForTokenClassification(LlamaForTokenClassification):
"""
The Arcee Model transformer with a token classification head on top.
"""

def __init__(self, config):
self.config_class = ArceeConfig
super().__init__(config)
self.model = ArceeModel(config)
# Initialize weights and apply final processing
self.post_init()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for those, we don't actually need to rewrite the init at all 🤗

Comment on lines 118 to 119
"altclip",
"arcee",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be reverted - this is more for legacy purposes

Comment on lines 90 to 116
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for scaling_type in ["linear", "dynamic"]:
config.rope_scaling = {"type": scaling_type, "factor": 2.0}
model = ArceeModel(config)
model.to(torch_device)
model.eval()
input_ids = torch.randint(0, config.vocab_size, (1, 10)).to(torch_device)
with torch.no_grad():
model(input_ids)

def test_model_rope_scaling_yarn(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.rope_scaling = {
"type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": 2048,
"attention_factor": 1.0,
"beta_fast": 32,
"beta_slow": 1,
}
model = ArceeModel(config)
model.to(torch_device)
model.eval()
input_ids = torch.randint(0, config.vocab_size, (1, 10)).to(torch_device)
with torch.no_grad():
model(input_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests are not needed and can be removed

57AE
Comment on lines 118 to 135
def test_arcee_mlp_uses_relu_squared(self):
"""Test that ArceeMLP uses ReLU² activation instead of SiLU."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.hidden_act = "relu2" # Ensure we're using relu2 activation
model = ArceeModel(config)

# Check that the MLP layers use the correct activation
for layer in model.layers:
mlp = layer.mlp
# Test with a simple input
x = torch.randn(1, 10, config.hidden_size)
up_output = mlp.up_proj(x)

# Verify ReLU² activation: x * relu(x)
expected_activation = up_output * torch.relu(up_output)
actual_activation = mlp.act_fn(up_output)

self.assertTrue(torch.allclose(expected_activation, actual_activation, atol=1e-5))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind testing this to make sure, but let's not use a for loop if we don't actually use the loop!

@@ -0,0 +1,153 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

date here as well haha

@pranav4501
Copy link
Contributor
pranav4501 commented Jun 16, 2025

Hi @Cyrilvallez ,
Thanks for the feedback, made the requested refactoring changes.
Also, while removing the init from the modular implementation as suggested, the generated modeling code does not have self.config_class = ArceeConfig from the previous version. Is that redundant as well?

@Cyrilvallez
Copy link
Member
Cyrilvallez commented Jun 19, 2025

Also, while removing the init from the modular implementation as suggested, the generated modeling code does not have self.config_class = ArceeConfig from the previous version. Is that redundant as well?

Yes, it's already in the PreTrainedModel!

Copy link
Member
@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, it's perfect! 🤗🤗 Just added some minor comments about the config (most notably, the tp_plan should reflect the new MLP (this is very important for your model to be available as backend in vllm/TGI and other frameworks), but otherwise all good! Great work! I'll merge as soon as you make those small changes 🤗

Comment on lines 85 to 89
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg should be removed!

Comment on lines +137 to +138

model_type = "arcee"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should add the following TP plan as class attribute (we need to change as the MLP is slightly different from Llama):

base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }

pad_token_id=None,
bos_token_id=128000,
eos_token_id=128001,
pretraining_tp=1,
Copy link
Member
@Cyrilvallez Cyrilvallez Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretraining_tp to remove here

Comment on lines 192 to 197
# Validate the correctness of rotary position embeddings parameters using Arcee's custom validation
# BC: if there is a 'type' field, copy it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to overwrite the check, but we should delete the pretraining_tp attribute to make it disppear in actual config

Suggested change
# Validate the correctness of rotary position embeddings parameters using Arcee's custom validation
# BC: if there is a 'type' field, copy it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
del self.pretraining_tp

pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pretraining_tp=pretraining_tp,
Copy link
Member
@Cyrilvallez Cyrilvallez Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretraining_tp to remove here as well

Comment on lines +108 to +122
@require_torch_accelerator
class ArceeIntegrationTest(unittest.TestCase):
def tearDown(self):
import gc

gc.collect()
torch.cuda.empty_cache()

@slow
def test_model_from_pretrained(self):
# This test would be enabled once a pretrained model is available
# For now, we just test that the model can be instantiated
config = ArceeConfig()
model = ArceeForCausalLM(config)
self.assertIsInstance(model, ArceeForCausalLM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add a few Integration test based on real checkpoints as well if possible! 🤗 Otherwise you can open another PR later if more convenient for you

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@pranav4501
Copy link
Contributor

@Cyrilvallez Thanks for the feedback, removed the pretraining TP from the configurations and added scaffolding for generation integration testing. We will add more robust integration tests and update the checkpoints with the release.

Copy link
Member
@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, merging! Thanks a lot! TP plan is still wrong, but I'll update it myself after the merge! 🤗🚀

@Cyrilvallez Cyrilvallez merged commit 71de20b into huggingface:main Jun 24, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0