8000 feat: Add option to configure n_ubatch · jeffmaury/llama-cpp-python@6c44a3f · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c44a3f

Browse files
committed
feat: Add option to configure n_ubatch
1 parent 47d7a62 commit 6c44a3f

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
7676
n_ctx: int = 512,
7777
n_batch: int = 512,
78+
n_ubatch: int = 512,
7879
n_threads: Optional[int] = None,
7980
n_threads_batch: Optional[int] = None,
8081
rope_scaling_type: Optional[
@@ -156,6 +157,7 @@ def __init__(
156157
seed: RNG seed, -1 for random
157158
n_ctx: Text context, 0 = from model
158159
n_batch: Prompt processing maximum batch size
160+
n_ubatch: Physical batch size
159161
n_threads: Number of threads to use for generation
160162
n_threads_batch: Number of threads to use for batch processing
161163
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -309,6 +311,7 @@ def __init__(
309311
self.context_params = llama_cpp.llama_context_default_params()
310312
self.context_params.n_ctx = n_ctx
311313
self.context_params.n_batch = self.n_batch
314+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
312315
self.context_params.n_threads = self.n_threads
313316
self.context_params.n_threads_batch = self.n_threads_batch
314317
self.context_params.rope_scaling_type = (
@@ -380,6 +383,7 @@ def __init__(
380383
self.n_batch = min(n_ctx, n_batch)
381384
self.context_params.n_ctx = self._model.n_ctx_train()
382385
self.context_params.n_batch = self.n_batch
386+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
383387

384388
self._ctx = self._stack.enter_context(
385389
contextlib.closing(
@@ -2071,6 +2075,7 @@ def __getstate__(self):
20712075
seed=self.context_params.seed,
20722076
n_ctx=self.context_params.n_ctx,
20732077
n_batch=self.n_batch,
2078+
n_ubatch=self.context_params.n_ubatch,
20742079
n_threads=self.context_params.n_threads,
20752080
n_threads_batch=self.context_params.n_threads_batch,
20762081
rope_scaling_type=self.context_params.rope_scaling_type,

llama_cpp/server/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
249249
seed=settings.seed,
250250
n_ctx=settings.n_ctx,
251251
n_batch=settings.n_batch,
252+
n_ubatch=settings.n_ubatch,
252253
n_threads=settings.n_threads,
253254
n_threads_batch=settings.n_threads_batch,
254255
rope_scaling_type=settings.rope_scaling_type,

llama_cpp/server/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class ModelSettings(BaseSettings):
7070
n_batch: int = Field(
7171
default=512, ge=1, description="The batch size to use per eval."
7272
)
73+
n_ubatch: int = Field(
74+
default=512, ge=1, description="The physical batch size used by llama.cpp"
75+
)
7376
n_threads: int = Field(
7477
default=max(multiprocessing.cpu_count() // 2, 1),
7578
ge=1,

0 commit comments

Comments
 (0)
0