@@ -92,6 +92,7 @@ def __init__(
92
92
logits_all : bool = False ,
93
93
embedding : bool = False ,
94
94
offload_kqv : bool = True ,
95
+ flash_attn : bool = False ,
95
96
# Sampling Params
96
97
last_n_tokens_size : int = 64 ,
97
98
# LoRA Params
@@ -168,6 +169,7 @@ def __init__(
168
169
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
169
170
embedding: Embedding mode only.
170
171
offload_kqv: Offload K, Q, V to GPU.
172
+ flash_attn: Use flash attention.
171
173
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
172
174
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
173
175
lora_path: Path to a LoRA file to apply to the model.
@@ -310,6 +312,7 @@ def __init__(
310
312
) # Must be set to True for speculative decoding
311
313
self .context_params .embeddings = embedding # TODO: Rename to embeddings
312
314
self .context_params .offload_kqv = offload_kqv
315
+ self .context_params .flash_attn = flash_attn
313
316
# KV cache quantization
314
317
if type_k is not None :
315
318
self .context_params .type_k = type_k
@@ -1774,6 +1777,7 @@ def __getstate__(self):
1774
1777
logits_all = self .context_params .logits_all ,
1775
1778
embedding = self .context_params .embeddings ,
1776
1779
offload_kqv = self .context_params .offload_kqv ,
1780
+ flash_offload = self .context_params .flash_offload ,
1777
1781
# Sampling Params
1778
1782
last_n_tokens_size = self .last_n_tokens_size ,
1779
1783
# LoRA Params
0 commit comments