8000 fix bugs · flypythoncom/ner-bert@730cd17 · GitHub
[go: up one dir, main page]

Skip to content

Commit 730cd17

Browse files
author
Ubuntu
committed
fix bugs
1 parent 41e4bc1 commit 730cd17

File tree

3 files changed

+10
-6
lines changed
  • modules
    • data
      • < 10000 div class="PRIVATE_VisuallyHidden prc-TreeView-TreeViewVisuallyHidden-4-mPv" aria-hidden="true" id=":R2qmtddabH1:">
        bert_data.py
  • train
  • utils
  • 3 files changed

    +10
    -6
    lines changed

    modules/data/bert_data.py

    Lines changed: 3 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -135,6 +135,7 @@ def collate_fn(self, data):
    135135
    def get_data(
    136136
    df, tokenizer, label2idx=None, max_seq_len=424, pad="<pad>", cls2idx=None,
    137137
    is_cls=False, is_meta=False):
    138+
    tqdm_notebook = tqdm
    138139
    if label2idx is None:
    139140
    label2idx = {pad: 0, '[CLS]': 1, '[SEP]': 2}
    140141
    features = []
    @@ -173,8 +174,8 @@ def get_data(
    173174
    bert_tokens.append("[CLS]")
    174175
    bert_labels.append("[CLS]")
    175176
    orig_tokens = []
    176-
    orig_tokens.extend(text.split())
    177-
    labels = labels.split()
    177+
    orig_tokens.extend(str(text).split())
    178+
    labels = str(labels).split()
    178179
    pad_idx = label2idx[pad]
    179180
    assert len(orig_tokens) == len(labels)
    180181
    prev_label = ""

    modules/train/train.py

    Lines changed: 6 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -16,7 +16,8 @@ def train_step(dl, model, optimizer, lr_scheduler=None, clip=None, num_epoch=1):
    1616
    model.train()
    1717
    epoch_loss = 0
    1818
    idx = 0
    19-
    for batch in tqdm_notebook(dl, total=len(dl), leave=False):
    19+
    pr = tqdm_notebook(dl, total=len(dl), leave=False)
    20+
    for batch in pr:
    2021
    idx += 1
    2122
    model.zero_grad()
    2223
    loss = model.score(batch)
    @@ -25,7 +26,9 @@ def train_step(dl, model, optimizer, lr_scheduler=None, clip=None, num_epoch=1):
    2526
    _ = torch.nn.utils.clip_grad_norm(model.parameters(), clip)
    2627
    optimizer.step()
    2728
    optimizer.zero_grad()
    28-
    epoch_loss += loss.data.cpu().tolist()
    29+
    loss = loss.data.cpu().tolist()
    30+
    epoch_loss += loss
    31+
    pr.set_description("train loss: {}".format(epoch_loss / idx))
    2932
    if lr_scheduler is not None:
    3033
    lr_scheduler.step()
    3134
    # torch.cuda.empty_cache()
    @@ -133,7 +136,7 @@ def predict(dl, model, id2label, id2cls=None):
    133136
    class NerLearner(object):
    134137
    def __init__(self, model, data, best_model_path, lr=0.001, betas=list([0.8, 0.9]), clip=5,
    135138
    verbose=True, sup_labels=None, t_total=-1, warmup=0.1, weight_decay=0.01):
    136-
    if ipython_info():
    139+
    if ipython_info() or True:
    137140
    global tqdm_notebook
    138141
    tqdm_notebook = tqdm
    139142
    self.model = model

    modules/utils/utils.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -13,7 +13,7 @@ def voting_choicer(tok_map, labels):
    1313
    for origin_idx in tok_map:
    1414

    1515
    vote_labels = Counter(
    16-
    ["I_" + l.split("_")[1] if l not in ["[SEP]", "[CLS]"] else "B_O" for l in labels[prev_idx:origin_idx]])
    16+
    ["I_" + l.split("_")[1] if l not in ["[SEP]", "[CLS]"] else "I_O" for l in labels[prev_idx:origin_idx]])
    1717
    # vote_labels = Counter(c)
    1818
    lb = sorted(list(vote_labels), key=lambda x: vote_labels[x])
    1919
    if len(lb):

    0 commit comments

    Comments
     (0)
    0