8000 add BertBiLSTMNCRF as released model · flypythoncom/ner-bert@fd69b4e · GitHub
[go: up one dir, main page]

Skip to content

Commit fd69b4e

Browse files
committed
add BertBiLSTMNCRF as released model
1 parent 1f1c62d commit fd69b4e

File tree

7 files changed

+491
-15
lines changed

7 files changed

+491
-15
lines changed

modules/data/bert_data.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from torch.utils.data import DataLoader
22
from modules.data import tokenization
3-
from modules.utils.utils import ipython_info
43
import torch
54
import pandas as pd
65
import numpy as np
7-
from tqdm._tqdm_notebook import tqdm_notebook
86
from tqdm import tqdm
97
import json
108

@@ -195,7 +193,8 @@ def get_data(
195193
if is_meta:
196194
meta_tokens.extend([meta[idx_]] * len(cur_tokens))
197195
bert_tokens.extend(cur_tokens)
198-
bert_label = [prefix + label] + ["X"] * (len(cur_tokens) - 1) # ["I_" + label] * (len(cur_tokens) - 1)
196+
# ["I_" + label] * (len(cur_tokens) - 1)
197+
bert_label = [prefix + label] + ["X"] * (len(cur_tokens) - 1)
199198
bert_labels.extend(bert_label)
200199
bert_tokens.append("[SEP]")
201200
bert_labels.append("[SEP]")
@@ -262,14 +261,13 @@ def get_data(
262261

263262

264263
def get_bert_data_loaders(train, valid, vocab_file, batch_size=16, cuda=True, is_cls=False,
265-
do_lower_case=False, max_seq_len=424, is_meta=False):
264+
do_lower_case=False, max_seq_len=424, is_meta=False, label2idx=None, cls2idx=None):
266265
train = pd.read_csv(train)
267266
valid = pd.read_csv(valid)
268267

269-
cls2idx = None
270-
271268
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
272-
train_f, label2idx = get_data(train, tokenizer, is_cls=is_cls, max_seq_len=max_seq_len, is_meta=is_meta)
269+
train_f, label2idx = get_data(
270+
train, tokenizer, label2idx, cls2idx=cls2idx, is_cls=is_cls, max_seq_len=max_seq_len, is_meta=is_meta)
273271
if is_cls:
274272
label2idx, cls2idx = label2idx
275273
train_dl = DataLoaderForTrain(
@@ -300,8 +298,14 @@ def get_bert_data_loader_for_predict(path, learner):
300298

301299
class BertNerData(object):
302300

303-
def __init__(self, train_dl, valid_dl, tokenizer, label2idx, max_seq_len=424,
301+
def __init__(self, train_path, valid_path, vocab_file, data_type,
302+
train_dl=None, valid_dl=None, tokenizer=None,
303+
label2idx=None, max_seq_len=424,
304304
cls2idx=None, batch_size=16, cuda=True, is_meta=False):
305+
self.train_path = train_path
306+
self.valid_path = valid_path
307+
self.data_type = data_type
308+
self.vocab_file = vocab_file
305309
self.train_dl = train_dl
306310
self.valid_dl = valid_dl
307311
self.tokenizer = tokenizer
@@ -317,13 +321,44 @@ def __init__(self, train_dl, valid_dl, tokenizer, label2idx, max_seq_len=424,
317321
self.is_cls = True
318322
self.id2cls = sorted(cls2idx.keys(), key=lambda x: cls2idx[x])
319323

324+
@classmethod
325+
def from_config(cls, config):
326+
if config["data_type"] == "bert_cased":
327+
do_lower_case = False
328+
fn = get_bert_data_loaders
329+
elif config["data_type"] == "bert_uncased":
330+
do_lower_case = True
331+
fn = get_bert_data_loaders
332+
else:
333+
raise NotImplementedError("No requested mode :(.")
334+
return cls(
335+
config["train_path"], config["valid_path"], config["vocab_file"], config["data_type"],
336+
*fn(config["train_path"], config["valid_path"], config["vocab_file"], config["batch_size"],
337+
config["cuda"], config["is_cls"], do_lower_case, config["max_seq_len"], config["is_meta"],
338+
label2idx=config["label2idx"], cls2idx=config["cls2idx"]),
339+
batch_size=config["batch_size"], cuda=config["cuda"], is_meta=config["is_meta"])
340+
341+
def get_config(self):
342+
config = {
343+
"train_path": self.train_path,
344+
"valid_path": self.valid_path,
345+
"vocab_file": self.vocab_file,
346+
"data_type": self.data_type,
347+
"max_seq_len": self.max_seq_len,
348+
"batch_size": self.batch_size,
349+
"cuda": self.cuda,
350+
"is_meta": self.is_meta,
351+
"label2idx": self.label2idx,
352+
"cls2idx": self.cls2idx
353+
}
354+
return config
355+
# with open(config_path, "w") as f:
356+
# json.dump(config, f)
357+
320358
@classmethod
321359
def create(cls,
322360
train_path, valid_path, vocab_file, batch_size=16, cuda=True, is_cls=False,
323361
data_type="bert_cased", max_seq_len=424, is_meta=False):
324-
if ipython_info():
325-
global tqdm_notebook
326-
tqdm_notebook = tqdm
327362
if data_type == "bert_cased":
328363
do_lower_case = False
329364
fn = get_bert_data_loaders
@@ -332,6 +367,6 @@ def create(cls,
332367
fn = get_bert_data_loaders
333368
else:
334369
raise NotImplementedError("No requested mode :(.")
335-
return cls(*fn(
370+
return cls(train_path, valid_path, vocab_file, data_type, *fn(
336371
train_path, valid_path, vocab_file, batch_size, cuda, is_cls, do_lower_case, max_seq_len, is_meta),
337372
batch_size=batch_size, cuda=cuda, is_meta=is_meta)

0 commit comments

Comments
 (0)
0