8000 feat: Add option to enable `flash_attn` to Lllama params and ModelSet… · brookareru/llama-cpp-python@22d77ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 22d77ee

Browse files
committed
feat: Add option to enable flash_attn to Lllama params and ModelSettings
1 parent 8c2b24d commit 22d77ee

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
logits_all: bool = False,
9393
embedding: bool = False,
9494
offload_kqv: bool = True,
95+
flash_attn: bool = False,
9596
# Sampling Params
9697
last_n_tokens_size: int = 64,
9798
# LoRA Params
@@ -168,6 +169,7 @@ def __init__(
168169
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
169170
embedding: Embedding mode only.
170171
offload_kqv: Offload K, Q, V to GPU.
172+
flash_attn: Use flash attention.
171173
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
172174
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
173175
lora_path: Path to a LoRA file to apply to the model.
@@ -310,6 +312,7 @@ def __init__(
310312
) # Must be set to True for speculative decoding
311313
self.context_params.embeddings = embedding # TODO: Rename to embeddings
312314
self.context_params.offload_kqv = offload_kqv
315+
self.context_params.flash_attn = flash_attn
313316
# KV cache quantization
314317
if type_k is not None:
315318
self.context_params.type_k = type_k
@@ -1774,6 +1777,7 @@ def __getstate__(self):
17741777
logits_all=self.context_params.logits_all,
17751778
embedding=self.context_params.embeddings,
17761779
offload_kqv=self.context_params.offload_kqv,
1780+
flash_offload=self.context_params.flash_offload,
17771781
# Sampling Params
17781782
last_n_tokens_size=self.last_n_tokens_size,
17791783
# LoRA Params

llama_cpp/server/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class ModelSettings(BaseSettings):
9696
offload_kqv: bool = Field(
9797
default=True, description="Whether to offload kqv to the GPU."
9898
)
99+
flash_attn: bool = Field(
100+
default=False, description="Whether to use flash attention."
101+
)
99102
# Sampling Params
100103
last_n_tokens_size: int = Field(
101104
default=64,

0 commit comments

Comments
 (0)
0