8000 Fix Pydantic model parsing (#1087) · davidvonthenen/llama-cpp-python@c689ccc · GitHub
[go: up one dir, main page]

Skip to content

Commit c689ccc

Browse files
authored
Fix Pydantic model parsing (abetlen#1087)
1 parent 5502ac8 commit c689ccc

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

llama_cpp/llama_grammar.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,6 @@ def _add_rule(self, name: str, rule: str):
14331433

14341434
def visit(self, schema: Dict[str, Any], name: str) -> str:
14351435
schema_type: Optional[str] = schema.get("type") # type: ignore
1436-
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
14371436
rule_name = name or "root"
14381437

14391438
if "$defs" in schema:

tests/test_grammar.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
11
import llama_cpp
2+
import json
23

34
tree = """
45
leaf ::= "."
56
node ::= leaf | "(" node node ")"
67
root ::= node
78
"""
89

10+
911
def test_grammar_from_string():
1012
grammar = llama_cpp.LlamaGrammar.from_string(tree)
1113
assert grammar._n_rules == 3
1214
assert grammar._start_rule_index == 2
1315
assert grammar.grammar is not None
16+
17+
18+
def test_composed_pydantic_grammar():
19+
"""
20+
from pydantic import BaseModel
21+
22+
class A(BaseModel):
23+
a: int
24+
25+
class B(BaseModel):
26+
a: A
27+
b: int
28+
"""
29+
30+
# This schema corresponds to the grammar in the comment above.
31+
# We don't use the pydantic models directly to avoid the dependency.
32+
schema = {
33+
"$defs": {
34+
"A": {
35+
"properties": {"a": {"title": "A", "type": "integer"}},
36+
"required": ["a"],
37+
"title": "A",
38+
"type": "object",
39+
}
40+
},
41+
"properties": {
42+
"a": {"$ref": "#/$defs/A"},
43+
"b": {"title": "B", "type": "integer"},
44+
},
45+
"required": ["a", "b"],
46+
"title": "B",
47+
"type": "object",
48+
}
49+
50+
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
51+
52+
assert grammar.grammar is not None

0 commit comments

Comments
 (0)
0