10000 feat: Move tokenizer to own module · notwa/llama-cpp-python@b5fca91 · GitHub
[go: up one dir, main page]

Skip to content

Commit b5fca91

Browse files
committed
feat: Move tokenizer to own module
1 parent 2ef7ba3 commit b5fca91

File tree

2 files changed

+100
-65
lines changed

2 files changed

+100
-65
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
import sys
5-
import abc
65
import uuid
76
import time
87
import multiprocessing
@@ -15,7 +14,6 @@
1514
Iterator,
1615
Deque,
1716
Callable,
18-
Any,
1917
)
2018
from collections import deque
2119

@@ -31,6 +29,10 @@
3129
LlamaDiskCache, # type: ignore
3230
LlamaRAMCache, # type: ignore
3331
)
32+
from .llama_tokenizer import (
33+
BaseLlamaTokenizer,
34+
LlamaTokenizer
35+
)
3436
import llama_cpp.llama_cpp as llama_cpp
3537
import llama_cpp.llama_chat_format as llama_chat_format
3638

@@ -1747,69 +1749,6 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
17471749
return longest_prefix
17481750

17491751

1750-
class BaseLlamaTokenizer(abc.ABC):
1751-
@abc.abstractmethod
1752-
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1753-
raise NotImplementedError
1754-
1755-
@abc.abstractmethod
1756-
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1757-
raise NotImplementedError
1758-
1759-
1760-
class LlamaTokenizer(BaseLlamaTokenizer):
1761-
def __init__(self, llama: Llama):
1762-
self.llama = llama
1763-
self._model = llama._model # type: ignore
1764-
1765-
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1766-
return self._model.tokenize(text, add_bos=add_bos, special=special)
1767-
1768-
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1769-
if prev_tokens is not None:
1770-
return self._model.detokenize(tokens[len(prev_tokens):])
1771-
else:
1772-
return self._model.detokenize(tokens)
1773-
1774-
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
1775-
return self.tokenize(
1776-
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
1777-
)
1778-
1779-
def decode(self, tokens: List[int]) -> str:
1780-
return self.detokenize(tokens).decode("utf-8", errors="ignore")
1781-
1782-
@classmethod
1783-
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
1784-
return cls(Llama(model_path=path, vocab_only=True))
1785-
1786-
1787-
class LlamaHFTokenizer(BaseLlamaTokenizer):
1788-
def __init__(self, hf_tokenizer: Any):
1789-
self.hf_tokenizer = hf_tokenizer
1790-
1791-
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1792-
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)
1793-
1794-
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1795-
if prev_tokens is not None:
1796-
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
1797-
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
1798-
return text[len(prev_text):]
1799-
else:
1800-
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
1801-
1802-
@classmethod
1803-
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
1804-
try:
1805-
from transformers import AutoTokenizer
1806-
except ImportError:
1807-
raise ImportError(
1808-
"The `transformers` library is required to use the `HFTokenizer`."
1809-
"You can install it with `pip install transformers`."
1810-
)
1811-
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
1812-
return cls(hf_tokenizer)
18131752

18141753

18151754
class LlamaState:

llama_cpp/llama_tokenizer.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
from typing import (
5+
List,
6+
Optional,
7+
Any,
8+
)
9+
10+
import llama_cpp
11+
from llama_cpp.llama_types import List
12+
13+
14+
class BaseLlamaTokenizer(abc.ABC):
15+
@abc.abstractmethod
16+
def tokenize(
17+
self, text: bytes, add_bos: bool = True, special: bool = True
18+
) -> List[int]:
19+
raise NotImplementedError
20+
21+
@abc.abstractmethod
22+
def detokenize(
23+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
24+
) -> bytes:
25+
raise NotImplementedError
26+
27+
28+
class LlamaTokenizer(BaseLlamaTokenizer):
29+
def __init__(self, llama: llama_cpp.Llama):
30+
self.llama = llama
31+
self._model = llama._model # type: ignore
32+
33+
def tokenize(
34+
self, text: bytes, add_bos: bool = True, special: bool = True
35+
) -> List[int]:
36+
return self._model.tokenize(text, add_bos=add_bos, special=special)
37+
38+
def detokenize(
39+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
40+
) -> bytes:
41+
if prev_tokens is not None:
42+
return self._model.detokenize(tokens[len(prev_tokens) :])
43+
else:
44+
return self._model.detokenize(tokens)
45+
46+
def encode(
47+
self, text: str, add_bos: bool = True, special: bool = True
48+
) -> List[int]:
49+
return self.tokenize(
50+
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
51+
)
52+
53+
def decode(self, tokens: List[int]) -> str:
54+
return self.detokenize(tokens).decode("utf-8", errors="ignore")
55+
56+
@classmethod
57+
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
58+
return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
59+
60+
61+
class LlamaHFTokenizer(BaseLlamaTokenizer):
62+
def __init__(self, hf_tokenizer: Any):
63+
self.hf_tokenizer = hf_tokenizer
64+
65+
def tokenize(
66+
self, text: bytes, add_bos: bool = True, special: bool = True
67+
) -> List[int]:
68+
return self.hf_tokenizer.encode(
69+
text.decode("utf-8", errors="ignore"), add_special_tokens=special
70+
)
71+
72+
def detokenize(
73+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
74+
) -> bytes:
75+
if prev_tokens is not None:
76+
text = self. A805 hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
77+
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
78+
"utf-8", errors="ignore"
79+
)
80+
return text[len(prev_text) :]
81+
else:
82+
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
83+
84+
@classmethod
85+
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
86+
try:
87+
from transformers import AutoTokenizer
88+
except ImportError:
89+
raise ImportError(
90+
"The `transformers` library is required to use the `HFTokenizer`."
91+
"You can install it with `pip install transformers`."
92+
)
93+
hf_tokenizer = AutoTokenizer.from_pretrained(
94+
pretrained_model_name_or_path=pretrained_model_name_or_path
95+
)
96+
return cls(hf_tokenizer)

0 commit comments

Comments
 (0)
0