8000 Clean up logprobs implementation · coderonion/llama-cpp-python@6153baa · GitHub
[go: up one dir, main page]

Skip to content

Commit 6153baa

Browse files
committed
Clean up logprobs implementation
1 parent 26cc4ee commit 6153baa

File tree

1 file changed

+39
-67
lines changed

1 file changed

+39
-67
lines changed

llama_cpp/llama.py

Lines changed: 39 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -351,55 +351,19 @@ def _create_completion(
351351
else:
352352
stop_sequences = []
353353

354-
text_offset = 0
355-
text_offsets: List[int] = []
356-
token_logprobs: List[float] = []
357-
tokens: List[str] = []
358-
top_logprobs: List[Dict[str, float]] = []
359-
360-
self.reset()
361-
self.eval(prompt_tokens)
362-
363354
if logprobs is not None and self.params.logits_all is False:
364355
raise ValueError(
365356
"logprobs is not supported for models created with logits_all=False"
366357
)
367358

368-
if logprobs is not None:
369-
token_strs = [
370-
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
371-
]
372-
logprobs_all = [
373-
[Llama.logit_to_logprob(logit) for logit in row]
374-
for row in self.all_logits
375-
]
376-
for token, token_str, logprobs_token in zip(
377-
prompt_tokens, token_strs, logprobs_all
378-
):
379-
text_offsets.append(text_offset)
380-
text_offset += len(token_str)
381-
tokens.append(token_str)
382-
sorted_logprobs = list(
383-
sorted(
384-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
385-
)
386-
)
387-
token_logprobs.append(sorted_logprobs[int(token)][0])
388-
top_logprob = {
389-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
390-
for logprob, i in sorted_logprobs[:logprobs]
391-
}
392-
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
393-
top_logprobs.append(top_logprob)
394-
395359
finish_reason = "length"
396-
while True:
397-
token = self.sample(
398-
top_k=top_k,
399-
top_p=top_p,
400-
temp=temperature,
401-
repeat_penalty=repeat_penalty,
402-
)
360+
for token in self.generate(
361+
prompt_tokens,
362+
top_k=top_k,
363+
top_p=top_p,
364+
temp=temperature,
365+
repeat_penalty=repeat_penalty,
366+
):
403367
if token == llama_cpp.llama_token_eos():
404368
text = self.detokenize(completion_tokens)
405369
finish_reason = "stop"
@@ -443,34 +407,10 @@ def _create_co 8000 mpletion(
443407
],
444408
}
445409

446-
if logprobs is not None:
447-
# TODO: Confirm wether this should happen before or after
448-
# next eval.
449-
token_str = self.detokenize([token]).decode("utf-8")
450-
text_offsets.append(text_offset)
451-
text_offset += len(token_str)
452-
tokens.append(token_str)
453-
logprobs_token = [
454-
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
455-
]
456-
sorted_logprobs = list(
457-
sorted(
458-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
459-
)
460-
)
461-
token_logprobs.append(sorted_logprobs[int(token)][0])
462-
top_logprob = {
463-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
464-
for logprob, i in sorted_logprobs[:logprobs]
465-
}
466-
top_logprob.update({token_str: logprobs_token[int(token)]})
467-
top_logprobs.append(top_logprob)
468-
469410
if len(completion_tokens) >= max_tokens:
470411
text = self.detokenize(completion_tokens)
471412
finish_reason = "length"
472413
break
473-
self.eval([token])
474414

475415
if stream:
476416
yield {
@@ -499,6 +439,38 @@ def _create_completion(
499439

500440
logprobs_or_none: Optional[CompletionLogprobs] = None
501441
if logprobs is not None:
442+
text_offset = 0
443+
text_offsets: List[int] = []
444+
token_logprobs: List[float] = []
445+
tokens: List[str] = []
446+
top_logprobs: List[Dict[str, float]] = []
447+
448+
all_tokens = prompt_tokens + completion_tokens
449+
all_token_strs = [
450+
self.detokenize([token]).decode("utf-8") for token in all_tokens
451+
]
452+
all_logprobs = [
453+
[Llama.logit_to_logprob(logit) for logit in row]
454+
for row in self.all_logits
455+
]
456+
for token, token_str, logprobs_token in zip(
457+
all_tokens, all_token_strs, all_logprobs
458+
):
459+
text_offsets.append(text_offset)
460+
text_offset += len(token_str)
461+
tokens.append(token_str)
462+
sorted_logprobs = list(
463+
sorted(
464+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
465+
)
466+
)
467+
token_logprobs.append(sorted_logprobs[int(token)][0])
468+
top_logprob = {
469+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
470+
for logprob, i in sorted_logprobs[:logprobs]
471+
}
472+
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
473+
top_logprobs.append(top_logprob)
502474
logprobs_or_none = {
503475
"tokens": tokens,
504476
"text_offset": text_offsets,

0 commit comments

Comments
 (0)
0