diff --git a/doc/reflection.rst b/doc/reflector.rst similarity index 93% rename from doc/reflection.rst rename to doc/reflector.rst index 07248ab..e0f8c22 100644 --- a/doc/reflection.rst +++ b/doc/reflector.rst @@ -40,6 +40,7 @@ Full example with reflection and enums: # although the import will work if placed here, you will still have a weird error about a null URL on the first run # that it because the base is not correctly configured on first run, you need to run the code a second time + # error is something like AttributeError: 'NoneType' object has no attribute 'url' import generated.trippin People = generated.trippin.People diff --git a/odata/complextype.py b/odata/complextype.py index a8cc403..a62b808 100644 --- a/odata/complextype.py +++ b/odata/complextype.py @@ -1,23 +1,17 @@ # -*- coding: utf-8 -*- +import dataclasses +import numbers +from dataclasses import dataclass -from odata.property import PropertyBase +from odata.property import QueryBase, StringProperty, IntegerProperty, FloatProperty, CollectionQueryBase -class ComplexType(dict): - properties = dict() - - def __getattr__(self, item): - return self[item] - - def __setattr__(self, key, value): - self[key] = value - - def __repr__(self): - keys = ','.join(self.keys()) - return ''.format(keys) +@dataclass +class ComplexType: + pass -class ComplexTypeProperty(PropertyBase): +class ComplexTypeProperty(QueryBase, CollectionQueryBase): """ A property that contains a ComplexType object @@ -25,13 +19,26 @@ class ComplexTypeProperty(PropertyBase): :param type_class: A subclass of ComplexType """ - def __init__(self, name, type_class=ComplexType): + def __init__(self, name, type_class=ComplexType, is_collection: bool = False, is_nullable: bool = True): """ :type name: str """ super(ComplexTypeProperty, self).__init__(name) + self.name = name + self.is_collection = is_collection + self.is_nullable = is_nullable self.type_class = type_class + def __getattr__(self, item): + fields = dataclasses.fields(self.type_class) + for field in fields: + if field.name == item: + return ComplexTypeProperty(f"{self.name}/{field.name}", field.type, is_nullable=False) + raise AttributeError(f"{item} does not exist in {self.name}") + + def __set__(self, instance, value): + instance.__odata_complex_type__ = value + def serialize(self, value): if isinstance(value, list): data = [] @@ -43,21 +50,18 @@ def serialize(self, value): def _serialize(self, value): data = dict() - for name, prop in value.properties.items(): - prop_value = value.get(name) - - if prop_value is None: - continue - - if isinstance(prop_value, ComplexType): - serialized_value = self.serialize(prop_value) + for candidate in dir(self): + member = getattr(self, candidate) + if issubclass(member, ComplexType): + value = member.serialize() else: - serialized_value = prop('temp').serialize(prop_value) - data[name] = serialized_value + value = member + data[candidate] = value return data def deserialize(self, value): if isinstance(value, list): + self.is_collection = True data = [] for i in value: data.append(self._deserialize(i)) @@ -66,29 +70,26 @@ def deserialize(self, value): return self._deserialize(value) def _deserialize(self, value): - data = self.type_class() + data = self.type_class - for name, prop in data.properties.items(): - prop_value = value.get(name) + for member in dataclasses.fields(self.type_class): + value[member.name] = self._build_recursive(member, value) - if prop_value is None: - continue + return data(**value) - if issubclass(prop, ComplexType): - ctprop = ComplexTypeProperty('temp', type_class=prop) - deserialized_value = ctprop.deserialize(prop_value) - else: - deserialized_value = prop('temp').deserialize(prop_value) - data[name] = deserialized_value - return data + def _build_recursive(self, member, value): + if dataclasses.is_dataclass(member.type): + for child in dataclasses.fields(member.type): + value[member.name][child.name] = self._build_recursive(child, value[member.name]) + return member.type(**value[member.name]) + return value[member.name] def escape_value(self, value): - raise NotImplementedError() - - def __getattr__(self, item): - # allows ComplexType key usage in filters etc - subkey = '{0}/{1}'.format(self.name, item) - prop = self.type_class.properties[item] - if issubclass(prop, ComplexType): - return ComplexTypeProperty(subkey, type_class=prop) - return prop(subkey) + if isinstance(value, str): + return StringProperty.escape_value(self, value) + if isinstance(value, int): + return IntegerProperty.escape_value(self, value) + if isinstance(value, float): + return FloatProperty.escape_value(self, value) + + raise NotImplementedError() \ No newline at end of file diff --git a/odata/entity.py b/odata/entity.py index 4c99e50..2530ff5 100644 --- a/odata/entity.py +++ b/odata/entity.py @@ -75,6 +75,7 @@ def is_product_available(self): for product in query: print(product.name, product.is_product_available()) """ +from odata.complextype import ComplexTypeProperty from odata.exceptions import ODataConnectionError try: @@ -103,6 +104,16 @@ def __odata_url__(cls): raise ODataConnectionError(f"Cannot query {cls.__name__} objects as they don't have " f"the collection defined. Are you sure you're querying the correct object ?") + @staticmethod + def _get_simple_type(entity, property_name: str, raw_data: dict): + raw = raw_data.get(property_name) + try: + member = getattr(entity, property_name) + if member and issubclass(member.__class__, ComplexTypeProperty): + return member.deserialize(raw) + except Exception as ex: + return raw + def __new__(cls, *args, **kwargs): i = super(EntityBase, cls).__new__(cls) i.__odata__ = es = EntityState(i) @@ -113,7 +124,7 @@ def __new__(cls, *args, **kwargs): raw_data = kwargs.pop('from_data') for prop_name, prop in es.properties: - i.__odata__[prop.name] = raw_data.get(prop.name) + i.__odata__[prop.name] = EntityBase._get_simple_type(i, prop_name, raw_data) # check for values from $expand for prop_name, prop in es.navigation_properties: diff --git a/odata/enumtype.py b/odata/enumtype.py index 8c032fa..238ac37 100644 --- a/odata/enumtype.py +++ b/odata/enumtype.py @@ -17,9 +17,10 @@ class EnumTypeProperty(PropertyBase): :param enum_class: A subclass of EnumType """ - def __init__(self, name, enum_class=EnumType): + def __init__(self, name, enum_class=EnumType, is_nullable: bool = True): super(EnumTypeProperty, self).__init__(name) self.enum_class = enum_class + self.is_nullable = is_nullable def serialize(self, value): return value.name diff --git a/odata/metadata.py b/odata/metadata.py index 6a1a497..7f2ac50 100644 --- a/odata/metadata.py +++ b/odata/metadata.py @@ -7,6 +7,8 @@ import rich.console import rich.progress +from .complextype import ComplexType, ComplexTypeProperty + has_lxml = False try: from lxml import etree as ET @@ -18,10 +20,15 @@ from .entity import declarative_base, EntityBase from .exceptions import ODataReflectionError -from .property import StringProperty, IntegerProperty, DecimalProperty, DatetimeProperty, BooleanProperty, NavigationProperty, UUIDProperty +from .property import StringProperty, IntegerProperty, DecimalProperty, DatetimeProperty, BooleanProperty, NavigationProperty, UUIDProperty, \ + LocationProperty from .enumtype import EnumType, EnumTypeProperty +def str_2_bool(value: str) -> bool: + return value.lower() in ['true', 't', 'y', 'yes', "1"] + + class MetaData(object): log = logging.getLogger('odata.metadata') @@ -40,6 +47,7 @@ class MetaData(object): 'Edm.DateTimeOffset': DatetimeProperty, 'Edm.Boolean': BooleanProperty, 'Edm.Guid': UUIDProperty, + 'Edm.LocationPoint': LocationProperty, } _annotation_term_computed = 'Org.OData.Core.V1.Computed' @@ -89,75 +97,145 @@ def _set_object_relationships(self, all_types): ) setattr(entity, name, nav) - def _create_entities(self, all_types, entity_base_class, schemas, depth=1): + def _create_entities(self, progress, all_types, entity_base_class, schemas, depth=1): orphan_entities = [] - with rich.progress.Progress(transient=True, console=self.console, disable=self.quiet) as progress: - schema_task = progress.add_task("Schemas", total=len(schemas)) - for schema in schemas: - entity_task = progress.add_task(f"Creating entities for {schema['name']}", total=len(schema.get("entities"))) - for entity_dict in schema.get('entities'): - progress.update(entity_task, advance=1) - entity_type = entity_dict['type'] - entity_type_alias = entity_dict.get('type_alias') - entity_name = entity_dict['name'] - - if entity_type in all_types: + for schema in schemas: + entity_task = progress.add_task(f"Creating entities for {schema['name']}", total=len(schema.get("entities"))) + for entity_dict in schema.get('entities'): + progress.update(entity_task, advance=1) + entity_type = entity_dict['type'] + entity_type_alias = entity_dict.get('type_alias') + entity_name = entity_dict['name'] + + if entity_type in all_types: + continue + + parent_entity_class = None + + if entity_dict.get('base_type'): + base_type = entity_dict.get('base_type') + parent_entity_class = all_types.get(base_type) + + if parent_entity_class is None: + # base class not yet created + orphan_entities.append(entity_type) continue - parent_entity_class = None - - if entity_dict.get('base_type'): - base_type = entity_dict.get('base_type') - parent_entity_class = all_types.get(base_type) - - if parent_entity_class is None: - # base class not yet created - orphan_entities.append(entity_type) - continue - - super_class = parent_entity_class or entity_base_class - object_dict = dict( - __odata_schema__=entity_dict, - __odata_type__=entity_type, - ) - entity_class = type(entity_name, (super_class,), object_dict) - - all_types[entity_type] = entity_class - if entity_type_alias: - all_types[entity_type_alias] = entity_class - - for prop in entity_dict.get('properties'): - prop_name = prop['name'] - - if hasattr(entity_class, prop_name): - # do not replace existing properties (from Base) - continue - - property_type = all_types.get(prop['type']) - - if property_type and issubclass(property_type, EnumType): - property_instance = EnumTypeProperty(prop_name, enum_class=property_type) - property_instance.is_computed_value = prop['is_computed_value'] - else: - type_ = self.property_type_to_python(prop['type']) - type_options = { - 'primary_key': prop['is_primary_key'], - 'is_collection': prop['is_collection'], - 'is_computed_value': prop['is_computed_value'], - } - property_instance = type_(prop_name, **type_options) - setattr(entity_class, prop_name, property_instance) - - progress.remove_task(entity_task) - progress.update(schema_task, advance=1) + super_class = parent_entity_class or entity_base_class + object_dict = dict( + __odata_schema__=entity_dict, + __odata_type__=entity_type, + ) + entity_class = type(entity_name, (super_class,), object_dict) + + all_types[entity_type] = entity_class + if entity_type_alias: + all_types[entity_type_alias] = entity_class + + for prop in entity_dict.get('properties'): + prop_name = prop['name'] + + if hasattr(entity_class, prop_name): + # do not replace existing properties (from Base) + continue + + property_type = all_types.get(prop['type']) + + if property_type and issubclass(property_type, EnumType): + property_instance = EnumTypeProperty(prop_name, enum_class=property_type) + property_instance.is_computed_value = prop['is_computed_value'] + elif property_type and issubclass(property_type, ComplexType): + property_instance = ComplexTypeProperty(prop_name, type_class=property_type) + property_instance.is_computed_value = prop['is_computed_value'] + property_instance.is_collection = prop["is_collection"] + property_instance.is_nullable = prop['is_nullable'] + else: + type_ = self.property_type_to_python(prop['type']) + type_options = { + 'primary_key': prop['is_primary_key'], + 'is_collection': prop['is_collection'], + 'is_nullable': prop['is_nullable'], + 'is_computed_value': prop['is_computed_value'], + } + property_instance = type_(prop_name, **type_options) + setattr(entity_class, prop_name, property_instance) + + progress.remove_task(entity_task) + if len(orphan_entities) > 0: + if depth > 10: + errmsg = ('Types could not be resolved. ' + 'Orphaned types: {0}').format(', '.join(orphan_entities)) + raise ODataReflectionError(errmsg) + depth += 1 + self._create_entities(progress, all_types, entity_base_class, schemas, depth) + + def _create_complextypes(self, progress, all_types, complex_base_class, schemas, depth=1): + orphan_entities = [] + for schema in schemas: + complex_task = progress.add_task(f"Creating complex types for {schema['name']}", total=len(schema.get("complex_types"))) + for complex_dict in schema.get('complex_types'): + progress.update(complex_task, advance=1) + complex_type = complex_dict['type'] + complex_name = complex_dict['name'] + + if complex_type in all_types: + continue + + parent_complex_class = None + + if complex_dict.get('base_type'): + base_type = complex_dict.get('base_type') + parent_complex_class = all_types.get(base_type) + + if parent_complex_class is None: + # base class not yet created + orphan_entities.append(complex_type) + continue + + super_class = parent_complex_class or complex_base_class + object_dict = dict( + __odata_schema__=complex_dict, + __odata_type__=complex_type, + ) + + entity_class = type(complex_name, (super_class,), object_dict) + + all_types[complex_type] = entity_class - if len(orphan_entities) > 0: - if depth > 10: - errmsg = ('Types could not be resolved. ' - 'Orphaned types: {0}').format(', '.join(orphan_entities)) - raise ODataReflectionError(errmsg) - depth += 1 - self._create_entities(all_types, entity_base_class, schemas, depth) + for prop in complex_dict.get('properties'): + prop_name = prop['name'] + + if hasattr(entity_class, prop_name): + # do not replace existing properties (from Base) + continue + + property_type = all_types.get(prop['type']) + + if property_type and issubclass(property_type, EnumType): + property_instance = EnumTypeProperty(prop_name, enum_class=property_type) + property_instance.is_computed_value = prop['is_computed_value'] + elif property_type and issubclass(property_type, ComplexType): + property_instance = ComplexTypeProperty(prop_name, type_class=property_type) + property_instance.is_computed_value = prop['is_computed_value'] + property_instance.is_nullable = prop['is_nullable'] + else: + type_ = self.property_type_to_python(prop['type']) + type_options = { + 'is_collection': prop['is_collection'], + 'is_nullable': prop['is_nullable'], + 'is_computed_value': prop['is_computed_value'], + } + property_instance = type_(prop_name, **type_options) + setattr(entity_class, prop_name, property_instance) + + progress.remove_task(complex_task) + if len(orphan_entities) > 0: + if depth > 10: + errmsg = ('Types could not be resolved. ' + 'Orphaned types: {0}').format(', '.join(orphan_entities)) + raise ODataReflectionError(errmsg) + depth += 1 + self._create_complextypes(progress, all_types, complex_base_class, schemas, depth) def _create_actions(self, all_types, actions, get_entity_or_prop_from_type): entities = self._get_entities_from_types(all_types) @@ -249,11 +327,11 @@ def get_entity_or_prop_from_type(typename): names = [(i['name'], i['value']) for i in enum_type['members']] created_enum = EnumType(enum_type['name'], names=names) all_types[enum_type['fully_qualified_name']] = created_enum - progress.remove_task(enum_task) progress.update(schema_task, advance=1) - self._create_entities(all_types, base_class, schemas) + self._create_complextypes(progress, all_types, ComplexType, schemas) + self._create_entities(progress, all_types, base_class, schemas) sets = {} for entity_set in rich.progress.track(entity_sets.values(), "Processing entity types ...", console=self.console, transient=True, disable=self.quiet): @@ -359,6 +437,7 @@ def _parse_entity(self, xmlq, entity_element, schema_name, schema_alias): entity = { 'name': entity_name, 'type': entity_type_name, + 'open': str_2_bool(entity_element.attrib.get("OpenType", "false")), 'properties': [], 'navigation_properties': [], } @@ -379,6 +458,7 @@ def _parse_entity(self, xmlq, entity_element, schema_name, schema_alias): for entity_property in xmlq(entity_element, 'edm:Property'): p_name = entity_property.attrib['Name'] p_type = entity_property.attrib['Type'] + p_nullable = str_2_bool(entity_property.attrib.get("Nullable", "true")) is_collection, p_type = self._type_is_collection(p_type) is_computed_value = False @@ -393,6 +473,7 @@ def _parse_entity(self, xmlq, entity_element, schema_name, schema_alias): 'name': p_name, 'type': p_type, 'is_primary_key': p_name in entity_pks, + 'is_nullable': p_nullable, 'is_collection': is_collection, 'is_computed_value': is_computed_value, }) @@ -400,6 +481,7 @@ def _parse_entity(self, xmlq, entity_element, schema_name, schema_alias): for nav_property in xmlq(entity_element, 'edm:NavigationProperty'): p_name = nav_property.attrib['Name'] p_type = nav_property.attrib['Type'] + p_nullable = str_2_bool(nav_property.attrib.get("Nullable", "true")) p_foreign_key = None ref_constraint = xmlq(nav_property, 'edm:ReferentialConstraint') @@ -410,6 +492,7 @@ def _parse_entity(self, xmlq, entity_element, schema_name, schema_alias): entity['navigation_properties'].append({ 'name': p_name, 'type': p_type, + 'is_nullable': p_nullable, 'foreign_key': p_foreign_key, }) return entity @@ -430,6 +513,39 @@ def _parse_enumtype(self, xmlq, enumtype_element, schema_name): }) return enum + def _parse_complextype(self, xmlq, complextype_element, schema_name): + complex_name = complextype_element.attrib['Name'] + + complex_type_name = '.'.join([schema_name, complex_name]) + + complex = { + 'name': complex_name, + 'type': complex_type_name, + 'open': str_2_bool(complextype_element.attrib.get('OpenType', 'false')), + 'properties': [] + } + + base_type = complextype_element.attrib.get('BaseType') + if base_type: + complex['base_type'] = base_type + + for complex_property in xmlq(complextype_element, 'edm:Property'): + p_name = complex_property.attrib['Name'] + p_type = complex_property.attrib['Type'] + p_nullable = str_2_bool(complex_property.attrib.get('Nullable', "true")) + + is_collection, p_type = self._type_is_collection(p_type) + is_computed_value = False + + complex['properties'].append({ + 'name': p_name, + 'type': p_type, + 'is_nullable': p_nullable, + 'is_collection': is_collection, + 'is_computed_value': is_computed_value, + }) + return complex + def parse_document(self, doc): schemas = [] container_sets = {} @@ -459,6 +575,10 @@ def xmlq(node, xpath): enum = self._parse_enumtype(xmlq, enum_type, schema_name) schema_dict['enum_types'].append(enum) + for complex_type in xmlq(schema, 'edm:ComplexType'): + complex_type = self._parse_complextype(xmlq, complex_type, schema_name) + schema_dict['complex_types'].append(complex_type) + for entity_type in xmlq(schema, 'edm:EntityType'): entity = self._parse_entity(xmlq, entity_type, schema_name, schema_alias) schema_dict['entities'].append(entity) diff --git a/odata/navproperty.py b/odata/navproperty.py index 5fd2652..be1d9a9 100644 --- a/odata/navproperty.py +++ b/odata/navproperty.py @@ -48,6 +48,7 @@ class NavigationProperty(object): not inherit from PropertyBase. """ def __init__(self, name, entitycls: Union[type, str], entity_package: str = None, collection=False, foreign_key=None): + super().__init__() from odata.property import PropertyBase self.name = name self.class_package = entity_package @@ -110,6 +111,9 @@ def __set__(self, instance, value): instance.__odata__.set_property_dirty(self) def __getattr__(self, item): + if self.is_collection: + raise AttributeError(f"Field {self.name} is a collection, you need to use any() or all() on the collection") + if item.startswith("__"): raise AttributeError(f"Skipping recursive check for {item}") if self.entitycls: @@ -118,7 +122,7 @@ def __getattr__(self, item): cpy.name = f"{self.name}/{item}" return cpy else: - raise Exception(f"Couldn't find {item} in {self.name}") + raise AttributeError(f"Couldn't find {item} in {self.name}") def navigation_url(self, instance): es = instance.__odata__ diff --git a/odata/property.py b/odata/property.py index 4134cab..93b57d1 100644 --- a/odata/property.py +++ b/odata/property.py @@ -106,6 +106,15 @@ def __str__(self): return f"{self.member} {self.op} {self.value}" +class CollectionQueryFilter(BaseQueryFilter): + def __init__(self, member, op, inner_query): + super().__init__(member, op, inner_query) + + def __str__(self): + l = self.member.lower() + return f"{self.member}/{self.op}({l}: {l}/{self.value})" + + class ParameterizedQueryFilter(BaseQueryFilter): def __str__(self): return f"{self.op}({self.member}, {self.value})" @@ -121,91 +130,10 @@ def __str__(self): return f"({self.member}) {self.op} ({self.value})" -class PropertyBase(object): - """ - A base class for all properties. +class QueryBase(object): - :param name: Name of the property in the endpoint - :param primary_key: This property is a primary key - :param is_collection: This property contains multiple values - """ - def __init__(self, name, primary_key=False, is_collection=False, is_computed_value=False): - """ - :type name: str - :type primary_key: bool - """ + def __init__(self, name): self.name = name - self.primary_key = primary_key - self.is_collection = is_collection - self.is_computed_value = is_computed_value - - def __repr__(self): - return ''.format(self.name) - - def __get__(self, instance, owner): - """ - :type instance: odata.entity.EntityBase - :type owner: odata.entity.EntityBase - """ - if instance is None: - return self - - es = instance.__odata__ - - if self.name in es: - raw_data = es[self.name] - if self.is_collection: - if raw_data is None: - return - - data = [] - for i in raw_data: - data.append(self.deserialize(i)) - return data - else: - return self.deserialize(raw_data) - else: - raise AttributeError() - - def __set__(self, instance, value): - """ - :type instance: odata.entity.EntityBase - """ - - es = instance.__odata__ - - if self.name in es: - if self.is_collection: - data = [] - for i in (value or []): - data.append(self.serialize(i)) - new_value = data - else: - new_value = self.serialize(value) - old_value = es[self.name] - if new_value != old_value: - es[self.name] = new_value - es.set_property_dirty(self) - - def serialize(self, value): - """ - Called when serializing the value to JSON. Implement this method when - creating a new Property class - - :param value: Value given in Python code - :returns: Value that will be used in JSON - """ - raise NotImplementedError() - - def deserialize(self, value): - """ - Called when deserializing the value from JSON to Python. Implement this - method when creating a new Property class - - :param value: Value received in JSON - :returns: Value that will be passed to Python - """ - raise NotImplementedError() def escape_value(self, value): """ @@ -292,6 +220,107 @@ def not_null(self): return SimpleQueryFilter(self.name, "ne", "null") +class CollectionQueryBase(object): + def __init__(self, name: str): + self.name = name + + def any(self, inner_query: BaseQueryFilter) -> BaseQueryFilter: + modified = str(inner_query).replace(f"{self.name}/", "", 1).replace(f"{self.name}.", "", 1) + return CollectionQueryFilter(self.name, "any", modified) + + def all(self, inner_query: BaseQueryFilter) -> BaseQueryFilter: + modified = str(inner_query).replace(f"{self.name}/", "", 1).replace(f"{self.name}.", "", 1) + return CollectionQueryFilter(self.name, "all", modified) + + +class PropertyBase(QueryBase): + """ + A base class for all properties. + + :param name: Name of the property in the endpoint + :param primary_key: This property is a primary key + :param is_collection: This property contains multiple values + """ + def __init__(self, name, primary_key=False, is_collection=False, is_computed_value=False, is_nullable=True): + """ + :type name: str + :type primary_key: bool + """ + super().__init__(name) + self.primary_key = primary_key + self.is_collection = is_collection + self.is_computed_value = is_computed_value + self.is_nullable = is_nullable + + def __repr__(self): + return ''.format(self.name) + + def __get__(self, instance, owner): + """ + :type instance: odata.entity.EntityBase + :type owner: odata.entity.EntityBase + """ + if instance is None: + return self + + es = instance.__odata__ + + if self.name in es: + raw_data = es[self.name] + if self.is_collection: + if raw_data is None: + return + + data = [] + for i in raw_data: + data.append(self.deserialize(i)) + return data + else: + return self.deserialize(raw_data) + else: + raise AttributeError() + + def __set__(self, instance, value): + """ + :type instance: odata.entity.EntityBase + """ + + es = instance.__odata__ + + if self.name in es: + if self.is_collection: + data = [] + for i in (value or []): + data.append(self.serialize(i)) + new_value = data + else: + new_value = self.serialize(value) + old_value = es[self.name] + if new_value != old_value: + es[self.name] = new_value + es.set_property_dirty(self) + + def serialize(self, value): + """ + Called when serializing the value to JSON. Implement this method when + creating a new Property class + + :param value: Value given in Python code + :returns: Value that will be used in JSON + """ + raise NotImplementedError() + + def deserialize(self, value): + """ + Called when deserializing the value from JSON to Python. Implement this + method when creating a new Property class + + :param value: Value received in JSON + :returns: Value that will be passed to Python + """ + raise NotImplementedError() + + class IntegerProperty(PropertyBase): """ Property that stores a plain old integer @@ -404,3 +433,20 @@ def escape_value(self, value): if value is None: return 'null' return str(value) + + +# todo: change to actual support, not string +class LocationProperty(PropertyBase): + """ + Property that stores a location + """ + def serialize(self, value): + return value + + def deserialize(self, value): + return value + + def escape_value(self, value): + if value is None: + return 'null' + return u"'{0}'".format(value.replace("'", "''")) diff --git a/odata/reflect-templates/entity.mako b/odata/reflect-templates/entity.mako index 05a219b..8da6018 100644 --- a/odata/reflect-templates/entity.mako +++ b/odata/reflect-templates/entity.mako @@ -15,7 +15,7 @@ None # Simple properties %for prop in schema['properties']: <% attr = getattr(entity, prop['name']) %>\ -<%include file="property.mako" args="property=attr, values=prop"/> +<%include file="property.mako" args="entity=entity, property=attr, values=prop"/> %endfor # Navigation properties diff --git a/odata/reflect-templates/main.mako b/odata/reflect-templates/main.mako index 5bce437..45e8b8d 100644 --- a/odata/reflect-templates/main.mako +++ b/odata/reflect-templates/main.mako @@ -13,11 +13,13 @@ # ${padding} import datetime -import uuid import decimal - +import uuid +from dataclasses import dataclass from enum import Enum +from typing import Optional +from odata.complextype import ComplexType, ComplexTypeProperty from odata.entity import EntityBase from odata.property import StringProperty, IntegerProperty, NavigationProperty, DatetimeProperty, DecimalProperty, FloatProperty, BooleanProperty, UUIDProperty from odata.enumtype import EnumType, EnumTypeProperty @@ -34,7 +36,16 @@ class ReflectionBase(EntityBase): # ************ End enum type definitions ************ +# ************ Start simple type definitions ************ + +%for type_name in simple_types: +<%include file="simple_type.mako" args="name=type_name, entity=simple_types[type_name]"/> +%endfor + +# ************ End simple type definitions ************ + # ************ Start type definitions ************ + %for type_name in types: <%include file="entity.mako" args="name=type_name, entity=types[type_name]"/> %endfor diff --git a/odata/reflect-templates/property.mako b/odata/reflect-templates/property.mako index de363fe..eb7667f 100644 --- a/odata/reflect-templates/property.mako +++ b/odata/reflect-templates/property.mako @@ -1,11 +1,18 @@ -<%page args="property, values"/>\ +<%page args="entity, property, values"/>\ <% property_name = values['name'].replace("@", "_").replace("-", "_") property_type = type(property) simple_type = property_type.__name__.split(".")[-1] - full_type = type_translations[simple_type] + if simple_type not in type_translations: + print(property, property_type, values) + property.primary_key = False + full_type = values['type'].split(".")[-1] + else: + full_type = type_translations[simple_type] if property.is_collection: full_type = "list[" + simple_type + "]" + if property.is_nullable: + full_type = "Optional[" + full_type + "]" %>\ ${property_name}: ${full_type} = ${simple_type}("${values['name']}"\ % if property.primary_key: @@ -14,8 +21,12 @@ % if property.is_collection: , is_collection=True\ % endif +, is_nullable=${property.is_nullable}\ % if property.is_computed_value: , is_computed_value=True\ + % endif + % if hasattr(property, "type_class"): +, type_class=${property.type_class.__name__}\ % endif % if hasattr(property, 'enum_class'): , enum_class=${property.enum_class.__name__}\ diff --git a/odata/reflect-templates/simple_property.mako b/odata/reflect-templates/simple_property.mako new file mode 100644 index 0000000..847369c --- /dev/null +++ b/odata/reflect-templates/simple_property.mako @@ -0,0 +1,18 @@ +<%page args="entity, property, values"/>\ +<% + property_name = values['name'].replace("@", "_").replace("-", "_") + property_type = type(property) + simple_type = property_type.__name__.split(".")[-1] + static_type = False + if simple_type not in type_translations: + static_type = True + print(property, property_type, values) + full_type = values['type'].split(".")[-1] + else: + full_type = type_translations[simple_type] + if property.is_collection: + full_type = "list[" + simple_type + "]" + if property.is_nullable: + full_type = "Optional[" + full_type + "]" +%>\ + ${property_name}: ${full_type} \ No newline at end of file diff --git a/odata/reflect-templates/simple_type.mako b/odata/reflect-templates/simple_type.mako new file mode 100644 index 0000000..dde3add --- /dev/null +++ b/odata/reflect-templates/simple_type.mako @@ -0,0 +1,14 @@ + +<%page args="name, entity"/>\ +<% short_name = name.split(".")[-1] %>\ +<% base_type = "(" + entity.__bases__[0].__name__ + ")" if len(entity.__bases__) > 0 else '' %>\ +@dataclass +class ${short_name}${base_type}: + <% + schema = entity.__odata_schema__ + %>\ + # Simple properties + %for prop in schema['properties']: +<% attr = getattr(entity, prop['name']) %>\ +<%include file="simple_property.mako" args="entity=entity, property=attr, values=prop"/> + %endfor diff --git a/odata/reflector.py b/odata/reflector.py index 977a003..233fc3c 100644 --- a/odata/reflector.py +++ b/odata/reflector.py @@ -81,29 +81,39 @@ from pathlib import Path from enum import EnumMeta +from mako import exceptions from mako.lookup import TemplateLookup from mako.runtime import Context +from odata.complextype import ComplexType type_translations = { "StringProperty": "str", "IntegerProperty": "int", - "NavigationProperty": "caca", + "NavigationProperty": "caca", # fixme: is this used ? "DatetimeProperty": "datetime.datetime", "DecimalProperty": "decimal.Decimal", "FloatProperty": "float", "BooleanProperty": "bool", "UUIDProperty": "uuid.UUID", - "EnumTypeProperty": "str" + "EnumTypeProperty": "str", + "LocationProperty": "str", } class MetadataReflector: - def __init__(self, metadata_url: str, entities: list["EntitySetCategories"], types: list["EntityBase"], package: str, quiet: bool = False): + def __init__(self, + metadata_url: str, + entities: dict[str, "EntitySetCategories"], + types: dict[str, "EntityBase"], + package: str, + console: rich.console.Console, + quiet: bool = False): self.package = package self.metadata_url = metadata_url self.entities = entities self.types = types + self.console = console self.quiet = quiet def write_reflected_types(self): @@ -114,16 +124,25 @@ def write_reflected_types(self): types = {k: v for k, v in self.types.items() if not isinstance(v, EnumMeta)} enum_types = {k: v for k, v in self.types.items() if isinstance(v, EnumMeta)} + simple_types = {k: v for k, v in self.types.items() if issubclass(v, ComplexType)} + types = {k: v for k, v in types.items() if k not in simple_types} + buffer = StringIO() context = Context(buffer, entities=self.entities, types=types, enum_types=enum_types, + simple_types=simple_types, + all_types=self.types, type_translations=type_translations, package=self.package, metadata_url=self.metadata_url) - with rich.console.Console(quiet=self.quiet).status("Loading metadata"): - template.render_context(context) + with self.console.status("Loading metadata"): + try: + template.render_context(context) + except Exception as ex: + self.console.print(exceptions.text_error_template(lookup).render()) + raise ex output_path = Path(self.package.replace(".", "/")).with_suffix(".py") if not output_path.parent.exists(): diff --git a/odata/service.py b/odata/service.py index 9c2fb0a..96510c3 100644 --- a/odata/service.py +++ b/odata/service.py @@ -82,9 +82,9 @@ class ODataService(object): """ :param url: Endpoint address. Must be an address that can be appended with ``$metadata`` - :param base: Custom base class to use for entities - :param reflect_entities: Create a request to the service for its metadata, and create entity classes automatically. If set to None it will only reflect the entities if package doesn't exist already - :param reflect_output_path: Optional parameter, if reflect_entities is configured it will create all reflected classes at this path + :param base: Custom base class to use for entities. If set to None and reflect_output_packages is defined it will attempt to read the ReflectionBase instance from the configured package name. + :param reflect_entities: Create a request to the service for its metadata, and create entity classes automatically. If set to None it will only reflect the entities if package doesn't already exist. + :param reflect_output_package: Optional parameter, if reflect_entities is configured it will create all reflected classes at this path :param session: Custom Requests session to use for communication with the endpoint :param extra_headers: Any extra headers that need to be passed to the OData service :param auth: Custom Requests auth object to use for credentials @@ -196,7 +196,13 @@ def __repr__(self): return u''.format(self.url) def _write_reflected_types(self, metadata_url: str, package: str): - outputter = MetadataReflector(metadata_url=metadata_url, entities=self.entities, types=self.types, package=package, quiet=self.quiet_progress) + outputter = MetadataReflector( + metadata_url=metadata_url, + entities=self.entities, + types=self.types, + package=package, + console=self.console, + quiet=self.quiet_progress) outputter.write_reflected_types() def create_context(self, auth=None, session=None, extra_headers: dict = None): diff --git a/odata/state.py b/odata/state.py index 2162f0e..508d39a 100644 --- a/odata/state.py +++ b/odata/state.py @@ -11,6 +11,7 @@ import rich.panel import rich.table +from odata.complextype import ComplexTypeProperty from odata.property import PropertyBase, NavigationProperty @@ -62,7 +63,7 @@ def values(self): name = prop.name if prop.is_collection: name += "[]" - if prop.primary_key: + if getattr(prop, "primary_key", False): name += '*' if prop.name in self.dirty: name += ' (dirty)' @@ -155,7 +156,7 @@ def properties(self): props = [] cls = self.entity.__class__ for key, value in inspect.getmembers(cls): - if isinstance(value, PropertyBase): + if isinstance(value, (PropertyBase, ComplexTypeProperty)): props.append((key, value)) return props