diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index ae90001b..f4b805e2 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,14 +3,18 @@ from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import interfaces +from sqlalchemy.orm import interfaces, strategies from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, String) from graphene.types.json import JSONString +from .batching import get_batch_resolver from .enums import enum_for_sa_enum +from .fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) from .registry import get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType @@ -18,6 +22,9 @@ ChoiceType = JSONType = ScalarListType = TSVectorType = object +is_selectin_available = getattr(strategies, 'SelectInLoader', None) + + def get_column_doc(column): return getattr(column, "doc", None) @@ -26,33 +33,82 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs): - direction = relationship_prop.direction - model = relationship_prop.mapper.entity - +def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, + orm_field_name, **field_kwargs): + """ + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param function|None connection_field_factory: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Dynamic + """ def dynamic_type(): - _type = registry.get_type_for_model(model) + """:rtype: Field|None""" + direction = relationship_prop.direction + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + batching_ = batching if is_selectin_available else False - if not _type: + if not child_type: return None + if direction == interfaces.MANYTOONE or not relationship_prop.uselist: - return Field( - _type, - resolver=resolver, - **field_kwargs - ) - elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type.connection: - # TODO Add a way to override connection_field_factory - return connection_field_factory(relationship_prop, registry, **field_kwargs) - return Field( - List(_type), - **field_kwargs - ) + return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, + **field_kwargs) + + if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): + return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, + connection_field_factory, **field_kwargs) return Dynamic(dynamic_type) +def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): + """ + Convert one-to-one or many-to-one relationshsip. Return an object field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + resolver = get_custom_resolver(obj_type, orm_field_name) + if resolver is None: + resolver = get_batch_resolver(relationship_prop) if batching else \ + get_attr_resolver(obj_type, relationship_prop.key) + + return Field(child_type, resolver=resolver, **field_kwargs) + + +def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): + """ + Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param function|None connection_field_factory: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + if not child_type._meta.connection: + return Field(List(child_type), **field_kwargs) + + # TODO Allow override of connection_field_factory and resolver via ORMField + if connection_field_factory is None: + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ + default_connection_field_factory + + return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) + + def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'type' not in field_kwargs: # TODO The default type should be dependent on the type of the property propety. diff --git a/graphene_sqlalchemy/resolver.py b/graphene_sqlalchemy/resolver.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py new file mode 100644 index 00000000..83a6e35d --- /dev/null +++ b/graphene_sqlalchemy/resolvers.py @@ -0,0 +1,26 @@ +from graphene.utils.get_unbound_function import get_unbound_function + + +def get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def get_attr_resolver(obj_type, model_attr): + """ + In order to support field renaming via `ORMField.model_attr`, + we need to define resolver functions for each field. + + :param SQLAlchemyObjectType obj_type: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + return lambda root, _info: getattr(root, model_attr, None) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index d8393fb0..b97002a7 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -6,8 +6,9 @@ import graphene from graphene import relay -from ..fields import BatchSQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType +from ..fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) +from ..types import ORMField, SQLAlchemyObjectType from .models import Article, HairKind, Pet, Reporter from .utils import is_sqlalchemy_version_less_than, to_std_dicts @@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class ArticleType(SQLAlchemyObjectType): class Meta: model = Article interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class PetType(SQLAlchemyObjectType): class Meta: model = Pet interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) @@ -513,3 +514,187 @@ def test_many_to_many(session_factory): }, ], } + + +def test_disable_batching_via_ormfield(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + favorite_article = ORMField(batching=False) + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + # Test one-to-many and many-to-many relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + +def test_connection_factory_field_overrides_batching_is_false(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = False + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + else: + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 1 + + +def test_connection_factory_field_overrides_batching_is_true(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + connection_field_factory = default_connection_field_factory + + articles = ORMField(batching=True) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e8051a18..e9ee2379 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -190,9 +190,12 @@ def test_should_jsontype_convert_jsonstring(): def test_should_manytomany_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, registry, default_connection_field_factory, mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -204,7 +207,7 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -220,19 +223,19 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - registry, - default_connection_field_factory, - mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -244,10 +247,7 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -262,10 +262,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -280,10 +277,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ef189b38..ff22cded 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,25 +3,23 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty, strategies) + RelationshipProperty) from sqlalchemy.orm.exc import NoResultFound from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs -from graphene.utils.get_unbound_function import get_unbound_function from graphene.utils.orderedtype import OrderedType -from .batching import get_batch_resolver from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) from .enums import (enum_for_field, sort_argument_for_object_type, sort_enum_for_object_type) -from .fields import default_connection_field_factory from .registry import Registry, get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver from .utils import get_query, is_mapped_class, is_mapped_instance @@ -33,6 +31,7 @@ def __init__( required=None, description=None, deprecation_reason=None, + batching=None, _creation_counter=None, **field_kwargs ): @@ -69,6 +68,8 @@ class Meta: Same behavior as in graphene.Field. Defaults to None. :param str deprecation_reason: Same behavior as in graphene.Field. Defaults to None. + :param bool batching: + Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. :param int _creation_counter: Same behavior as in graphene.Field. """ @@ -80,6 +81,7 @@ class Meta: 'required': required, 'description': description, 'deprecation_reason': deprecation_reason, + 'batching': batching, } common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs @@ -87,7 +89,7 @@ class Meta: def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, connection_field_factory + obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -101,7 +103,8 @@ def construct_fields( :param Registry registry: :param tuple[string] only_fields: :param tuple[string] exclude_fields: - :param function connection_field_factory: + :param bool batching: + :param function|None connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ inspected_model = sqlalchemy.inspect(model) @@ -152,40 +155,23 @@ def construct_fields( for orm_field_name, orm_field in orm_fields.items(): attr_name = orm_field.kwargs.pop('model_attr') attr = all_model_attrs[attr_name] - custom_resolver = _get_custom_resolver(obj_type, orm_field_name) + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column( - attr, - registry, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - **orm_field.kwargs - ) + field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) elif isinstance(attr, RelationshipProperty): + batching_ = orm_field.kwargs.pop('batching', batching) field = convert_sqlalchemy_relationship( - attr, - registry, - connection_field_factory, - custom_resolver or _get_relationship_resolver(obj_type, attr, attr_name), - **orm_field.kwargs - ) + attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " "Field: {}.{}".format(obj_type.__name__, orm_field_name)) - field = convert_sqlalchemy_composite( - attr, - registry, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - ) + field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): - field = convert_sqlalchemy_hybrid_method( - attr, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - **orm_field.kwargs - ) + field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: raise Exception('Property type is not supported') # Should never happen @@ -195,50 +181,6 @@ def construct_fields( return fields -def _get_custom_resolver(obj_type, orm_field_name): - """ - Since `graphene` will call `resolve_` on a field only if it - does not have a `resolver`, we need to re-implement that logic here so - users are able to override the default resolvers that we provide. - """ - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) - if resolver: - return get_unbound_function(resolver) - - return None - - -def _get_relationship_resolver(obj_type, relationship_prop, model_attr): - """ - Batch SQL queries using Dataloader to avoid the N+1 problem. - SQL batching only works for SQLAlchemy 1.2+ since it depends on - the `selectin` loader. - - :param SQLAlchemyObjectType obj_type: - :param sqlalchemy.orm.properties.RelationshipProperty relationship_prop: - :param str model_attr: the name of the SQLAlchemy attribute - :rtype: Callable - """ - if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: - # TODO Batch many-to-many and one-to-many relationships - return _get_attr_resolver(obj_type, model_attr, model_attr) - - return get_batch_resolver(relationship_prop) - - -def _get_attr_resolver(obj_type, orm_field_name, model_attr): - """ - In order to support field renaming via `ORMField.model_attr`, - we need to define resolver functions for each field. - - :param SQLAlchemyObjectType obj_type: - :param str orm_field_name: - :param str model_attr: the name of the SQLAlchemy attribute - :rtype: Callable - """ - return lambda root, _info: getattr(root, model_attr, None) - - class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): model = None # type: sqlalchemy.Model registry = None # type: sqlalchemy.Registry @@ -260,7 +202,8 @@ def __init_subclass_with_meta__( use_connection=None, interfaces=(), id=None, - connection_field_factory=default_connection_field_factory, + batching=False, + connection_field_factory=None, _meta=None, **options ): @@ -286,6 +229,7 @@ def __init_subclass_with_meta__( registry=registry, only_fields=only_fields, exclude_fields=exclude_fields, + batching=batching, connection_field_factory=connection_field_factory, ), _as=Field,