8000 fix get_data return · flypythoncom/ner-bert@8052aea · GitHub
[go: up one dir, main page]

Skip to content

Commit 8052aea

Browse files
committed
fix get_data return
1 parent f6dba0e commit 8052aea

File tree

1 file changed

+1
-26
lines changed

1 file changed

+1
-26
lines changed

modules/data/bert_data.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -264,32 +2 8000 64,7 @@ def get_data(
264264
assert len(input_ids) == len(input_type_ids)
265265
assert len(input_ids) == len(labels_ids)
266266
assert len(input_ids) == len(labels_mask)
267-
if is_cls:
268-
return features, (label2idx, cls2idx)
269-
return features, label2idx
270-
271-
272-
def get_bert_data_loaders(train, valid, vocab_file, batch_size=16, cuda=True, is_cls=False,
273-
do_lower_case=False, max_seq_len=424, is_meta=False, label2idx=None, cls2idx=None):
274-
train = pd.read_csv(train)
275-
valid = pd.read_csv(valid)
276-
277-
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
278-
train_f, label2idx = get_data(
279-
train, tokenizer, label2idx, cls2idx=cls2idx, is_cls=is_cls, max_seq_len=max_seq_len, is_meta=is_meta)
280-
if is_cls:
281-
label2idx, cls2idx = label2idx
282-
train_dl = DataLoaderForTrain(
283-
train_f, batch_size=batch_size, shuffle=True, cuda=cuda)
284-
valid_f, label2idx = get_data(
285-
valid, tokenizer, label2idx, cls2idx=cls2idx, is_cls=is_cls, max_seq_len=max_seq_len, is_meta=is_meta)
286-
if is_cls:
287-
label2idx, cls2idx = label2idx
288-
valid_dl = DataLoaderForTrain(
289-
valid_f, batch_size=batch_size, cuda=cuda, shuffle=False)
290-
if is_cls:
291-
return train_dl, valid_dl, tokenizer, label2idx, max_seq_len, cls2idx
292-
return train_dl, valid_dl, tokenizer, label2idx, max_seq_len
267+
return features, label2idx, cls2idx, meta2idx
293268

294269

295270
def get_bert_data_loader_for_predict(path, learner):

0 commit comments

Comments
 (0)
0