diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 04061801..a2e03694 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,11 +1,16 @@ +import datetime +import typing +import warnings +from decimal import Decimal from functools import singledispatch +from typing import Any from sqlalchemy import types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - String) +from graphene import (ID, Boolean, Date, DateTime, Dynamic, Enum, Field, Float, + Int, List, String, Time) from graphene.types.json import JSONString from .batching import get_batch_resolver @@ -14,6 +19,14 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver +from .utils import (registry_sqlalchemy_model_from_str, safe_isinstance, + singledispatchbymatchfunction, value_equals) + +try: + from typing import ForwardRef +except ImportError: + # python 3.6 + from typing import _ForwardRef as ForwardRef try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType @@ -25,7 +38,6 @@ except ImportError: EnumTypeImpl = object - is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -48,6 +60,7 @@ def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_fiel :param dict field_kwargs: :rtype: Dynamic """ + def dynamic_type(): """:rtype: Field|None""" direction = relationship_prop.direction @@ -115,8 +128,7 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'type_' not in field_kwargs: - # TODO The default type should be dependent on the type of the property propety. - field_kwargs['type_'] = String + field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) return Field( resolver=resolver, @@ -240,7 +252,7 @@ def convert_scalar_list_to_list(type, column, registry=None): def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n-1)) + return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n - 1)) @convert_sqlalchemy_type.register(types.ARRAY) @@ -260,3 +272,103 @@ def convert_json_to_string(type, column, registry=None): @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): return JSONString + + +@singledispatchbymatchfunction +def convert_sqlalchemy_hybrid_property_type(arg: Any): + existing_graphql_type = get_global_registry().get_type_for_model(arg) + if existing_graphql_type: + return existing_graphql_type + + # No valid type found, warn and fall back to graphene.String + warnings.warn( + (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." + "Falling back to \"graphene.String\"") + ) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) +def convert_sqlalchemy_hybrid_property_type_str(arg): + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) +def convert_sqlalchemy_hybrid_property_type_int(arg): + return Int + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) +def convert_sqlalchemy_hybrid_property_type_float(arg): + return Float + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(arg): + # The reason Decimal should be serialized as a String is because this is a + # base10 type used in things like money, and string allows it to not + # lose precision (which would happen if we downcasted to a Float, for example) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) +def convert_sqlalchemy_hybrid_property_type_bool(arg): + return Boolean + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) +def convert_sqlalchemy_hybrid_property_type_datetime(arg): + return DateTime + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) +def convert_sqlalchemy_hybrid_property_type_date(arg): + return Date + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) +def convert_sqlalchemy_hybrid_property_type_time(arg): + return Time + + +@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +def convert_sqlalchemy_hybrid_property_type_list_t(arg): + # type is either list[T] or List[T], generic argument at __args__[0] + internal_type = arg.__args__[0] + + graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + + return List(graphql_internal_type) + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(arg): + """ + Generate a lambda that will resolve the type at runtime + This takes care of self-references + """ + + def forward_reference_solver(): + model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) + if not model: + return String + # Always fall back to string if no ForwardRef type found. + return get_global_registry().get_type_for_model(model) + + return forward_reference_solver + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(arg): + """ + Convert Bare String into a ForwardRef + """ + + return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + + +def convert_hybrid_property_return_type(hybrid_prop): + # Grab the original method's return type annotations from inside the hybrid property + return_type_annotation = hybrid_prop.fget.__annotations__.get('return', str) + + return convert_sqlalchemy_hybrid_property_type(return_type_annotation) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 88e992b9..bda5a863 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,6 +1,9 @@ from __future__ import absolute_import +import datetime import enum +from decimal import Decimal +from typing import List, Tuple from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, func, select) @@ -69,6 +72,26 @@ class Reporter(Base): def hybrid_prop(self): return self.first_name + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_list(self) -> List[int]: + return [1, 2, 3] + column_prop = column_property( select([func.cast(func.count(id), Integer)]), doc="Column property" ) @@ -95,3 +118,102 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) mapper(ReflectedEditor, editor_table) + + +############################################ +# The models below are mainly used in the +# @hybrid_property type inference scenarios +############################################ + + +class ShoppingCartItem(Base): + __tablename__ = "shopping_cart_items" + + id = Column(Integer(), primary_key=True) + + @hybrid_property + def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] + + +class ShoppingCart(Base): + __tablename__ = "shopping_carts" + + id = Column(Integer(), primary_key=True) + + # Standard Library types + + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_decimal(self) -> Decimal: + return Decimal("3.14") + + @hybrid_property + def hybrid_prop_date(self) -> datetime.date: + return datetime.datetime.now().date() + + @hybrid_property + def hybrid_prop_time(self) -> datetime.time: + return datetime.datetime.now().time() + + @hybrid_property + def hybrid_prop_datetime(self) -> datetime.datetime: + return datetime.datetime.now() + + # Lists and Nested Lists + + @hybrid_property + def hybrid_prop_list_int(self) -> List[int]: + return [1, 2, 3] + + @hybrid_property + def hybrid_prop_list_date(self) -> List[datetime.date]: + return [self.hybrid_prop_date, self.hybrid_prop_date, self.hybrid_prop_date] + + @hybrid_property + def hybrid_prop_nested_list_int(self) -> List[List[int]]: + return [self.hybrid_prop_list_int, ] + + @hybrid_property + def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: + return [[self.hybrid_prop_list_int, ], ] + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: + return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] + + # Unsupported Type + @hybrid_property + def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: + return "this will actually", "be a string" + + # Self-references + + @hybrid_property + def hybrid_prop_self_referential(self) -> 'ShoppingCart': + return ShoppingCart(id=1) + + @hybrid_property + def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 57c43058..4b9e74ed 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,4 +1,5 @@ import enum +from typing import Dict, Union import pytest from sqlalchemy import Column, func, select, types @@ -9,9 +10,11 @@ from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene +from graphene import Boolean, Float, Int, Scalar, String from graphene.relay import Node -from graphene.types.datetime import DateTime +from graphene.types.datetime import Date, DateTime, Time from graphene.types.json import JSONString +from graphene.types.structures import List, Structure from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -20,7 +23,8 @@ default_connection_field_factory) from ..registry import Registry, get_global_registry from ..types import SQLAlchemyObjectType -from .models import Article, CompositeFullName, Pet, Reporter +from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, + ShoppingCartItem) def mock_resolver(): @@ -384,3 +388,86 @@ def __init__(self, col1, col2): Registry(), mock_resolver, ) + + +def test_sqlalchemy_hybrid_property_type_inference(): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + interfaces = (Node,) + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + interfaces = (Node,) + + ####################################################### + # Check ShoppingCartItem's Properties and Return Types + ####################################################### + + shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = { + 'hybrid_prop_shopping_cart': List(ShoppingCartType) + } + + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_item_expected_types.items(): + hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + + ################################################### + # Check ShoppingCart's Properties and Return Types + ################################################### + + shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = { + # Basic types + "hybrid_prop_str": String, + "hybrid_prop_int": Int, + "hybrid_prop_float": Float, + "hybrid_prop_bool": Boolean, + "hybrid_prop_decimal": String, # Decimals should be serialized Strings + "hybrid_prop_date": Date, + "hybrid_prop_time": Time, + "hybrid_prop_datetime": DateTime, + # Lists and Nested Lists + "hybrid_prop_list_int": List(Int), + "hybrid_prop_list_date": List(Date), + "hybrid_prop_nested_list_int": List(List(Int)), + "hybrid_prop_deeply_nested_list_int": List(List(List(Int))), + "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType), + "hybrid_prop_unsupported_type_tuple": String, + # Self Referential List + "hybrid_prop_self_referential": ShoppingCartType, + "hybrid_prop_self_referential_list": List(ShoppingCartType), + } + + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_expected_types.items(): + hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 1f15fa1a..2d660b67 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,8 +4,8 @@ import sqlalchemy.exc import sqlalchemy.orm.exc -from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, - ObjectType, Schema, String) +from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, + Node, NonNull, ObjectType, Schema, String) from graphene.relay import Connection from .. import utils @@ -74,7 +74,7 @@ class Meta: model = Article interfaces = (Node,) - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Columns "column_prop", # SQLAlchemy retuns column properties first "id", @@ -86,11 +86,16 @@ class Meta: "composite_prop", # Hybrid "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", # Relationship "pets", "articles", "favorite_article", - ] + ]) # column first_name_field = ReporterType._meta.fields['first_name'] @@ -115,6 +120,36 @@ class Meta: # "doc" is ignored by hybrid_property assert hybrid_prop.description is None + # hybrid_property_str + hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + assert hybrid_prop_str.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop_str.description is None + + # hybrid_property_int + hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + assert hybrid_prop_int.type == Int + # "doc" is ignored by hybrid_property + assert hybrid_prop_int.description is None + + # hybrid_property_float + hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + assert hybrid_prop_float.type == Float + # "doc" is ignored by hybrid_property + assert hybrid_prop_float.description is None + + # hybrid_property_bool + hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + assert hybrid_prop_bool.type == Boolean + # "doc" is ignored by hybrid_property + assert hybrid_prop_bool.description is None + + # hybrid_property_list + hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + assert hybrid_prop_list.type == List(Int) + # "doc" is ignored by hybrid_property + assert hybrid_prop_list.description is None + # relationship favorite_article_field = ReporterType._meta.fields['favorite_article'] assert isinstance(favorite_article_field, Dynamic) @@ -166,7 +201,7 @@ class Meta: interfaces = (Node,) use_connection = False - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Fields from ReporterMixin "first_name", "last_name", @@ -182,7 +217,12 @@ class Meta: # Then the automatic SQLAlchemy fields "id", "favorite_pet_kind", - ] + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ]) first_name_field = ReporterType._meta.fields['first_name'] assert isinstance(first_name_field.type, NonNull) @@ -271,7 +311,7 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ "first_name", "last_name", "column_prop", @@ -279,10 +319,15 @@ class Meta: "favorite_pet_kind", "composite_prop", "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", "pets", "articles", "favorite_article", - ] + ]) def test_only_and_exclude_fields(): @@ -387,7 +432,7 @@ class Meta: assert issubclass(CustomReporterType, ObjectType) assert CustomReporterType._meta.model == Reporter - assert len(CustomReporterType._meta.fields) == 11 + assert len(CustomReporterType._meta.fields) == 16 # Test Custom SQLAlchemyObjectType with Custom Options diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 340ad47e..301e782c 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,11 +1,15 @@ import re import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, Optional import pkg_resources from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene_sqlalchemy.registry import get_global_registry + def get_session(context): return context.get("session") @@ -87,7 +91,6 @@ def _deprecated_default_symbol_name(column_name, sort_asc): def _deprecated_object_type_for_model(cls, name): - try: return _deprecated_object_type_cache[cls, name] except KeyError: @@ -152,3 +155,54 @@ def sort_argument_for_model(cls, has_default=True): def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + + +class singledispatchbymatchfunction: + """ + Inspired by @singledispatch, this is a variant that works using a matcher function + instead of relying on the type of the first argument. + The register method can be used to register a new matcher, which is passed as the first argument: + """ + + def __init__(self, default: Callable): + self.registry: Dict[Callable, Callable] = OrderedDict() + self.default = default + + def __call__(self, *args, **kwargs): + for matcher_function, final_method in self.registry.items(): + # Register order is important. First one that matches, runs. + if matcher_function(args[0]): + return final_method(*args, **kwargs) + + # No match, using default. + return self.default(*args, **kwargs) + + def register(self, matcher_function: Callable[[Any], bool]): + + def grab_function_from_outside(f): + self.registry[matcher_function] = f + return self + + return grab_function_from_outside + + +def value_equals(value): + """A simple function that makes the equality based matcher functions for + SingleDispatchByMatchFunction prettier""" + return lambda x: x == value + + +def safe_isinstance(cls): + def safe_isinstance_checker(arg): + try: + return isinstance(arg, cls) + except TypeError: + pass + return safe_isinstance_checker + + +def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + try: + return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + except StopIteration: + pass