1
1
from torch .utils .data import DataLoader
2
2
from modules .data import tokenization
3
- from modules .utils .utils import ipython_info
4
3
import torch
5
4
import pandas as pd
6
5
import numpy as np
7
- from tqdm ._tqdm_notebook import tqdm_notebook
8
6
from tqdm import tqdm
9
7
import json
10
8
@@ -195,7 +193,8 @@ def get_data(
195
193
if is_meta :
196
194
meta_tokens .extend ([meta [idx_ ]] * len (cur_tokens ))
197
195
bert_tokens .extend (cur_tokens )
198
- bert_label = [prefix + label ] + ["X" ] * (len (cur_tokens ) - 1 ) # ["I_" + label] * (len(cur_tokens) - 1)
196
+ # ["I_" + label] * (len(cur_tokens) - 1)
197
+ bert_label = [prefix + label ] + ["X" ] * (len (cur_tokens ) - 1 )
199
198
bert_labels .extend (bert_label )
200
199
bert_tokens .append ("[SEP]" )
201
200
bert_labels .append ("[SEP]" )
@@ -262,14 +261,13 @@ def get_data(
262
261
263
262
264
263
def get_bert_data_loaders (train , valid , vocab_file , batch_size = 16 , cuda = True , is_cls = False ,
265
- do_lower_case = False , max_seq_len = 424 , is_meta = False ):
264
+ do_lower_case = False , max_seq_len = 424 , is_meta = False , label2idx = None , cls2idx = None ):
266
265
train = pd .read_csv (train )
267
266
valid = pd .read_csv (valid )
268
267
269
- cls2idx = None
270
-
271
268
tokenizer = tokenization .FullTokenizer (vocab_file = vocab_file , do_lower_case = do_lower_case )
272
- train_f , label2idx = get_data (train , tokenizer , is_cls = is_cls , max_seq_len = max_seq_len , is_meta = is_meta )
269
+ train_f , label2idx = get_data (
270
+ train , tokenizer , label2idx , cls2idx = cls2idx , is_cls = is_cls , max_seq_len = max_seq_len , is_meta = is_meta )
273
271
if is_cls :
274
272
label2idx , cls2idx = label2idx
275
273
train_dl = DataLoaderForTrain (
@@ -300,8 +298,14 @@ def get_bert_data_loader_for_predict(path, learner):
300
298
301
299
class BertNerData (object ):
302
300
303
- def __init__ (self , train_dl , valid_dl , tokenizer , label2idx , max_seq_len = 424 ,
301
+ def __init__ (self , train_path , valid_path , vocab_file , data_type ,
302
+ train_dl = None , valid_dl = None , tokenizer = None ,
303
+ label2idx = None , max_seq_len = 424 ,
304
304
cls2idx = None , batch_size = 16 , cuda = True , is_meta = False ):
305
+ self .train_path = train_path
306
+ self .valid_path = valid_path
307
+ self .data_type = data_type
308
+ self .vocab_file = vocab_file
305
309
self .train_dl = train_dl
306
310
self .valid_dl = valid_dl
307
311
self .tokenizer = tokenizer
@@ -317,13 +321,44 @@ def __init__(self, train_dl, valid_dl, tokenizer, label2idx, max_seq_len=424,
317
321
self .is_cls = True
318
322
self .id2cls = sorted (cls2idx .keys (), key = lambda x : cls2idx [x ])
319
323
324
+ @classmethod
325
+ def from_config (cls , config ):
326
+ if config ["data_type" ] == "bert_cased" :
327
+ do_lower_case = False
328
+ fn = get_bert_data_loaders
329
+ elif config ["data_type" ] == "bert_uncased" :
330
+ do_lower_case = True
331
+ fn = get_bert_data_loaders
332
+ else :
333
+ raise NotImplementedError ("No requested mode :(." )
334
+ return cls (
335
+ config ["train_path" ], config ["valid_path" ], config ["vocab_file" ], config ["data_type" ],
336
+ * fn (config ["train_path" ], config ["valid_path" ], config ["vocab_file" ], config ["batch_size" ],
337
+ config ["cuda" ], config ["is_cls" ], do_lower_case , config ["max_seq_len" ], config ["is_meta" ],
338
+ label2idx = config ["label2idx" ], cls2idx = config ["cls2idx" ]),
339
+ batch_size = config ["batch_size" ], cuda = config ["cuda" ], is_meta = config ["is_meta" ])
340
+
341
+ def get_config (self ):
342
+ config = {
343
+ "train_path" : self .train_path ,
344
+ "valid_path" : self .valid_path ,
345
+ "vocab_file" : self .vocab_file ,
346
+ "data_type" : self .data_type ,
347
+ "max_seq_len" : self .max_seq_len ,
348
+ "batch_size" : self .batch_size ,
349
+ "cuda" : self .cuda ,
350
+ "is_meta" : self .is_meta ,
351
+ "label2idx" : self .label2idx ,
352
+ "cls2idx" : self .cls2idx
353
+ }
354
+ return config
355
+ # with open(config_path, "w") as f:
356
+ # json.dump(config, f)
357
+
320
358
@classmethod
321
359
def create (cls ,
322
360
train_path , valid_path , vocab_file , batch_size = 16 , cuda = True , is_cls = False ,
323
361
data_type = "bert_cased" , max_seq_len = 424 , is_meta = False ):
324
- if ipython_info ():
325
- global tqdm_notebook
326
- tqdm_notebook = tqdm
327
362
if data_type == "bert_cased" :
328
363
do_lower_case = False
329
364
fn = get_bert_data_loaders
@@ -332,6 +367,6 @@ def create(cls,
332
367
fn = get_bert_data_loaders
333
368
else :
334
369
raise NotImplementedError ("No requested mode :(." )
335
- return cls (* fn (
370
+ return cls (train_path , valid_path , vocab_file , data_type , * fn (
336
371
train_path , valid_path , vocab_file , batch_size , cuda , is_cls , do_lower_case , max_seq_len , is_meta ),
337
372
batch_size = batch_size , cuda = cuda , is_meta = is_meta )
0 commit comments