8000 feat: Update llama.cpp · coderonion/llama-cpp-python@56071c9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 56071c9

Browse files
committed
feat: Update llama.cpp
1 parent 08b16af commit 56071c9

File tree

2 files changed

+231
-8
lines changed

2 files changed

+231
-8
lines changed

llama_cpp/llama_cpp.py

Lines changed: 230 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,18 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
237237
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
238238
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
239239

240+
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
241+
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
242+
240243
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
241244
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
242245
# define LLAMA_SESSION_VERSION 5
243246
LLAMA_SESSION_VERSION = 5
244247

248+
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
249+
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
250+
#define LLAMA_STATE_SEQ_VERSION 1
251+
LLAMA_STATE_SEQ_VERSION = 1
245252

246253
# struct llama_model;
247254
llama_model_p = NewType("llama_model_p", int)
@@ -1467,6 +1474,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
14671474

14681475

14691476
# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
1477+
# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
14701478
# // seq_id < 0 : match any sequence
14711479
# // p0 < 0 : [0, p1]
14721480
# // p1 < 0 : [p0, inf)
@@ -1493,6 +1501,9 @@ def llama_kv_cache_seq_rm(
14931501
/,
14941502
) -> bool:
14951503
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
1504+
1505+
Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
1506+
14961507
seq_id < 0 : match any sequence
14971508
p0 < 0 : [0, p1]
14981509
p1 < 0 : [p0, inf)"""
@@ -1652,7 +1663,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /):
16521663

16531664
# Returns the maximum size in bytes of the state (rng, logits, embedding
16541665
# and kv_cache) - will often be smaller after compacting tokens
1655-
# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
1666+
# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
1667+
@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t)
1668+
def llama_state_get_size(ctx: llama_context_p, /) -> int:
1669+
"""Returns the maximum size in bytes of the state (rng, logits, embedding
1670+
and kv_cache) - will often be smaller after compacting tokens"""
1671+
...
1672+
1673+
1674+
# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
1675+
# "use llama_state_get_size instead");
16561676
@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t)
16571677
def llama_get_state_size(ctx: llama_context_p, /) -> int:
16581678
"""Returns the maximum size in bytes of the state (rng, logits, embedding
@@ -1663,9 +1683,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int:
16631683
# Copies the state to the specified destination address.
16641684
# Destination needs to have allocated enough memory.
16651685
# Returns the number of bytes copied
1666-
# LLAMA_API size_t llama_copy_state_data(
1686+
# LLAMA_API size_t llama_state_get_data(
16671687
# struct llama_context * ctx,
16681688
# uint8_t * dst);
1689+
@ctypes_function(
1690+
"llama_state_get_data",
1691+
[
1692+
llama_context_p_ctypes,
1693+
ctypes.POINTER(ctypes.c_uint8),
1694+
],
1695+
ctypes.c_size_t,
1696+
)
1697+
def llama_state_get_data(
1698+
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
1699+
) -> int:
1700+
"""Copies the state to the specified destination address.
1701+
Destination needs to have allocated enough memory.
1702+
Returns the number of bytes copied"""
1703+
...
1704+
1705+
1706+
# LLAMA_API DEPRECATED(size_t llama_copy_state_data(
1707+
# struct llama_context * ctx,
1708+
# uint8_t * dst),
1709+
# "use llama_state_get_data instead");
16691710
@ctypes_function(
16701711
"llama_copy_state_data",
16711712
[
@@ -1685,9 +1726,26 @@ def llama_copy_state_data(
16851726

16861727
# // Set the state reading from the specified address
16871728
# // Returns the number of bytes read
1688-
# LLAMA_API size_t llama_set_state_data(
1729+
# LLAMA_API size_t llama_state_set_data(
16891730
# struct llama_context * ctx,
16901731
# const uint8_t * src);
1732+
@ctypes_function(
1733+
"llama_state_set_data",
1734+
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
1735+
ctypes.c_size_t,
1736+
)
1737+
def llama_state_set_data(
1738+
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
1739+
) -> int:
1740+
"""Set the state reading from the specified address
1741+
Returns the number of bytes read"""
1742+
...
1743+
1744+
1745+
# LLAMA_API DEPRECATED(size_t llama_set_state_data(
1746+
# struct llama_context * ctx,
1747+
# const uint8_t * src),
1748+
# "use llama_state_set_data instead");
16911749
@ctypes_function(
16921750
"llama_set_state_data",
16931751
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)],
@@ -1701,12 +1759,40 @@ def llama_set_state_data(
17011759

17021760

17031761
# Save/load session file
1704-
# LLAMA_API bool llama_load_session_file(
1762+
# LLAMA_API bool llama_state_load_file(
17051763
# struct llama_context * ctx,
17061764
# const char * path_session,
17071765
# llama_token * tokens_out,
17081766
# size_t n_token_capacity,
17091767
# size_t * n_token_count_out);
1768+
@ctypes_function(
1769+
"llama_state_load_file",
1770+
[
1771+
llama_context_p_ctypes,
1772+
ctypes.c_char_p,
1773+
llama_token_p,
1774+
ctypes.c_size_t,
1775+
ctypes.POINTER(ctypes.c_size_t),
1776+
],
1777+
ctypes.c_bool,
1778+
)
1779+
def llama_state_load_file(
1780+
ctx: llama_context_p,
1781+
path_session: bytes,
1782+
tokens_out: CtypesArray[llama_token],
1783+
n_token_capacity: Union[ctypes.c_size_t, int],
1784+
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
1785+
/,
1786+
) -> bool: ...
1787+
1788+
1789+
# LLAMA_API DEPRECATED(bool llama_load_session_file(
1790+
# struct llama_context * ctx,
1791+
# const char * path_session,
1792+
# llama_token * tokens_out,
1793+
# size_t n_token_capacity,
1794+
# size_t * n_token_count_out),
1795+
# "use llama_state_load_file instead");
17101796
@ctypes_function(
17111797
"llama_load_session_file",
17121798
[
@@ -1728,11 +1814,36 @@ def llama_load_session_file(
17281814
) -> int: ...
17291815

17301816

1731-
# LLAMA_API bool llama_save_session_file(
1817+
# LLAMA_API bool llama_state_save_file(
17321818
# struct llama_context * ctx,
17331819
# const char * path_session,
17341820
# const llama_token * tokens,
17351821
# size_t n_token_count);
1822+
@ctypes_function(
1823+
"llama_state_save_file",
1824+
[
1825+
llama_context_p_ctypes,
1826+
ctypes.c_char_p,
1827+
llama_token_p,
1828+
ctypes.c_size_t,
1829+
],
1830+
ctypes.c_bool,
1831+
)
1832+
def llama_state_save_file(
1833+
ctx: llama_context_p,
1834+
path_session: bytes,
1835+
tokens: CtypesArray[llama_token],
1836+
n_token_count: Union[ctypes.c_size_t, int],
1837+
/,
1838+
) -> bool: ...
1839+
1840+
1841+
# LLAMA_API DEPRECATED(bool llama_save_session_file(
1842+
# struct llama_context * ctx,
1843+
# const char * path_session,
1844+
# const llama_token * tokens,
1845+
# size_t n_token_count),
1846+
# "use llama_state_save_file instead");
17361847
@ctypes_function(
17371848
"llama_save_session_file",
17381849
[
@@ -1752,6 +1863,116 @@ def llama_save_session_file(
17521863
) -> int: ...
17531864

17541865

1866+
# // Get the exact size needed to copy the KV cache of a single sequence
1867+
# LLAMA_API size_t llama_state_seq_get_size(
1868+
# struct llama_context * ctx,
1869+
# llama_seq_id seq_id);
1870+
@ctypes_function(
1871+
"llama_state_seq_get_size",
1872+
[llama_context_p_ctypes, llama_seq_id],
1873+
ctypes.c_size_t,
1874+
)
1875+
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
1876+
"""Get the exact size needed to copy the KV cache of a single sequence"""
1877+
...
1878+
1879+
1880+
# // Copy the KV cache of a single sequence into the specified buffer
1881+
# LLAMA_API size_t llama_state_seq_get_data(
1882+
# struct llama_context * ctx,
1883+
# uint8_t * dst,
1884+
# llama_seq_id seq_id);
1885+
@ctypes_function(
1886+
"llama_state_seq_get_data",
1887+
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
1888+
ctypes.c_size_t,
1889+
)
1890+
def llama_state_seq_get_data(
1891+
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, /
1892+
) -> int:
1893+
"""Copy the KV cache of a single sequence into the specified buffer"""
1894+
...
1895+
1896+
1897+
# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
1898+
# // Returns:
1899+
# // - Positive: Ok
1900+
# // - Zero: Failed to load
1901+
# LLAMA_API size_t llama_state_seq_set_data(
1902+
# struct llama_context * ctx,
1903+
# const uint8_t * src,
1904+
# llama_seq_id dest_seq_id);
1905+
@ctypes_function(
1906+
"llama_state_seq_set_data",
1907+
[llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id],
1908+
ctypes.c_size_t,
1909+
)
1910+
def llama_state_seq_set_data(
1911+
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, /
1912+
) -> int:
1913+
"""Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence"""
1914+
...
1915+
1916+
1917+
# LLAMA_API size_t llama_state_seq_save_file(
1918+
# struct llama_context * ctx,
1919+
# const char * filepath,
1920+
# llama_seq_id seq_id,
1921+
# const llama_token * tokens,
1922+
# size_t n_token_count);
1923+
@ctypes_function(
1924+
"llama_state_seq_save_file",
1925+
[
1926+
llama_context_p_ctypes,
1927+
ctypes.c_char_p,
1928+
llama_seq_id,
1929+
llama_token_p,
1930+
ctypes.c_size_t,
1931+
],
1932+
ctypes.c_size_t,
1933+
)
1934+
def llama_state_seq_save_file(
1935+
ctx: llama_context_p,
1936+
filepath: bytes,
1937+
seq_id: llama_seq_id,
1938+
tokens: CtypesArray[llama_token],
1939+
n_token_count: Union[ctypes.c_size_t, int],
1940+
/,
1941+
) -> int:
1942+
...
1943+
1944+
1945+
# LLAMA_API size_t llama_state_seq_load_file(
1946+
# struct llama_context * ctx,
1947+
# const char * filepath,
1948+
# llama_seq_id dest_seq_id,
1949+
# llama_token * tokens_out,
1950+
# size_t n_token_capacity,
1951+
# size_t * n_token_count_out);
1952+
@ctypes_function(
1953+
"llama_state_seq_load_file",
1954+
[
1955+
llama_context_p_ctypes,
1956+
ctypes.c_char_p,
1957+
llama_seq_id,
1958+
llama_token_p,
1959+
ctypes.c_size_t,
1960+
ctypes.POINTER(ctypes.c_size_t),
1961+
],
1962+
ctypes.c_size_t,
1963+
)
1964+
def llama_state_seq_load_file(
1965+
ctx: llama_context_p,
1966+
filepath: bytes,
1967+
dest_seq_id: llama_seq_id,
1968+
tokens_out: CtypesArray[llama_token],
1969+
n_token_capacity: Union[ctypes.c_size_t, int],
1970+
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
1971+
/,
1972+
) -> int:
1973+
...
1974+
1975+
17551976
# //
17561977
# // Decoding
17571978
# //
@@ -1930,8 +2151,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
19302151
...
19312152

19322153

1933-
# // Logits for the ith token. Equivalent to:
2154+
# // Logits for the ith token. For positive indices, Equivalent to:
19342155
# // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
2156+
# // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
19352157
# // returns NULL for invalid ids.
19362158
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
19372159
@ctypes_function(
@@ -1963,8 +2185,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]
19632185
...
19642186

19652187

1966-
# // Get the embeddings for the ith token. Equivalent to:
2188+
# // Get the embeddings for the ith token. For positive indices, Equivalent to:
19672189
# // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
2190+
# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
19682191
# // shape: [n_embd] (1-dimensional)
19692192
# // returns NULL for invalid ids.
19702193
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);

vendor/llama.cpp

0 commit comments

Comments
 (0)
0