8000 fix: from_json_schema oneof/anyof bug. Closes #1097 · sjanaX01/llama-cpp-python@d3f5528 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3f5528

Browse files
committed
fix: from_json_schema oneof/anyof bug. Closes abetlen#1097
1 parent 8eefdbc commit d3f5528

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

llama_cpp/llama_grammar.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,6 @@ def _add_rule(self, name: str, rule: str):
14321432
return key
14331433

14341434
def visit(self, schema: Dict[str, Any], name: str) -> str:
1435-
schema_type: Optional[str] = schema.get("type") # type: ignore
14361435
rule_name = name or "root"
14371436

14381437
if "$defs" in schema:
@@ -1458,7 +1457,19 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
14581457
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
14591458
return self._add_rule(rule_name, rule)
14601459

1461-
elif schema_type == "object" and "properties" in schema:
1460+
elif "$ref" in schema:
1461+
ref = schema["$ref"]
1462+
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
1463+
# inline $defs
1464+
def_name = ref[len("#/$defs/") :]
1465+
def_schema = self._defs[def_name]
1466+
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
1467+
1468+
1469+
schema_type: Optional[str] = schema.get("type") # type: ignore
1470+
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
1471+
1472+
if schema_type == "object" and "properties" in schema:
14621473
# TODO: `required` keyword
14631474
prop_order = self._prop_order
14641475
prop_pairs = sorted(
@@ -1489,14 +1500,6 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
14891500
)
14901501
return self._add_rule(rule_name, rule)
14911502

1492-
elif "$ref" in schema:
1493-
ref = schema["$ref"]
1494-
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
1495-
# inline $defs
1496-
def_name = ref[len("#/$defs/") :]
1497-
def_schema = self._defs[def_name]
1498-
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
1499-
15001503
else:
15011504
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
15021505
return self._add_rule(

tests/test_grammar.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,29 @@ class B(BaseModel):
5050
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
5151

5252
assert grammar.grammar is not None
53+
54+
55+
def test_grammar_anyof():
56+
sch = {
57+
"properties": {
58+
"temperature": {
59+
"description": "The temperature mentioned",
60+
"type": "number",
61+
},
62+
"unit": {
63+
"anyOf": [
64+
{
65+
"description": "Unit for temperature",
66+
"enum": ["celsius", "fahrenheit"],
67+
"type": "string",
68+
},
69+
{"type": "null"},
70+
],
71+
},
72+
},
73+
"type": "object",
74+
}
75+
76+
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
77+
78+
assert grammar.grammar is not None

0 commit comments

Comments
 (0)
0