8000 Use pre-allocated buffers to store input_ids and scores · sirajperson/falcon-cpp-python@b95b0ff · GitHub
[go: up one dir, main page]

Skip to content

Commit b95b0ff

Browse files
committed
Use pre-allocated buffers to store input_ids and scores
1 parent a5e059c commit b95b0ff

File tree

1 file changed

+44
-42
lines changed

1 file changed

+44
-42
lines changed

llama_cpp/llama.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
# NOTE: This puts an integer as key in cache, which breaks,
144+
# NOTE: This puts an integer as key in cache, which breaks,
145145
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146146
# self.cache.push(_key, side="front") # type: ignore
147147
return value
@@ -166,17 +166,15 @@ def __setitem__(self, key: Sequence[int], value: "LlamaState"):
166166
class LlamaState:
167167
def __init__(
168168
self,
169-
eval_tokens: Deque[int],
170-
eval_logits: Deque[List[float]],
171169
input_ids: npt.NDArray[np.intc],
172170
scores: npt.NDArray[np.single],
171+
n_tokens: int,
173172
llama_state: bytes,
174173
llama_state_size: int,
175174
):
176-
self.eval_tokens = eval_tokens
177-
self.eval_logits = eval_logits
178175
self.input_ids = input_ids
179176
self.scores = scores
177+
self.n_tokens = n_tokens
180178
self.llama_state = llama_state
181179
self.llama_state_size = llama_state_size
182180

@@ -267,8 +265,6 @@ def __init__(
267265

268266
self.last_n_tokens_size = last_n_tokens_size
269267
self.n_batch = min(n_ctx, n_batch)
270-
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
271-
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
272268

273269
self.cache: Optional[BaseLlamaCache] = None
274270

@@ -329,8 +325,30 @@ def __init__(
329325
self._token_nl = Llama.token_nl()
330326
self._token_eos = Llama.token_eos()
331327

332-
self._input_ids = np.array([], dtype=np.intc)
333-
self._scores 8000 : npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single)
328+
self.n_tokens = 0
329+
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
330+
self.scores: npt.NDArray[np.single] = np.ndarray(
331+
(n_ctx, self._n_vocab), dtype=np.single
332+
)
333+
334+
@property
335+
def _input_ids(self) -> npt.NDArray[np.intc]:
336+
return self.input_ids[: self.n_tokens]
337+
338+
@property
339+
def _scores(self) -> npt.NDArray[np.single]:
340+
return self.scores[: self.n_tokens, :]
341+
342+
@property
343+
def eval_tokens(self) -> Deque[int]:
344+
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
345+
346+
@property
347+
def eval_logits(self) -> Deque[List[float]]:
348+
return deque(
349+
self.scores[: self.n_tokens, :].tolist(),
350+
maxlen=self._n_ctx if self.params.logits_all else 1,
351+
)
334352

335353
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
336354
"""Tokenize a string.
@@ -397,10 +415,7 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
397415

398416
def reset(self):
399417
"""Reset the model state."""
400-
self.eval_tokens.clear()
401-
self.eval_logits.clear()
402-
self._input_ids = np.array([], dtype=np.intc)
403-
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
418+
self.n_tokens = 0
404419

405420
def eval(self, tokens: Sequence[int]):
406421
"""Evaluate a list of tokens.
@@ -410,7 +425,6 @@ def eval(self, tokens: Sequence[int]):
410425
"""
411426
assert self.ctx is not None
412427
n_ctx = self._n_ctx
413-
scores: List[npt.NDArray[np.single]] = []
414428
for i in range(0, len(tokens), self.n_batch):
415429
batch = tokens[i : min(len(tokens), i + self.n_batch)]
416430
n_past = min(n_ctx - len(batch), len(self._input_ids))
@@ -425,19 +439,16 @@ def eval(self, tokens: Sequence[int]):
425439
if return_code != 0:
426440
raise RuntimeError(f"llama_eval returned {return_code}")
427441
# Save tokens
428-
self.eval_tokens.extend(batch)
429-
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
430-
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
431-
)
442+
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
432443
# Save logits
433444
rows = n_tokens if self.params.logits_all else 1
434445
n_vocab = self._n_vocab
435446
cols = n_vocab
436447
logits_view = llama_cpp.llama_get_logits(self.ctx)
437448
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
438-
self.eval_logits.extend(logits)
439-
scores.append(np.array(logits, dtype=np.single))
440-
self._scores = np.concatenate(scores)
449+
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
450+
# Update n_tokens
451+
self.n_tokens += n_tokens
441452

442453
def _sample(
443454
self,
@@ -457,8 +468,7 @@ def _sample(
457468
logits_processor: Optional[LogitsProcessorList] = None,
458469
):
459470
assert self.ctx is not None
460-
assert len(self.eval_logits) > 0
461-
assert self._scores.shape[0] > 0
471+
assert self.n_tokens > 0
462472
n_vocab = self._n_vocab
463473
n_ctx = self._n_ctx
464474
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -475,7 +485,6 @@ def _sample(
475485
dtype=np.single,
476486
)
477487
self._scores[-1, :] = logits
478-
self.eval_logits[-1] = logits.tolist()
479488

480489
nl_logit = logits[self._token_nl]
481490
candidates = self._candidates
@@ -672,14 +681,7 @@ def generate(
672681
print("Llama.generate: prefix-match hit", file=sys.stderr)
673682
reset = False
674683
tokens = tokens[longest_prefix:]
675-
self._input_ids = self._input_ids[:longest_prefix]
676-
self._scores = self._scores[:longest_prefix, :]
677-
for _ in range(len(self.eval_tokens) - longest_prefix):
678-
self.eval_tokens.pop()
679-
try:
680-
self.eval_logits.pop()
681-
except IndexError:
682-
pass
684+
self.n_tokens = longest_prefix
683685

684686
if reset:
685687
self.reset()
@@ -819,7 +821,9 @@ def _create_completion(
819821
llama_cpp.llama_reset_timings(self.ctx)
820822

821823
if len(prompt_tokens) > self._n_ctx:
822-
raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}")
824+
raise ValueError(
825+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
826+
)
823827

824828
# Truncate max_tokens if requested tokens would exceed the context window
825829
max_tokens = (
@@ -1513,22 +1517,20 @@ def save_state(self) -> LlamaState:
15131517
file=sys.stderr,
15141518
)
15151519
return LlamaState(
1516-
eval_tokens=self.eval_tokens.copy(),
1517-
eval_logits=self.eval_logits.copy(),
1518-
scores=self._scores.copy(),
1519-
input_ids=self._input_ids.copy(),
1520+
scores=self.scores.copy(),
1521+
input_ids=self.input_ids.copy(),
1522+
n_tokens=self.n_tokens,
15201523
llama_state=bytes(llama_state_compact),
15211524
llama_state_size=n_bytes,
15221525
)
15231526

15241527
def load_state(self, state: LlamaState) -> None:
15251528
assert self.ctx is not None
1526-
self.eval_tokens = state.eval_tokens.copy()
1527-
self.eval_logits = state.eval_logits.copy()
1528-
self._scores = state.scores.copy()
1529-
self._input_ids = state.input_ids.copy()
1529+
self.scores = state.scores.copy()
1530+
self.input_ids = state.input_ids.copy()
1531+
self.n_tokens = state.n_tokens
15301532
state_size = state.llama_state_size
1531-
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1533+
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
15321534
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
15331535

15341536
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:

0 commit comments

Comments
 (0)
0