8000 added GPTNeoXForTokenClassification (#23002) · githubhjs/transformers@614e191 · GitHub
[go: up one dir, main page]

Skip to content

Commit 614e191

Browse files
peter-sksgugger
andauthored
added GPTNeoXForTokenClassification (huggingface#23002)
* initial commit * added GPTNeoXForTokenClassification * typo * doc fixed extra comma that turned into a tuple * unifying variable names fixing forward call * classifier_dropout is in config Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 1933231 commit 614e191

File tree

9 files changed

+130
-5
lines changed

9 files changed

+130
-5
lines changed

docs/source/en/model_doc/gpt_neox.mdx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
8282
## GPTNeoXForSequenceClassification
8383
8484
[[autodoc]] GPTNeoXForSequenceClassification
85-
- forward
85+
- forward
86+
87+
## GPTNeoXForTokenClassification
88+
89+
[[autodoc]] GPTNeoXForTokenClassification
90+
- forward

docs/source/en/tasks/token_classification.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
2828

2929
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
3030

31-
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
31+
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
3232

3333
<!--End of the generated tip-->
3434

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,7 @@
16971697
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
16981698
"GPTNeoXForCausalLM",
16991699
"GPTNeoXForSequenceClassification",
1700+
"GPTNeoXForTokenClassification",
17001701
"GPTNeoXLayer",
17011702
"GPTNeoXModel",
17021703
"GPTNeoXPreTrainedModel",
@@ -5230,6 +5231,7 @@
52305231
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
52315232
GPTNeoXForCausalLM,
52325233
GPTNeoXForSequenceClassification,
5234+
GPTNeoXForTokenClassification,
52335235
GPTNeoXLayer,
52345236
GPTNeoXModel,
52355237
GPTNeoXPreTrainedModel,

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@
814814
("gpt-sw3", "GPT2ForTokenClassification"),
815815
("gpt2", "GPT2ForTokenClassification"),
816816
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
817+
("gpt_neox", "GPTNeoXForTokenClassification"),
817818
("ibert", "IBertForTokenClassification"),
818819
("layoutlm", "LayoutLMForTokenClassification"),
819820
("layoutlmv2", "LayoutLMv2ForTokenClassification"),

src/transformers/models/gpt_neox/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
3838
"GPTNeoXForCausalLM",
3939
"GPTNeoXForSequenceClassification",
40+
"GPTNeoXForTokenClassification",
4041
"GPTNeoXLayer",
4142
"GPTNeoXModel",
4243
"GPTNeoXPreTrainedModel",
@@ -64,6 +65,7 @@
6465
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
6566
GPTNeoXForCausalLM,
6667
GPTNeoXForSequenceClassification,
68+
GPTNeoXForTokenClassification,
6769
GPTNeoXLayer,
6870
GPTNeoXModel,
6971
GPTNeoXPreTrainedModel,

src/transformers/models/gpt_neox/configuration_gpt_neox.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class GPTNeoXConfig(PretrainedConfig):
5656
percentage of hidden dimensions to allocate to rotary embeddings
5757
rotary_emb_base (`int`, *optional*, defaults to 10000)
5858
base for computing rotary embeddings frequency
59+
classifier_dropout (`float`, *optional*, defaults to 0.1):
60+
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
61+
62+
The dropout ratio for the hidden layer.
5963
max_position_embeddings (`int`, *optional*, defaults to 2048):
6064
The maximum sequence length that this model might ever be used with. Typically set this to something large
6165
just in case (e.g., 512 or 1024 or 2048).
@@ -95,6 +99,7 @@ def __init__(
9599
hidden_act="gelu",
96100
rotary_pct=0.25,
97101
rotary_emb_base=10000,
102+
classifier_dropout=0.1,
98103
max_position_embeddings=2048,
99104
initializer_range=0.02,
100105
layer_norm_eps=1e-5,
@@ -115,6 +120,7 @@ def __init__(
115120
self.hidden_act = hidden_act
116121
self.rotary_pct = rotary_pct
117122
self.rotary_emb_base = rotary_emb_base
123+
self.classifier_dropout = classifier_dropout
118124
self.initializer_range = initializer_range
119125
self.layer_norm_eps = layer_norm_eps
120126
self.use_cache = use_cache

src/transformers/models/gpt_neox/modeling_gpt_neox.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
add_start_docstrings_to_model_forward,
2929
replace_return_docstrings,
3030
)
31-
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31+
from ...modeling_outputs import (
32+
BaseModelOutputWithPast,
33+
CausalLMOutputWithPast,
34+
SequenceClassifierOutputWithPast,
35+
TokenClassifierOutput,
36+
)
3237
from ...modeling_utils import PreTrainedModel
3338
from ...utils import logging
3439
from .configuration_gpt_neox import GPTNeoXConfig
@@ -873,3 +878,80 @@ def forward(
873878
hidden_states=outputs.hidden_states,
874879
attentions=outputs.attentions,
875880
)
881+
882+
883+
class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
884+
def __init__(self, config):
885+
super().__init__(config)
886+
self.num_labels = config.num_labels
887+
888+
self.gpt_neox = GPTNeoXModel(config)
889+
self.dropout = nn.Dropout(config.classifier_dropout)
890+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
891+
892+
# Initialize weights and apply final processing
893+
self.post_init()
894+
895+
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
896+
@add_code_sample_docstrings(
897+
checkpoint="LarsJonasson/pythia-410m-deduped-sft-swedish",
898+
output_type=TokenClassifierOutput,
899+
config_class=_CONFIG_FOR_DOC,
900+
expected_loss=0.25,
901+
)
902+
def forward(
903+
self,
904+
input_ids: Optional[torch.LongTensor] = None,
905+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
906+
attention_mask: Optional[torch.FloatTensor] = None,
907+
token_type_ids: Optional[torch.LongTensor] = None,
908+
position_ids: Optional[torch.LongTensor] = None,
909+
head_mask: Optional[torch.FloatTensor] = None,
910+
inputs_embeds: Optional[torch.FloatTensor] = None,
911+
labels: Optional[torch.LongTensor] = None,
912+
use_cache: Optional[bool] = None,
913+
output_attentions: Optional[bool] = None,
914+
output_hidden_states: Optional[bool] = None,
915+
return_dict: Optional[bool] = None,
916+
) -> Union[Tuple, TokenClassifierOutput]:
917+
r"""
918+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
919+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
920+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
921+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
922+
"""
923+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
924+
925+
outputs = self.gpt_neox(
926+
input_ids,
927+
past_key_values=past_key_values,
928+
attention_mask=attention_mask,
929+
position_ids=position_ids,
930+
head_mask=head_mask,
931+
inputs_embeds=inputs_embeds,
932+
use_cache=use_cache,
933+
output_attentions=output_attentions,
934+
output_hidden_states=output_hidden_states,
935+
return_dict=return_dict,
936+
)
937+
938+
hidden_states = outputs[0]
939+
hidden_states = self.dropout(hidden_states)
940+
logits = self.classifier(hidden_states)
941+
942+
loss = None
943+
if labels is not None:
944+
labels = labels.to(logits.device)
945+
loss_fct = CrossEntropyLoss()
946+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
947+
948+
if not return_dict:
949+
output = (logits,) + outputs[2:]
950+
return ((loss,) + output) if loss is not None else output
951+
952+
return TokenClassifierOutput(
953+
loss=loss,
954+
logits=logits,
955+
hidden_states=outputs.hidden_states,
956+
attentions=outputs.attentions,
957+
)

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,6 +3308,13 @@ def __init__(self, *args, **kwargs):
33083308
requires_backends(self, ["torch"])
33093309

33103310

3311+
class GPTNeoXForTokenClassification(metaclass=DummyObject):
3312+
_backends = ["torch"]
3313+
3314+
def __init__(self, *args, **kwargs):
3315+
requires_backends(self, ["torch"])
3316+
3317+
33113318
class GPTNeoXLayer(metaclass=DummyObject):
33123319
_backends = ["torch"]
33133320

tests/models/gpt_neox/test_modeling_gpt_neox.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
if is_torch_available():
3030
import torch
3131

32-
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
32+
from transformers import (
33+
GPTNeoXForCausalLM,
34+
GPTNeoXForSequenceClassification,
35+
GPTNeoXForTokenClassification,
36+
GPTNeoXModel,
37+
)
3338

3439

3540
class GPTNeoXModelTester:
@@ -153,6 +158,14 @@ def create_and_check_for_sequence_classification(self, config, input_ids, input_
153158
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
154159
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
155160

161+
def create_and_check_for_token_classification(self, config, input_ids, input_mask, token_labels):
162+
config.num_labels = self.num_labels
163+
model = GPTNeoXForTokenClassification(config)
164+
model.to(torch_device)
165+
model.eval()
166+
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
167+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
168+
156169
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
157170
config.is_decoder = True
158171
model = GPTNeoXForCausalLM(config=config)
@@ -200,13 +213,16 @@ def prepare_config_and_inputs_for_common(self):
200213
@require_torch
201214
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
202215
all_model_classes = (
203-
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
216+
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification)
217+
if is_torch_available()
218+
else ()
204219
)
205220
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
206221
pipeline_model_mapping = (
207222
{
208223
"feature-extraction": GPTNeoXModel,
209224
"text-classification": GPTNeoXForSequenceClassification,
225+
"token-classification": GPTNeoXForTokenClassification,
210226
"text-generation": GPTNeoXForCausalLM,
211227
"zero-shot": GPTNeoXForSequenceClassification,
212228
}
@@ -253,6 +269,10 @@ def test_model_for_sequence_classification(self):
253269
config_and_inputs = self.model_tester.prepare_config_and_inputs()
254270
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
255271

272+
def test_model_for_token_classification(self):
273+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
274+
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
275+
256276
@unittest.skip(reason="Feed forward chunking is not implemented")
257277
def test_feed_forward_chunking(self):
258278
pass

0 commit comments

Comments
 (0)
0