-
Notifications
You must be signed in to change notification settings - Fork 12.5k
JSON schema conversion: ⚡️ faster repetitions, min/maxLength for strings, cap number length #6555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
2148f24
f771a8f
159b883
a59e943
07163fb
dcf5d32
181f984
de4e60e
6c885dc
3c81e94
67a5184
9c33ee9
ed13d47
958bdda
64e3059
ba90d5b
dfd4eb3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,37 +6,86 @@ | |
import sys | ||
from typing import Any, Dict, List, Set, Tuple, Union | ||
|
||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): | ||
if not separator_rule: | ||
if min_items == 0 and max_items == 1: | ||
return f'{item_rule}?' | ||
elif min_items == 1 and max_items is None: | ||
return f'{item_rule}+' | ||
|
||
result = '' | ||
|
||
if min_items > 0: | ||
if item_rule_is_literal and separator_rule is None: | ||
result = '"' + (item_rule[1:-1] * min_items) + '"' | ||
else: | ||
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) | ||
|
||
def opt_repetitions(up_to_n, prefix_with_sep=False): | ||
if up_to_n == 0: | ||
return '' | ||
|
||
res = separator_rule + ' ' + item_rule if separator_rule and prefix_with_sep else item_rule | ||
if up_to_n > 1: | ||
res += ' ' + opt_repetitions(up_to_n - 1, prefix_with_sep=True) | ||
return f'({res})?' | ||
|
||
if min_items > 0 and max_items != min_items: | ||
result += ' ' | ||
|
||
if max_items is not None: | ||
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) | ||
else: | ||
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' | ||
|
||
if min_items == 0 and separator_rule: | ||
result = f'({item_rule} {item_operator}*)?' | ||
else: | ||
result += f'{item_operator}*' | ||
|
||
return result | ||
|
||
|
||
class BuiltinRule: | ||
def __init__(self, content: str, deps: list = None): | ||
self.content = content | ||
self.deps = deps or [] | ||
|
||
_up_to_15_digits = _build_repetition('[0-9]', 0, 15) | ||
|
||
# whitespace is constrained to a single space char to prevent model "running away" in | ||
# whitespace. Also maybe improves generation quality? | ||
SPACE_RULE = '" "?' | ||
|
||
PRIMITIVE_RULES = { | ||
'boolean': '("true" | "false") space', | ||
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', | ||
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space', | ||
'value' : 'object | array | string | number | boolean', | ||
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', | ||
'array' : '"[" space ( value ("," space value)* )? "]" space', | ||
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space', | ||
'string': r''' "\"" ( | ||
[^"\\] | | ||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) | ||
)* "\"" space''', | ||
'null': '"null" space', | ||
'boolean' : BuiltinRule('("true" | "false") space', []), | ||
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), | ||
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), | ||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), | ||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), | ||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), | ||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), | ||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), | ||
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), | ||
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), | ||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), | ||
'null' : BuiltinRule('"null" space', []), | ||
} | ||
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'] | ||
|
||
# TODO: support "uri", "email" string formats | ||
DATE_RULES = { | ||
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', | ||
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', | ||
'date-time': 'date "T" time', | ||
'date-string': '"\\"" date "\\"" space', | ||
'time-string': '"\\"" time "\\"" space', | ||
'date-time-string': '"\\"" date-time "\\"" space', | ||
STRING_FORMAT_RULES = { | ||
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), | ||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), | ||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']), | ||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), | ||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), | ||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), | ||
} | ||
|
||
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()]) | ||
DOTALL = '[\\U00000000-\\U0010FFFF]' | ||
DOT = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if it would be more performant or not, but I'm curious if: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to simpler negative range, thanks!! |
||
|
||
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) | ||
|
||
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') | ||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') | ||
|
@@ -46,16 +95,16 @@ | |
NON_LITERAL_SET = set('|.()[]{}*+?') | ||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') | ||
|
||
DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])' | ||
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits | ||
|
||
class SchemaConverter: | ||
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): | ||
self._prop_order = prop_order | ||
self._allow_fetch = allow_fetch | ||
self._dotall = dotall | ||
self._raw_pattern = raw_pattern | ||
self._rules = {'space': SPACE_RULE} | ||
self._rules = { | ||
'space': SPACE_RULE, | ||
} | ||
self._refs = {} | ||
self._refs_being_resolved = set() | ||
|
||
|
@@ -65,6 +114,29 @@ def _format_literal(self, literal): | |
) | ||
return f'"{escaped}"' | ||
|
||
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: | ||
''' | ||
not_literal('a') -> '[^a]' | ||
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' | ||
''' | ||
assert len(literal) > 0, 'Empty literal not supported' | ||
def recurse(i: int): | ||
c = literal[i] | ||
if maybe_escaped_underscores and c == '_': | ||
yield f'[^{c}\\\\]' | ||
yield ' | ' | ||
yield f'"\\\\"? "{c}"' | ||
else: | ||
yield f'[^{c}]' | ||
if i < len(literal) - 1: | ||
yield ' | ' | ||
yield self._format_literal(c) | ||
yield ' (' | ||
yield from recurse(i + 1) | ||
yield ')?' | ||
|
||
return ''.join(('(', *recurse(0), ')')) | ||
|
||
def _add_rule(self, name, rule): | ||
esc_name = INVALID_RULE_CHARS_RE.sub('-', name) | ||
if esc_name not in self._rules or self._rules[esc_name] == rule: | ||
|
@@ -169,10 +241,10 @@ def transform() -> Tuple[str, bool]: | |
|
||
def get_dot(): | ||
if self._dotall: | ||
rule = '[\\U00000000-\\U0010FFFF]' | ||
rule = DOTALL | ||
else: | ||
# Accept any character... except \n and \r line break chars (\x0A and \xOD) | ||
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]' | ||
rule = DOT | ||
return self._add_rule(f'dot', rule) | ||
|
||
def join_seq(): | ||
|
@@ -246,26 +318,14 @@ def join_seq(): | |
|
||
(sub, sub_is_literal) = seq[-1] | ||
|
||
if min_times == 0 and max_times is None: | ||
seq[-1] = (f'{sub}*', False) | ||
elif min_times == 0 and max_times == 1: | ||
seq[-1] = (f'{sub}?', False) | ||
elif min_times == 1 and max_times is None: | ||
seq[-1] = (f'{sub}+', False) | ||
else: | ||
if not sub_is_literal: | ||
id = sub_rule_ids.get(sub) | ||
if id is None: | ||
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) | ||
sub_rule_ids[sub] = id | ||
sub = id | ||
|
||
seq[-1] = ( | ||
' '.join( | ||
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) + | ||
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])), | ||
False | ||
) | ||
if not sub_is_literal: | ||
id = sub_rule_ids.get(sub) | ||
if id is None: | ||
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) | ||
sub_rule_ids[sub] = id | ||
sub = id | ||
|
||
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) | ||
else: | ||
literal = '' | ||
while i < length: | ||
|
@@ -373,49 +433,47 @@ def add_component(comp_schema, is_required): | |
' "]" space') | ||
else: | ||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') | ||
list_item_operator = f'( "," space {item_rule_name} )' | ||
successive_items = "" | ||
min_items = schema.get("minItems", 0) | ||
max_items = schema.get("maxItems") | ||
if min_items > 0: | ||
successive_items = list_item_operator * (min_items - 1) | ||
min_items -= 1 | ||
if max_items is not None and max_items > min_items: | ||
successive_items += (list_item_operator + "?") * (max_items - min_items - 1) | ||
else: | ||
successive_items += list_item_operator + "*" | ||
if min_items == 0: | ||
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space' | ||
else: | ||
rule = f'"[" space {item_rule_name} {successive_items} "]" space' | ||
return self._add_rule(rule_name, rule) | ||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') | ||
|
||
elif schema_type in (None, 'string') and 'pattern' in schema: | ||
return self._visit_pattern(schema['pattern'], rule_name) | ||
|
||
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): | ||
return self._add_rule( | ||
return self._add_primitive( | ||
'root' if rule_name == 'root' else schema_format, | ||
PRIMITIVE_RULES['uuid'] | ||
) | ||
|
||
elif schema_type in (None, 'string') and schema_format in DATE_RULES: | ||
for t, r in DATE_RULES.items(): | ||
self._add_rule(t, r) | ||
return schema_format + '-string' | ||
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: | ||
prim_name = f'{schema_format}-string' | ||
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) | ||
|
||
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): | ||
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) | ||
min_len = schema.get('minLength', 0) | ||
max_len = schema.get('maxLength') | ||
|
||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') | ||
|
||
elif (schema_type == 'object') or (len(schema) == 0): | ||
for n in OBJECT_RULE_NAMES: | ||
self._add_rule(n, PRIMITIVE_RULES[n]) | ||
return self._add_rule(rule_name, 'object') | ||
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) | ||
|
||
else: | ||
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' | ||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero | ||
return self._add_rule( | ||
'root' if rule_name == 'root' else schema_type, | ||
PRIMITIVE_RULES[schema_type] | ||
) | ||
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) | ||
|
||
def _add_primitive(self, name: str, rule: BuiltinRule): | ||
n = self._add_rule(name, rule.content) | ||
|
||
for dep in rule.deps: | ||
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) | ||
assert dep_rule, f'Rule {dep} not known' | ||
if dep not in self._rules: | ||
self._add_primitive(dep, dep_rule) | ||
return n | ||
|
||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): | ||
prop_order = self._prop_order | ||
|
@@ -437,7 +495,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st | |
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') | ||
prop_kv_rule_names["*"] = self._add_rule( | ||
f'{sub_name}-kv', | ||
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' | ||
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' | ||
) | ||
optional_props.append("*") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I'm on board with the rename to use underscores -- while there are a few other files with underscores (such as
pydantic_models_to_grammar.py
), most seem to use hyphens (pydantic-models-to-grammar-examples.py
, etc), and it seems like the old filename is possibly better?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I wanted all filenames in the repo to use hyphens. But later I found out that Python does not work well when there are hyphens in the filenames (e.g. I think you cannot include a Python file that has hyphens). So I think it's better to eventually rename all Python files to use underscores in their filenames
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh I did all this as a prerequisite for #6389, in which i need to import the converter from Python. I also found out llama-cpp-python inlines that file in their codebase, since it's hard / not trivial to import (short of using importlib, which feels dirty).