10000 fix index error when computing ppl on long-text prompt (#2697) · InternLM/lmdeploy@58c9126 · GitHub
[go: up one dir, main page]

Skip to content

Commit 58c9126

Browse files
committed
fix index error when computing ppl on long-text prompt (#2697)
* fix index error when computing ppl on long-text prompt * update user guide
1 parent 1d2b9c6 commit 58c9126

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

docs/en/llm/pipeline.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ logits = pipe.get_logits(input_ids)
136136
ppl = pipe.get_ppl(input_ids)
137137
```
138138

139+
```{note}
140+
get_ppl returns the cross entropy loss without applying the exponential operation afterwards
141+
```
142+
139143
- **Below is an example for pytorch backend. Please install triton first.**
140144

141145
```shell

docs/zh_cn/llm/pipeline.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ logits = pipe.get_logits(input_ids)
136136
ppl = pipe.get_ppl(input_ids)
137137
```
138138

139+
```{note}
140+
get_ppl 返回的是 cross entropy loss,没有在之后加 exp 操作
141+
```
142+
139143
- **使用 pytorch 后端**
140144

141145
需要先安装 triton

lmdeploy/serve/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,16 @@ def get_ppl(self, input_ids: Union[List[int],
212212
logger.info(f'sorted indices: {indices}')
213213
for (start, end) in self._batch_iterator(sizes, max_input_len):
214214
logger.info(f'start: {start}, end: {end}')
215-
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
216215
if start == end:
216+
_input_ids = input_ids[indices[start]]
217217
loss, target_count = self._get_long_text_ppl(
218218
generator=generator,
219219
input_ids=_input_ids,
220220
max_input_len=max_input_len)
221221
losses.append(loss)
222222
target_counts.append(target_count)
223223
else:
224+
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
224225
loss, target_count = self._get_ppl(
225226
generator=generator,
226227
input_ids=_input_ids,
@@ -261,24 +262,24 @@ def _batch_iterator(self, sizes, max_value):
261262
i += 1
262263

263264
def _get_long_text_ppl(self, generator, input_ids, max_input_len):
264-
assert isinstance(input_ids, List) and len(input_ids) == 1
265-
seq_len = len(input_ids[0])
265+
assert all(isinstance(_, int) for _ in input_ids)
266+
seq_len = len(input_ids)
266267
assert seq_len > max_input_len
267268
logger.info(f'get long text ppl: seq_len {seq_len}')
268269

269270
losses = []
270271
target_counts = []
271272
for i in range(0, seq_len, max_input_len):
272-
token_ids = input_ids[:, i:i + max_input_len]
273+
token_ids = input_ids[i:i + max_input_len]
273274
step = [i]
274275
# shift token_ids by 1 to the left
275-
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]
276+
target_ids = input_ids[i + 1:i + 1 + max_input_len]
276277

277278
loss, target_count = self._get_ppl(
278279
generator=generator,
279-
input_ids=token_ids,
280+
input_ids=[token_ids],
280281
max_input_len=max_input_len,
281-
target_ids=target_ids,
282+
target_ids=[target_ids],
282283
steps=step,
283284
sequence_start=(i == 0),
284285
sequence_end=(i + max_input_len >= seq_len))

0 commit comments

Comments
 (0)
0