8000 Black formatting · coderonion/llama-cpp-python@2cc4995 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2cc4995

Browse files
committed
Black formatting
1 parent d29b05b commit 2cc4995

File tree

6 files changed

+121
-35
lines changed

6 files changed

+121
-35
lines changed

examples/fastapi_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
from fastapi import FastAPI
66
from pydantic import BaseModel, BaseSettings, Field
77

8+
89
class Settings(BaseSettings):
910
model: str
1011

12+
1113
app = FastAPI(
1214
title="🦙 llama.cpp Python API",
1315
version="0.0.1",
1416
)
1517
settings = Settings()
1618
llama = Llama(settings.model)
1719

20+
1821
class CompletionRequest(BaseModel):
1922
prompt: str
2023
suffix: Optional[str] = Field(None)
@@ -31,12 +34,11 @@ class Config:
3134
schema_extra = {
3235
"example": {
3336
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
34-
"stop": ["\n", "###"]
37+
"stop": ["\n", "###"],
3538
}
3639
}
3740

3841

39-
4042
@app.post("/v1/completions")
4143
def completions(request: CompletionRequest):
42-
return llama(**request.dict())
44+
return llama(**request.dict())

examples/high_level_api_basic_inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
llm = Llama(model_path=args.model)
1111

12-
output = llm("Question: What are the names of the planets in the solar system? Answer: ", max_tokens=48, stop=["Q:", "\n"], echo=True)
12+
output = llm(
13+
"Question: What are the names of the planets in the solar system? Answer: ",
14+
max_tokens=48,
15+
stop=["Q:", "\n"],
16+
echo=True,
17+
)
1318

14-
print(json.dumps(output, indent=2))
19+
print(json.dumps(output, indent=2))

examples/langchain_custom_llm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langchain.llms.base import LLM
66
from typing import Optional, List, Mapping, Any
77

8+
89
class LlamaLLM(LLM):
910
model_path: str
1011
llm: Llama
@@ -16,7 +17,7 @@ def _llm_type(self) -> str:
1617
def __init__(self, model_path: str, **kwargs: Any):
1718
model_path = model_path
1819
llm = Llama(model_path=model_path)
19-
super().__init__(model_path=model_path, llm=llm, **kwargs)
20+
super().__init__(model_path=model_path, llm=llm, **kwargs)
2021

2122
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
2223
response = self.llm(prompt, stop=stop or [])
@@ -26,6 +27,7 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
2627
def _identifying_params(self) -> Mapping[str, Any]:
2728
return {"model_path": self.model_path}
2829

30+
2931
parser = argparse.ArgumentParser()
3032
parser.add_argument("-m", "--model", type=str, default="./models/...")
3133
args = parser.parse_args()
@@ -34,7 +36,9 @@ def _identifying_params(self) -> Mapping[str, Any]:
3436
llm = LlamaLLM(model_path=args.model)
3537

3638
# Basic Q&A
37-
answer = llm("Question: What is the capital of France? Answer: ", stop=["Question:", "\n"])
39+
answer = llm(
40+
"Question: What is the capital of France? Answer: ", stop=["Question:", "\n"]
41+
)
3842
print(f"Answer: {answer.strip()}")
3943

4044
# Using in a chain
@@ -48,4 +52,4 @@ def _identifying_params(self) -> Mapping[str, Any]:
4852
chain = LLMChain(llm=llm, prompt=prompt)
4953

5054
# Run the chain only specifying the input variable.
51-
print(chain.run("colorful socks"))
55+
print(chain.run("colorful socks"))

examples/low_level_api_inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,15 @@
2727
n = 8
2828

2929
for i in range(n):
30-
id = llama_cpp.llama_sample_top_p_top_k(ctx, (llama_cpp.c_int * len(embd))(*embd), n_of_tok + i, 40, 0.8, 0.2, 1.0/0.85)
30+
id = llama_cpp.llama_sample_top_p_top_k(
31+
ctx,
32+
(llama_cpp.c_int * len(embd))(*embd),
33+
n_of_tok + i,
34+
40,
35+
0.8,
36+
0.2,
37+
1.0 / 0.85,
38+
)
3139

3240
embd.append(id)
3341

@@ -38,4 +46,4 @@
3846

3947
llama_cpp.llama_free(ctx)
4048

41-
print(prediction.decode("utf-8"))
49+
print(prediction.decode("utf-8"))

llama_cpp/llama.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from . import llama_cpp
77

8+
89
class Llama:
910
def __init__(
1011
self,
@@ -82,7 +83,10 @@ def __call__(
8283

8384
for i in range(max_tokens):
8485
tokens_seen = prompt_tokens + completion_tokens
85-
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [self.tokens[j] for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)]
86+
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [
87< 10000 span class="diff-text-marker">+
self.tokens[j]
88+
for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)
89+
]
8690

8791
token = llama_cpp.llama_sample_top_p_top_k(
8892
self.ctx,
@@ -128,9 +132,8 @@ def __call__(
128132
self.ctx,
129133
)[:logprobs]
130134

131-
132135
return {
133-
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
136+
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
134137
"object": "text_completion",
135138
"created": int(time.time()),
136139
"model": self.model_path,
@@ -151,5 +154,3 @@ def __call__(
151154

152155
def __del__(self):
153156
llama_cpp.llama_free(self.ctx)
154-
155-

llama_cpp/llama_cpp.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import ctypes
22

3-
from ctypes import c_int, c_float, c_double, c_char_p, c_void_p, c_bool, POINTER, Structure
3+
from ctypes import (
4+
c_int,
5+
c_float,
6+
c_double,
7+
c_char_p,
8+
c_void_p,
9+
c_bool,
10+
POINTER,
11+
Structure,
12+
)
413

514
import pathlib
615

@@ -13,26 +22,32 @@
1322
llama_token = c_int
1423
llama_token_p = POINTER(llama_token)
1524

25+
1626
class llama_token_data(Structure):
1727
_fields_ = [
18-
('id', llama_token), # token id
19-
('p', c_float), # probability of the token
20-
('plog', c_float), # log probability of the token
28+
("id", llama_token), # token id
29+
("p", c_float), # probability of the token
30+
("plog", c_float), # log probability of the token
2131
]
2232

33+
2334
llama_token_data_p = POINTER(llama_token_data)
2435

36+
2537
class llama_context_params(Structure):
2638
_fields_ = [
27-
('n_ctx', c_int), # text context
28-
('n_parts', c_int), # -1 for default
29-
('seed', c_int), # RNG seed, 0 for random
30-
('f16_kv', c_bool), # use fp16 for KV cache
31-
('logits_all', c_bool), # the llama_eval() call computes all logits, not just the last one
32-
33-
('vocab_only', c_bool), # only load the vocabulary, no weights
39+
("n_ctx", c_int), # text context
40+
("n_parts", c_int), # -1 for default
41+
("seed", c_int), # RNG seed, 0 for random
42+
("f16_kv", c_bool), # use fp16 for KV cache
43+
(
44+
"logits_all",
45+
c_bool,
46+
), # the llama_eval() call computes all logits, not just the last one
47+
("vocab_only", c_bool), # only load the vocabulary, no weights
3448
]
3549

50+
3651
llama_context_params_p = POINTER(llama_context_params)
3752

3853
llama_context_p = c_void_p
@@ -74,7 +89,15 @@ class llama_context_params(Structure):
7489
lib.llama_token_eos.argtypes = []
7590
lib.llama_token_eos.restype = llama_token
7691

77-
lib.llama_sample_top_p_top_k.argtypes = [llama_context_p, llama_token_p, c_int, c_int, c_double, c_double, c_double]
92+
lib.llama_sample_top_p_top_k.argtypes = [
93+
llama_context_p,
94+
llama_token_p,
95+
c_int,
96+
c_int,
97+
c_double,
98+
c_double,
99+
c_double,
100+
]
78101
lib.llama_sample_top_p_top_k.restype = llama_token
79102

80103
lib.llama_print_timings.argtypes = [llama_context_p]
@@ -86,45 +109,71 @@ class llama_context_params(Structure):
86109
lib.llama_print_system_info.argtypes = []
87110
lib.llama_print_system_info.restype = c_char_p
88111

112+
89113
# Python functions
90114
def llama_context_default_params() -> llama_context_params:
91115
params = lib.llama_context_default_params()
92116
return params
93117

94-
def llama_init_from_file(path_model: bytes, params: llama_context_params) -> llama_context_p:
118+
119+
def llama_init_from_file(
120+
path_model: bytes, params: llama_context_params
121+
) -> llama_context_p:
95122
"""Various functions for loading a ggml llama model.
96123
Allocate (almost) all memory needed for the model.
97-
Return NULL on failure """
124+
Return NULL on failure"""
98125
return lib.llama_init_from_file(path_model, params)
99126

127+
100128
def llama_free(ctx: llama_context_p):
101129
"""Free all allocated memory"""
102130
lib.llama_free(ctx)
103131

104-
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int) -> c_int:
132+
133+
def llama_model_quantize(
134+
fname_inp: bytes, fname_out: bytes, itype: c_int, qk: c_int
135+
) -> c_int:
105136
"""Returns 0 on success"""
106137
return lib.llama_model_quantize(fname_inp, fname_out, itype, qk)
107138

108-
def llama_eval(ctx: llama_context_p, tokens: llama_token_p, n_tokens: c_int, n_past: c_int, n_threads: c_int) -> c_int:
139+
140+
def llama_eval(
141+
ctx: llama_context_p,
142+
tokens: llama_token_p,
143+
n_tokens: c_int,
144+
n_past: c_int,
145+
n_threads: c_int,
146+
) -> c_int:
109147
"""Run the llama inference to obtain the logits and probabilities for the next token.
110148
tokens + n_tokens is the provided batch of new tokens to process
111149
n_past is the number of tokens to use from previous eval calls
112150
Returns 0 on success"""
113151
return lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
114152

115-
def llama_tokenize(ctx: llama_context_p, text: bytes, tokens: llama_token_p, n_max_tokens: c_int, add_bos: c_bool) -> c_int:
153+
154+
def llama_tokenize(
155+
ctx: llama_context_p,
156+
text: bytes,
157+
tokens: llama_token_p,
158+
n_max_tokens: c_int,
159+
add_bos: c_bool,
160+
) -> c_int:
116161
"""Convert the provided text into tokens.
117162
The tokens pointer must be large enough to hold the resulting tokens.
118163
Returns the number of tokens on success, no more than n_max_tokens
119-
Returns a negative number on failure - the number of tokens that would have been returned"""
164+
Returns a negative number on failure - the number of tokens that would have been returned
165+
"""
120166
return lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
121167

168+
122169
def llama_n_vocab(ctx: llama_context_p) -> c_int:
123170
return lib.llama_n_vocab(ctx)
124171

172+
125173
def llama_n_ctx(ctx: llama_context_p) -> c_int:
126174
return lib.llama_n_ctx(ctx)
127175

176+
128177
def llama_get_logits(ctx: llama_context_p):
129178
"""Token logits obtained from the last call to llama_eval()
130179
The logits for the last token are stored in the last row
@@ -133,25 +182,42 @@ def llama_get_logits(ctx: llama_context_p):
133182
Cols: n_voc 10000 ab"""
134183
return lib.llama_get_logits(ctx)
135184

185+
136186
def llama_token_to_str(ctx: llama_context_p, token: int) -> bytes:
137187
"""Token Id -> String. Uses the vocabulary in the provided context"""
138188
return lib.llama_token_to_str(ctx, token)
139189

190+
140191
def llama_token_bos() -> llama_token:
141192
return lib.llama_token_bos()
142193

194+
143195
def llama_token_eos() -> llama_token:
144196
return lib.llama_token_eos()
145197

146-
def llama_sample_top_p_top_k(ctx: llama_context_p, last_n_tokens_data: llama_token_p, last_n_tokens_size: c_int, top_k: c_int, top_p: c_double, temp: c_double, repeat_penalty: c_double) -> llama_token:
147-
return lib.llama_sample_top_p_top_k(ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty)
198+
199+
def llama_sample_top_p_top_k(
200+
ctx: llama_context_p,
201+
last_n_tokens_data: llama_token_p,
202+
last_n_tokens_size: c_int,
203+
top_k: c_int,
204+
top_p: c_double,
205+
temp: c_double,
206+
repeat_penalty: c_double,
207+
) -> llama_token:
208+
return lib.llama_sample_top_p_top_k(
209+
ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty
210+
)
211+
148212

149213
def llama_print_timings(ctx: llama_context_p):
150214
lib.llama_print_timings(ctx)
151215

216+
152217
def llama_reset_timings(ctx: llama_context_p):
153218
lib.llama_reset_timings(ctx)
154219

220+
155221
def llama_print_system_info() -> bytes:
156222
"""Print system informaiton"""
157223
return lib.llama_print_system_info()

0 commit comments

Comments
 (0)
0