|
4 | 4 | from torch import nn
|
5 | 5 | from .layers import Linears, MultiHeadAttention
|
6 | 6 | from .crf import CRF
|
| 7 | +from .ncrf import NCRF |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class CRFDecoder(nn.Module):
|
@@ -555,3 +556,49 @@ def create(cls, label_size, intent_size,
|
555 | 556 | return cls(label_size=label_size, intent_size=intent_size,
|
556 | 557 | embedding_dim=embedding_dim, hidden_dim=hidden_dim,
|
557 | 558 | rnn_layers=rnn_layers, dropout_p=dropout_p, pad_idx=pad_idx, use_cuda=use_cuda)
|
| 559 | + |
| 560 | + |
| 561 | +class AttnNCRFJointDecoder(nn.Module): |
| 562 | + def __init__(self, |
| 563 | + crf, label_size, input_dim, intent_size, input_dropout=0.5, |
| 564 | + key_dim=64, val_dim=64, num_heads=3, nbest=8): |
| 565 | + super(AttnNCRFJointDecoder, self).__init__() |
| 566 | + self.input_dim = input_dim |
| 567 | + self.attn = MultiHeadAttention(key_dim, val_dim, input_dim, num_heads, input_dropout) |
| 568 | + self.linear = Linears(in_features=input_dim, |
| 569 | + out_features=label_size, |
| 570 | + hiddens=[input_dim // 2]) |
| 571 | + self.crf = crf |
| 572 | + self.label_size = label_size |
| 573 | + self.intent_size = intent_size |
| 574 | + self.intent_out = PoolingLinearClassifier(input_dim, intent_size, input_dropout) |
| 575 | + self.intent_loss = nn.CrossEntropyLoss() |
| 576 | + self.nbest = nbest |
| 577 | + |
| 578 | + def forward_model(self, inputs, labels_mask=None): |
| 579 | + batch_size, seq_len, input_dim = inputs.size() |
| 580 | + inputs, hidden = self.attn(inputs, inputs, inputs, labels_mask) |
| 581 | + intent_output = self.intent_out(inputs) |
| 582 | + output = inputs.contiguous().view(-1, self.input_dim) |
| 583 | + # Fully-connected layer |
| 584 | + output = self.linear.forward(output) |
| 585 | + output = output.view(batch_size, seq_len, self.label_size) |
| 586 | + return output, intent_output |
| 587 | + |
| 588 | + def forward(self, inputs, labels_mask): |
| 589 | + self.eval() |
| 590 | + logits, intent_output = self.forward_model(inputs) |
| 591 | + _, preds = self.crf._viterbi_decode_nbest(logits, labels_mask, self.nbest) |
| 592 | + self.train() |
| 593 | + return preds, intent_output.argmax(-1) |
| 594 | + |
| 595 | + def score(self, inputs, labels_mask, labels, cls_ids): |
| 596 | + logits, intent_output = self.forward_model(inputs) |
| 597 | + crf_score = self.crf.neg_log_likelihood_loss(logits, labels_mask, labels) / logits.shape[0] |
| 598 | + return crf_score + self.intent_loss(intent_output, cls_ids) |
| 599 | + |
| 600 | + @classmethod |
| 601 | + def create(cls, label_size, input_dim, intent_size, input_dropout=0.5, key_dim=64, |
| 602 | + val_dim=64, num_heads=3, use_cuda=True, nbest=8): |
| 603 | + return cls(NCRF(label_size + 2, use_cuda), label_size, input_dim, intent_size, input_dropout, |
| 604 | + key_dim, val_dim, num_heads, nbest) |
0 commit comments