@@ -75,6 +75,7 @@ def __init__(
75
75
seed : int = llama_cpp .LLAMA_DEFAULT_SEED ,
76
76
n_ctx : int = 512 ,
77
77
n_batch : int = 512 ,
78
+ n_ubatch : int = 512 ,
78
79
n_threads : Optional [int ] = None ,
79
80
n_threads_batch : Optional [int ] = None ,
80
81
rope_scaling_type : Optional [
@@ -156,6 +157,7 @@ def __init__(
156
157
seed: RNG seed, -1 for random
157
158
n_ctx: Text context, 0 = from model
158
159
n_batch: Prompt processing maximum batch size
160
+ n_ubatch: Physical batch size
159
161
n_threads: Number of threads to use for generation
160
162
n_threads_batch: Number of threads to use for batch processing
161
163
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -309,6 +311,7 @@ def __init__(
309
311
self .context_params = llama_cpp .llama_context_default_params ()
310
312
self .context_params .n_ctx = n_ctx
311
313
self .context_params .n_batch = self .n_batch
314
+ self .context_params .n_ubatch = min (self .n_batch , n_ubatch )
312
315
self .context_params .n_threads = self .n_threads
313
316
self .context_params .n_threads_batch = self .n_threads_batch
314
317
self .context_params .rope_scaling_type = (
@@ -380,6 +383,7 @@ def __init__(
380
383
self .n_batch = min (n_ctx , n_batch )
381
384
self .context_params .n_ctx = self ._model .n_ctx_train ()
382
385
self .context_params .n_batch = self .n_batch
386
+ self .context_params .n_ubatch = min (self .n_batch , n_ubatch )
383
387
384
388
self ._ctx = self ._stack .enter_context (
385
389
contextlib .closing (
@@ -2071,6 +2075,7 @@ def __getstate__(self):
2071
2075
seed = self .context_params .seed ,
2072
2076
n_ctx = self .context_params .n_ctx ,
2073
2077
n_batch = self .n_batch ,
2078
+ n_ubatch = self .context_params .n_ubatch ,
2074
2079
n_threads = self .context_params .n_threads ,
2075
2080
n_threads_batch = self .context_params .n_threads_batch ,
2076
2081
rope_scaling_type = self .context_params .rope_scaling_type ,
0 commit comments