@@ -264,32 +2
8000
64,7 @@ def get_data(
264
264
assert len (input_ids ) == len (input_type_ids )
265
265
assert len (input_ids ) == len (labels_ids )
266
266
assert len (input_ids ) == len (labels_mask )
267
- if is_cls :
268
- return features , (label2idx , cls2idx )
269
- return features , label2idx
270
-
271
-
272
- def get_bert_data_loaders (train , valid , vocab_file , batch_size = 16 , cuda = True , is_cls = False ,
273
- do_lower_case = False , max_seq_len = 424 , is_meta = False , label2idx = None , cls2idx = None ):
274
- train = pd .read_csv (train )
275
- valid = pd .read_csv (valid )
276
-
277
- tokenizer = tokenization .FullTokenizer (vocab_file = vocab_file , do_lower_case = do_lower_case )
278
- train_f , label2idx = get_data (
279
- train , tokenizer , label2idx , cls2idx = cls2idx , is_cls = is_cls , max_seq_len = max_seq_len , is_meta = is_meta )
280
- if is_cls :
281
- label2idx , cls2idx = label2idx
282
- train_dl = DataLoaderForTrain (
283
- train_f , batch_size = batch_size , shuffle = True , cuda = cuda )
284
- valid_f , label2idx = get_data (
285
- valid , tokenizer , label2idx , cls2idx = cls2idx , is_cls = is_cls , max_seq_len = max_seq_len , is_meta = is_meta )
286
- if is_cls :
287
- label2idx , cls2idx = label2idx
288
- valid_dl = DataLoaderForTrain (
289
- valid_f , batch_size = batch_size , cuda = cuda , shuffle = False )
290
- if is_cls :
291
- return train_dl , valid_dl , tokenizer , label2idx , max_seq_len , cls2idx
292
- return train_dl , valid_dl , tokenizer , label2idx , max_seq_len
267
+ return features , label2idx , cls2idx , meta2idx
293
268
294
269
295
270
def get_bert_data_loader_for_predict (path , learner ):
0 commit comments