@@ -212,15 +212,16 @@ def get_ppl(self, input_ids: Union[List[int],
212
212
logger .info (f'sorted indices: { indices } ' )
213
213
for (start , end ) in self ._batch_iterator (sizes , max_input_len ):
214
214
logger .info (f'start: { start } , end: { end } ' )
215
- _input_ids = [input_ids [indices [i ]] for i in range (start , end )]
216
215
if start == end :
216
+ _input_ids = input_ids [indices [start ]]
217
217
loss , target_count = self ._get_long_text_ppl (
218
218
generator = generator ,
219
219
input_ids = _input_ids ,
220
220
max_input_len = max_input_len )
221
221
losses .append (loss )
222
222
target_counts .append (target_count )
223
223
else :
224
+ _input_ids = [input_ids [indices [i ]] for i in range (start , end )]
224
225
loss , target_count = self ._get_ppl (
225
226
generator = generator ,
226
227
input_ids = _input_ids ,
@@ -261,24 +262,24 @@ def _batch_iterator(self, sizes, max_value):
261
262
i += 1
262
263
263
264
def _get_long_text_ppl (self , generator , input_ids , max_input_len ):
264
- assert isinstance (input_ids , List ) and len ( input_ids ) == 1
265
- seq_len = len (input_ids [ 0 ] )
265
+ assert all ( isinstance (_ , int ) for _ in input_ids )
266
+ seq_len = len (input_ids )
266
267
assert seq_len > max_input_len
267
268
logger .info (f'get long text ppl: seq_len { seq_len } ' )
268
269
269
270
losses = []
270
271
target_counts = []
271
272
for i in range (0 , seq_len , max_input_len ):
272
- token_ids = input_ids [:, i :i + max_input_len ]
273
+ token_ids = input_ids [i :i + max_input_len ]
273
274
step = [i ]
274
275
# shift token_ids by 1 to the left
275
- target_ids = input_ids [:, i + 1 :i + 1 + max_input_len ]
276
+ target_ids = input_ids [i + 1 :i + 1 + max_input_len ]
276
277
277
278
loss , target_count = self ._get_ppl (
278
279
generator = generator ,
279
- input_ids = token_ids ,
280
+ input_ids = [ token_ids ] ,
280
281
max_input_len = max_input_len ,
281
- target_ids = target_ids ,
282
+ target_ids = [ target_ids ] ,
282
283
steps = step ,
283
284
sequence_start = (i == 0 ),
284
285
sequence_end = (i + max_input_len >= seq_len ))
0 commit comments