8000 feat: Update llama.cpp · CISC/llama-cpp-python@7613d23 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7613d23

Browse files
committed
feat: Update llama.cpp
1 parent f7f4fa8 commit 7613d23

File tree

3 files changed

+95
-44
lines changed

3 files changed

+95
-44
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,24 @@ def __init__(
408408
)
409409
)
410410

411+
self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None
412+
411413
if self.lora_path:
412-
if self._model.apply_lora_from_file(
413-
self.lora_path,
414-
self.lora_scale,
415-
self.lora_base,
416-
self.n_threads,
414+
assert self._model.model is not None
415+
self._lora_adapter = llama_cpp.llama_lora_adapter_init(
416+
self._model.model,
417+
self.lora_path.encode("utf-8"),
418+
)
419+
if self._lora_adapter is None:
420+
raise RuntimeError(
421+
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
422+
)
423+
assert self._ctx.ctx is not None
424+
if llama_cpp.llama_lora_adapter_set(
425+
self._ctx.ctx, self._lora_adapter, self.lora_scale
417426
):
418427
raise RuntimeError(
419-
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
428+
f"Failed to set LoRA adapter from lora path: {self.lora_path}"
420429
)
421430

422431
if self.verbose:
@@ -2077,6 +2086,8 @@ def close(self) -> None:
20772086
self._stack.close()
20782087

20792088
def __del__(self) -> None:
2089+
if self._lora_adapter is not None:
2090+
llama_cpp.llama_lora_adapter_free(self._lora_adapter)
20802091
self.close()
20812092

20822093
@staticmethod

llama_cpp/llama_cpp.py

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
401401
# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
402402
# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
403403
# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
404-
# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
404+
# // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
405405
# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
406406
# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
407407
# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
@@ -430,14 +430,16 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
430430
# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
431431
# LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
432432
# LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
433-
433+
# LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
434+
# LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
435+
# LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
436+
#
434437
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
435438
# };
436439
LLAMA_FTYPE_ALL_F32 = 0
437440
LLAMA_FTYPE_MOSTLY_F16 = 1
438441
LLAMA_FTYPE_MOSTLY_Q4_0 = 2
439442
LLAMA_FTYPE_MOSTLY_Q4_1 = 3
440-
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
441443
LLAMA_FTYPE_MOSTLY_Q8_0 = 7
442444
LLAMA_FTYPE_MOSTLY_Q5_0 = 8
443445
LLAMA_FTYPE_MOSTLY_Q5_1 = 9
@@ -464,6 +466,9 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
464466
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30
465467
LLAMA_FTYPE_MOSTLY_IQ1_M = 31
466468
LLAMA_FTYPE_MOSTLY_BF16 = 32
469+
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33
470+
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34
471+
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35
467472
LLAMA_FTYPE_GUESSED = 1024
468473

469474
# enum llama_rope_scaling_type {
@@ -1100,6 +1105,12 @@ class llama_chat_message(ctypes.Structure):
11001105
]
11011106

11021107

1108+
# // lora adapter
1109+
# struct llama_lora_adapter;
1110+
llama_lora_adapter_p = ctypes.c_void_p
1111+
llama_lora_adapter_p_ctypes = ctypes.POINTER(ctypes.c_void_p)
1112+
1113+
11031114
# // Helpers for getting default parameters
11041115
# LLAMA_API struct llama_model_params llama_model_default_params(void);
11051116
@ctypes_function(
@@ -1507,43 +1518,72 @@ def llama_model_quantize(
15071518
...
15081519

15091520

1510-
# // Apply a LoRA adapter to a loaded model
1511-
# // path_base_model is the path to a higher quality model to use as a base for
1512-
# // the layers modified by the adapter. Can be NULL to use the current loaded model.
1513-
# // The model needs to be reloaded before applying a new adapter, otherwise the adapter
1514-
# // will be applied on top of the previous one
1515-
# // Returns 0 on success
1516-
# LLAMA_API int32_t llama_model_apply_lora_from_file(
1517-
# const struct llama_model * model,
1518-
# const char * path_lora,
1519-
# float scale,
1520-
# const char * path_base_model,
1521-
# int32_t n_threads);
1522-
@ctypes_function(
1523-
"llama_model_apply_lora_from_file",
1524-
[
1525-
llama_model_p_ctypes,
1526-
ctypes.c_char_p,
1527-
ctypes.c_float,
1528-
ctypes.c_char_p,
1529-
ctypes.c_int32,
1530-
],
1521+
# // Load a LoRA adapter from file
1522+
# // The loaded adapter will be associated to the given model, and will be free when the model is deleted
1523+
# LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
1524+
# struct llama_model * model,
1525+
# const char * path_lora);
1526+
@ctypes_function(
1527+
"llama_lora_adapter_init",
1528+
[llama_model_p_ctypes, ctypes.c_char_p],
1529+
llama_lora_adapter_p_ctypes,
1530+
)
1531+
def llama_lora_adapter_init(
1532+
model: llama_model_p, path_lora: bytes, /
1533+
) -> Optional[llama_lora_adapter_p]:
1534+
"""Load a LoRA adapter from file
1535+
The loaded adapter will be associated to the given model, and will be free when the model is deleted"""
1536+
...
1537+
1538+
1539+
# // Add a loaded LoRA adapter to given context
1540+
# // This will not modify model's weight
1541+
# LLAMA_API int32_t llama_lora_adapter_set(
1542+
# struct llama_context * ctx,
1543+
# struct llama_lora_adapter * adapter,
1544+
# float scale);
1545+
@ctypes_function(
1546+
"llama_lora_adapter_set",
1547+
[llama_context_p_ctypes, llama_lora_adapter_p_ctypes, ctypes.c_float],
15311548
ctypes.c_int32,
15321549
)
1533-
def llama_model_apply_lora_from_file(
1534-
model: llama_model_p,
1535-
path_lora: Union[ctypes.c_char_p, bytes],
1536-
scale: Union[ctypes.c_float, float],
1537-
path_base_model: Union[ctypes.c_char_p, bytes, None],
1538-
n_threads: Union[ctypes.c_int32, int],
1539-
/,
1550+
def llama_lora_adapter_set(
1551+
ctx: llama_context_p, adapter: llama_lora_adapter_p, scale: float, /
1552+
) -> int:
1553+
"""Add a loaded LoRA adapter to given context
1554+
This will not modify model's weight"""
1555+
...
1556+
1557+
1558+
# // Remove a LoRA adapter from given context
1559+
# // Return -1 if the adapter is not present in the context
1560+
# LLAMA_API int32_t llama_lora_adapter_remove(
1561+
# struct llama_context * ctx,
1562+
# struct llama_lora_adapter * adapter);
1563+
@ctypes_function(
1564+
"llama_lora_adapter_remove",
1565+
[llama_context_p_ctypes, llama_lora_adapter_p_ctypes],
1566+
ctypes.c_int32,
1567+
)
1568+
def llama_lora_adapter_remove(
1569+
ctx: llama_context_p, adapter: llama_lora_adapter_p, /
15401570
) -> int:
1541-
"""Apply a LoRA adapter to a loaded model
1542-
path_base_model is the path to a higher quality model to use as a base for
1543-
the layers modified by the adapter. Can be NULL to use the current loaded model.
1544-
The model needs to be reloaded before applying a new adapter, otherwise the adapter
1545-
will be applied on top of the previous one
1546-
Returns 0 on success"""
1571+
"""Remove a LoRA adapter from given context
1572+
Return -1 if the adapter is not present in the context"""
1573+
...
1574+
1575+
1576+
# // Manually free a LoRA adapter
1577+
# // Note: loaded adapters will be free when the associated model is deleted
1578+
# LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
1579+
@ctypes_function(
1580+
"llama_lora_adapter_free",
1581+
[llama_lora_adapter_p_ctypes],
1582+
None,
1583+
)
1584+
def llama_lora_adapter_free(adapter: llama_lora_adapter_p, /):
1585+
"""Manually free a LoRA adapter
1586+
Note: loaded adapters will be free when the associated model is deleted"""
15471587
...
15481588

15491589

vendor/llama.cpp

0 commit comments

Comments
 (0)
0