From 5eb5559a1ac63bdcd42f8142f9d72e99ebd3eda3 Mon Sep 17 00:00:00 2001 From: aubustou Date: Wed, 28 Aug 2019 16:54:23 +0200 Subject: [PATCH 01/67] [Python 2] Add support for unicode field names Check against six.string_types instead of str for python 2 compatibility --- graphene_sqlalchemy/enums.py | 3 ++- graphene_sqlalchemy/fields.py | 3 ++- graphene_sqlalchemy/registry.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index f100be19..0adea107 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,3 +1,4 @@ +import six from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType @@ -62,7 +63,7 @@ def enum_for_field(obj_type, field_name): if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) - if not field_name or not isinstance(field_name, str): + if not field_name or not isinstance(field_name, six.string_types): raise TypeError( "Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 266b5f37..a9f514ba 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,6 +1,7 @@ import warnings from functools import partial +import six from promise import Promise, is_thenable from sqlalchemy.orm.query import Query @@ -35,7 +36,7 @@ def model(self): def get_query(cls, model, info, sort=None, **args): query = get_query(model, info.context) if sort is not None: - if isinstance(sort, str): + if isinstance(sort, six.string_types): query = query.order_by(sort.value) else: query = query.order_by(*(col.value for col in sort)) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index acfa744b..c20bc2ca 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,5 +1,6 @@ from collections import defaultdict +import six from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Enum @@ -42,7 +43,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) ) - if not field_name or not isinstance(field_name, str): + if not field_name or not isinstance(field_name, six.string_types): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field From 8ea20865244aa3940f02d9918c1c9289ed23ee9f Mon Sep 17 00:00:00 2001 From: Anton Novosyolov Date: Mon, 9 Sep 2019 17:24:08 +0300 Subject: [PATCH 02/67] Add support for generic SQLAlchemy Array type (#246) --- graphene_sqlalchemy/converter.py | 3 ++- graphene_sqlalchemy/tests/test_converter.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index ef8715ff..9f99c8aa 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -171,8 +171,9 @@ def convert_scalar_list_to_list(type, column, registry=None): return List(String) +@convert_sqlalchemy_type.register(types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_postgres_array_to_list(_type, column, registry=None): +def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) return List(inner_type) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index efee91aa..459a3139 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -297,6 +297,12 @@ def test_should_postgresql_array_convert(): assert field.type.of_type == graphene.Int +def test_should_array_convert(): + field = get_field(types.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int + + def test_should_postgresql_json_convert(): assert get_field(postgresql.JSON()).type == graphene.JSONString From 0544f812b37f8e6b49ed3363e7010c08a600be1d Mon Sep 17 00:00:00 2001 From: Maxim Date: Wed, 11 Sep 2019 08:20:18 +0600 Subject: [PATCH 03/67] Add support for Python enums in sqlalchemy_utils.ChoiceType (#240) --- graphene_sqlalchemy/converter.py | 9 ++++++- graphene_sqlalchemy/tests/test_converter.py | 26 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 9f99c8aa..4ff55eed 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,3 +1,5 @@ +from enum import EnumMeta + from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql @@ -163,7 +165,12 @@ def convert_enum_to_enum(type, column, registry=None): @convert_sqlalchemy_type.register(ChoiceType) def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() - return Enum(name, type.choices) + if isinstance(type.choices, EnumMeta): + # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta + # do not use from_enum here because we can have more than one enum column in table + return Enum(name, list((v.name, v.value) for v in type.choices)) + else: + return Enum(name, type.choices) @convert_sqlalchemy_type.register(ScalarListType) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 459a3139..e8051a18 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -145,6 +145,32 @@ def test_should_choice_convert_enum(): assert graphene_type._meta.enum.__members__["en"].value == "English" +def test_should_enum_choice_convert_enum(): + class TestEnum(enum.Enum): + es = u"Spanish" + en = u"English" + + field = get_field(ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" + + +def test_should_intenum_choice_convert_enum(): + class TestEnum(enum.IntEnum): + one = 1 + two = 2 + + field = get_field(ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["one"].value == 1 + assert graphene_type._meta.enum.__members__["two"].value == 2 + + def test_should_columproperty_convert(): field = get_field_from_column(column_property( select([func.sum(func.cast(id, types.Integer))]).where(id == 1) From 89c37265012b0e296147a1631b44d8f5d943dc59 Mon Sep 17 00:00:00 2001 From: Clemens Tolboom Date: Wed, 23 Oct 2019 21:27:31 +0200 Subject: [PATCH 04/67] =?UTF-8?q?ValueError:=20The=20options=20'only=5Ffie?= =?UTF-8?q?lds'=20and=20'exclude=5Ffields'=20cannot=20be=20=E2=80=A6=20(#2?= =?UTF-8?q?50)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2ba0d1cb..9b617069 100644 --- a/README.md +++ b/README.md @@ -43,10 +43,10 @@ from graphene_sqlalchemy import SQLAlchemyObjectType class User(SQLAlchemyObjectType): class Meta: model = UserModel - # only return specified fields - only_fields = ("name",) - # exclude specified fields - exclude_fields = ("last_name",) + # use `only_fields` to only expose specific fields ie "name" + # only_fields = ("name",) + # use `exclude_fields` to exclude specific fields ie "last_name" + # exclude_fields = ("last_name",) class Query(graphene.ObjectType): users = graphene.List(User) From 98e6fe7c118922ccca1f5df1ab83d90967b7d2e4 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Mon, 18 Nov 2019 14:31:21 -0500 Subject: [PATCH 05/67] Fix N+1 problem for one-to-one and many-to-one relationships (#253) --- graphene_sqlalchemy/resolver.py | 0 graphene_sqlalchemy/tests/conftest.py | 32 ++- graphene_sqlalchemy/tests/test_batching.py | 228 +++++++++++++++++++++ graphene_sqlalchemy/tests/test_query.py | 11 +- graphene_sqlalchemy/tests/utils.py | 8 + graphene_sqlalchemy/types.py | 142 +++++++++++-- setup.cfg | 2 +- setup.py | 3 +- 8 files changed, 379 insertions(+), 47 deletions(-) create mode 100644 graphene_sqlalchemy/resolver.py create mode 100644 graphene_sqlalchemy/tests/test_batching.py create mode 100644 graphene_sqlalchemy/tests/utils.py diff --git a/graphene_sqlalchemy/resolver.py b/graphene_sqlalchemy/resolver.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 9dc390eb..98515051 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,6 +1,6 @@ import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import sessionmaker import graphene @@ -23,19 +23,17 @@ def convert_composite_class(composite, registry): @pytest.yield_fixture(scope="function") -def session(): - db = create_engine(test_db_url) - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) - - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) - - yield session - - # Finalize test here - transaction.rollback() - connection.close() - session.remove() +def session_factory(): + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + + yield sessionmaker(bind=engine) + + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + +@pytest.fixture(scope="function") +def session(session_factory): + return session_factory() diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py new file mode 100644 index 00000000..0881f71e --- /dev/null +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -0,0 +1,228 @@ +import contextlib +import logging + +import pkg_resources +import pytest + +import graphene + +from ..types import SQLAlchemyObjectType +from .models import Article, Reporter +from .utils import to_std_dicts + + +class MockLoggingHandler(logging.Handler): + """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): + self.messages = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.messages.append(record.getMessage()) + + +@contextlib.contextmanager +def mock_sqlalchemy_logging_handler(): + logging.basicConfig() + sql_logger = logging.getLogger('sqlalchemy.engine') + previous_level = sql_logger.level + + sql_logger.setLevel(logging.INFO) + mock_logging_handler = MockLoggingHandler() + mock_logging_handler.setLevel(logging.INFO) + sql_logger.addHandler(mock_logging_handler) + + yield mock_logging_handler + + sql_logger.setLevel(previous_level) + + +def make_fixture(session): + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + +def get_schema(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, _info): + return session.query(Article).all() + + def resolve_reporters(self, _info): + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + +def is_sqlalchemy_version_less_than(version_string): + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + + +if is_sqlalchemy_version_less_than('1.2'): + pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) + + +def test_many_to_one(session_factory): + session = session_factory() + make_fixture(session) + schema = get_schema(session) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + articles { + headline + reporter { + firstName + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date, ' + 'articles.reporter_id AS articles_reporter_id \n' + 'FROM articles', + '()', + + 'SELECT reporters.id AS reporters_id, ' + '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters \n' + 'WHERE reporters.id IN (?, ?)', + '(1, 2)', + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + + +def test_one_to_one(session_factory): + session = session_factory() + make_fixture(session) + schema = get_schema(session) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?) ' + 'ORDER BY articles.reporter_id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 74a7249a..45272e0b 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -5,16 +5,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter - - -def to_std_dicts(value): - """Convert nested ordered dicts to normal dicts for better comparison.""" - if isinstance(value, dict): - return {k: to_std_dicts(v) for k, v in value.items()} - elif isinstance(value, list): - return [to_std_dicts(v) for v in value] - else: - return value +from .utils import to_std_dicts def add_test_data(session): diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py new file mode 100644 index 00000000..b59ab0e8 --- /dev/null +++ b/graphene_sqlalchemy/tests/utils.py @@ -0,0 +1,8 @@ +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2b3e5728..23c8288e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,11 +1,12 @@ from collections import OrderedDict import sqlalchemy +from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty) + RelationshipProperty, Session, strategies) from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.query import QueryContext from graphene import Field from graphene.relay import Connection, Node @@ -104,7 +105,7 @@ def construct_fields( :param function connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ - inspected_model = sqlalchemyinspect(model) + inspected_model = sqlalchemy.inspect(model) # Gather all the relevant attributes from the SQLAlchemy model in order all_model_attrs = OrderedDict( inspected_model.column_attrs.items() + @@ -152,22 +153,40 @@ def construct_fields( for orm_field_name, orm_field in orm_fields.items(): attr_name = orm_field.kwargs.pop('model_attr') attr = all_model_attrs[attr_name] - resolver = _get_field_resolver(obj_type, orm_field_name, attr_name) + custom_resolver = _get_custom_resolver(obj_type, orm_field_name) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_column( + attr, + registry, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + **orm_field.kwargs + ) elif isinstance(attr, RelationshipProperty): - field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, resolver, - **orm_field.kwargs) + field = convert_sqlalchemy_relationship( + attr, + registry, + connection_field_factory, + custom_resolver or _get_relationship_resolver(obj_type, attr, attr_name), + **orm_field.kwargs + ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " "Field: {}.{}".format(obj_type.__name__, orm_field_name)) - field = convert_sqlalchemy_composite(attr, registry, resolver) + field = convert_sqlalchemy_composite( + attr, + registry, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + ) elif isinstance(attr, hybrid_property): - field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_hybrid_method( + attr, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + **orm_field.kwargs + ) else: raise Exception('Property type is not supported') # Should never happen @@ -177,22 +196,109 @@ def construct_fields( return fields -def _get_field_resolver(obj_type, orm_field_name, model_attr): +def _get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def _get_relationship_resolver(obj_type, relationship_prop, model_attr): + """ + Batch SQL queries using Dataloader to avoid the N+1 problem. + SQL batching only works for SQLAlchemy 1.2+ since it depends on + the `selectin` loader. + + :param SQLAlchemyObjectType obj_type: + :param sqlalchemy.orm.properties.RelationshipProperty relationship_prop: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + child_mapper = relationship_prop.mapper + parent_mapper = relationship_prop.parent + + if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: + # TODO Batch many-to-many and one-to-many relationships + return _get_attr_resolver(obj_type, model_attr, model_attr) + + class NonListRelationshipLoader(dataloader.DataLoader): + cache = False + + def batch_load_fn(self, parents): # pylint: disable=method-hidden + """ + Batch loads the relationship of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = QueryContext(session.query(parent_mapper.entity)) + + loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + + return promise.Promise.resolve([getattr(parent, model_attr) for parent in parents]) + + loader = NonListRelationshipLoader() + + def resolve(root, info): + return loader.load(root) + + return resolve + + +def _get_attr_resolver(obj_type, orm_field_name, model_attr): """ In order to support field renaming via `ORMField.model_attr`, we need to define resolver functions for each field. :param SQLAlchemyObjectType obj_type: - :param model: the SQLAlchemy model - :param str model_attr: the name of SQLAlchemy of the attribute used to resolve the field + :param str orm_field_name: + :param str model_attr: the name of the SQLAlchemy attribute :rtype: Callable """ - # Since `graphene` will call `resolve_` on a field only if it - # does not have a `resolver`, we need to re-implement that logic here. - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) - if resolver: - return get_unbound_function(resolver) - return lambda root, _info: getattr(root, model_attr, None) diff --git a/setup.cfg b/setup.cfg index 0aa80ba9..880c87d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ max-line-length = 120 no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,mock,models,nameko,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=app,database,flask,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy diff --git a/setup.py b/setup.py index 66704b28..4e7c4f9c 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,9 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene "graphene>=2.1.3,<3", + "promise>=2.1", # Tests fail with 1.0.19 - "SQLAlchemy>=1.1,<2", + "SQLAlchemy>=1.2,<2", "six>=1.10.0,<2", "singledispatch>=3.4.0.3,<4", ] From d90de4ae8547e6725a1ec7bf4914be55d1fe32de Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Wed, 22 Jan 2020 17:58:55 -0500 Subject: [PATCH 06/67] Fix N+1 problem for one-to-many and many-to-many relationships (#254) This optimization batches what used to be multiple SQL statements into a single SQL statement. For now, you'll have to enable the optimization via the `SQLAlchemyObjectType.Meta.connection_field_factory` (see `test_batching.py`). --- .gitignore | 2 + graphene_sqlalchemy/__init__.py | 2 +- graphene_sqlalchemy/batching.py | 69 ++++ graphene_sqlalchemy/fields.py | 38 ++- graphene_sqlalchemy/tests/models.py | 2 +- graphene_sqlalchemy/tests/test_batching.py | 356 +++++++++++++++++++-- graphene_sqlalchemy/tests/test_fields.py | 40 ++- graphene_sqlalchemy/types.py | 69 +--- 8 files changed, 458 insertions(+), 120 deletions(-) create mode 100644 graphene_sqlalchemy/batching.py diff --git a/.gitignore b/.gitignore index d4f71e35..a97b8c21 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ var/ *.egg-info/ .installed.cfg *.egg +.python-version # PyInstaller # Usually these files are written by a python script from a template @@ -47,6 +48,7 @@ nosetests.xml coverage.xml *,cover .pytest_cache/ +.benchmarks/ # Translations *.mo diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 9ed4b0f6..ba71f614 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.2.2" +__version__ = "2.3.0.dev0" __all__ = [ "__version__", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py new file mode 100644 index 00000000..0665248f --- /dev/null +++ b/graphene_sqlalchemy/batching.py @@ -0,0 +1,69 @@ +import sqlalchemy +from promise import dataloader, promise +from sqlalchemy.orm import Session, strategies +from sqlalchemy.orm.query import QueryContext + + +def get_batch_resolver(relationship_prop): + class RelationshipLoader(dataloader.DataLoader): + cache = False + + def batch_load_fn(self, parents): # pylint: disable=method-hidden + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = relationship_prop.mapper + parent_mapper = relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = QueryContext(session.query(parent_mapper.entity)) + + loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + + return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) + + loader = RelationshipLoader() + + def resolve(root, info, **args): + return loader.load(root) + + return resolve diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index a9f514ba..840204ae 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -9,6 +9,7 @@ from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .batching import get_batch_resolver from .utils import get_query @@ -33,14 +34,8 @@ def model(self): return self.type._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if isinstance(sort, six.string_types): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) - return query + def get_query(cls, model, info, **args): + return get_query(model, info.context) @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) +# TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): def __init__(self, type, *args, **kwargs): if "sort" not in kwargs and issubclass(type, Connection): @@ -95,6 +91,32 @@ def __init__(self, type, *args, **kwargs): del kwargs["sort"] super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + @classmethod + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if isinstance(sort, six.string_types): + query = query.order_by(sort.value) + else: + query = query.order_by(*(col.value for col in sort)) + return query + + +class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): + """ + This is currently experimental. + The API and behavior may change in future versions. + Use at your own risk. + """ + def get_resolver(self, parent_resolver): + return partial(self.connection_resolver, self.resolver, self.type, self.model) + + @classmethod + def from_relationship(cls, relationship, registry, **field_kwargs): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + def default_connection_field_factory(relationship, registry, **field_kwargs): model = relationship.mapper.entity diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 1df28333..88e992b9 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -61,7 +61,7 @@ class Reporter(Base): last_name = Column(String(30), doc="Last name") email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters") + pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 0881f71e..77681069 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -5,9 +5,11 @@ import pytest import graphene +from graphene import relay +from ..fields import BatchSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from .models import Article, Reporter +from .models import Article, HairKind, Pet, Reporter from .utils import to_std_dicts @@ -37,46 +39,34 @@ def mock_sqlalchemy_logging_handler(): sql_logger.setLevel(previous_level) -def make_fixture(session): - reporter_1 = Reporter( - first_name='Reporter_1', - ) - session.add(reporter_1) - reporter_2 = Reporter( - first_name='Reporter_2', - ) - session.add(reporter_2) - - article_1 = Article(headline='Article_1') - article_1.reporter = reporter_1 - session.add(article_1) - - article_2 = Article(headline='Article_2') - article_2.reporter = reporter_2 - session.add(article_2) - - session.commit() - session.close() - - -def get_schema(session): +def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class ArticleType(SQLAlchemyObjectType): class Meta: model = Article + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) reporters = graphene.Field(graphene.List(ReporterType)) - def resolve_articles(self, _info): - return session.query(Article).all() + def resolve_articles(self, info): + return info.context.get('session').query(Article).all() - def resolve_reporters(self, _info): - return session.query(Reporter).all() + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() return graphene.Schema(query=Query) @@ -91,8 +81,28 @@ def is_sqlalchemy_version_less_than(version_string): def test_many_to_one(session_factory): session = session_factory() - make_fixture(session) - schema = get_schema(session) + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level @@ -115,6 +125,8 @@ def test_many_to_one(session_factory): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + assert len(sql_statements) == 1 return assert messages == [ @@ -160,8 +172,28 @@ def test_many_to_one(session_factory): def test_one_to_one(session_factory): session = session_factory() - make_fixture(session) - schema = get_schema(session) + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level @@ -184,6 +216,8 @@ def test_one_to_one(session_factory): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 return assert messages == [ @@ -226,3 +260,261 @@ def test_one_to_one(session_factory): }, ], } + + +def test_one_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline='Article_3') + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline='Article_4') + article_4.reporter = reporter_2 + session.add(article_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?) ' + 'ORDER BY articles.reporter_id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], + }, + }, + ], + } + + +def test_many_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT reporters_1.id AS reporters_1_id, ' + 'pets.id AS pets_id, ' + 'pets.name AS pets_name, ' + 'pets.pet_kind AS pets_pet_kind, ' + 'pets.hair_kind AS pets_hair_kind, ' + 'pets.reporter_id AS pets_reporter_id \n' + 'FROM reporters AS reporters_1 ' + 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' + 'JOIN pets ON pets.id = association_1.pet_id \n' + 'WHERE reporters_1.id IN (?, ?) ' + 'ORDER BY reporters_1.id, pets.id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], + }, + }, + ], + } diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 875b729d..557ff114 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,9 +1,11 @@ import pytest from promise import Promise -from graphene.relay import Connection +from graphene import ObjectType +from graphene.relay import Connection, Node -from ..fields import SQLAlchemyConnectionField +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField) from ..types import SQLAlchemyObjectType from .models import Editor as EditorModel from .models import Pet as PetModel @@ -12,44 +14,58 @@ class Pet(SQLAlchemyObjectType): class Meta: model = PetModel + interfaces = (Node,) class Editor(SQLAlchemyObjectType): class Meta: model = EditorModel - -class PetConnection(Connection): - class Meta: - node = Pet +## +# SQLAlchemyConnectionField +## def test_promise_connection_resolver(): def resolver(_obj, _info): return Promise.resolve([]) - result = SQLAlchemyConnectionField.connection_resolver( - resolver, PetConnection, Pet, None, None + result = UnsortedSQLAlchemyConnectionField.connection_resolver( + resolver, Pet._meta.connection, Pet, None, None ) assert isinstance(result, Promise) +def test_type_assert_sqlalchemy_object_type(): + with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): + SQLAlchemyConnectionField(ObjectType).type + + +def test_type_assert_object_has_connection(): + with pytest.raises(AssertionError, match="doesn't have a connection"): + SQLAlchemyConnectionField(Editor).type + +## +# UnsortedSQLAlchemyConnectionField +## + + def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(PetConnection) + field = SQLAlchemyConnectionField(Pet._meta.connection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(PetConnection, sort=None) + field = SQLAlchemyConnectionField(Pet._meta.connection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(PetConnection, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(Pet._meta.connection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() -def test_init_raises(): +def test_sort_init_raises(): with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 23c8288e..2ed5110e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,12 +1,10 @@ from collections import OrderedDict import sqlalchemy -from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty, Session, strategies) + RelationshipProperty, strategies) from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.query import QueryContext from graphene import Field from graphene.relay import Connection, Node @@ -15,6 +13,7 @@ from graphene.utils.get_unbound_function import get_unbound_function from graphene.utils.orderedtype import OrderedType +from .batching import get_batch_resolver from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, @@ -220,73 +219,11 @@ def _get_relationship_resolver(obj_type, relationship_prop, model_attr): :param str model_attr: the name of the SQLAlchemy attribute :rtype: Callable """ - child_mapper = relationship_prop.mapper - parent_mapper = relationship_prop.parent - if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: # TODO Batch many-to-many and one-to-many relationships return _get_attr_resolver(obj_type, model_attr, model_attr) - class NonListRelationshipLoader(dataloader.DataLoader): - cache = False - - def batch_load_fn(self, parents): # pylint: disable=method-hidden - """ - Batch loads the relationship of all the parents as one SQL statement. - - There is no way to do this out-of-the-box with SQLAlchemy but - we can piggyback on some internal APIs of the `selectin` - eager loading strategy. It's a bit hacky but it's preferable - than re-implementing and maintainnig a big chunk of the `selectin` - loader logic ourselves. - - The approach here is to build a regular query that - selects the parent and `selectin` load the relationship. - But instead of having the query emits 2 `SELECT` statements - when callling `all()`, we skip the first `SELECT` statement - and jump right before the `selectin` loader is called. - To accomplish this, we have to construct objects that are - normally built in the first part of the query in order - to call directly `SelectInLoader._load_for_path`. - - TODO Move this logic to a util in the SQLAlchemy repo as per - SQLAlchemy's main maitainer suggestion. - See https://git.io/JewQ7 - """ - session = Session.object_session(parents[0]) - - # These issues are very unlikely to happen in practice... - for parent in parents: - assert parent.__mapper__ is parent_mapper - # All instances must share the same session - assert session is Session.object_session(parent) - # The behavior of `selectin` is undefined if the parent is dirty - assert parent not in session.dirty - - loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - - # Should the boolean be set to False? Does it matter for our purposes? - states = [(sqlalchemy.inspect(parent), True) for parent in parents] - - # For our purposes, the query_context will only used to get the session - query_context = QueryContext(session.query(parent_mapper.entity)) - - loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - ) - - return promise.Promise.resolve([getattr(parent, model_attr) for parent in parents]) - - loader = NonListRelationshipLoader() - - def resolve(root, info): - return loader.load(root) - - return resolve + return get_batch_resolver(relationship_prop) def _get_attr_resolver(obj_type, orm_field_name, model_attr): From 631513fe42cb0b0613349f17118faabf879107ac Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Fri, 24 Jan 2020 13:34:53 -0500 Subject: [PATCH 07/67] Add benchmark for connection fields (#259) Add `pytest-benchmark` so we can easily track performance changes over time Others: * disable tests for Python 3.4 * upgrade coveralls --- .travis.yml | 3 - graphene_sqlalchemy/batching.py | 9 +- graphene_sqlalchemy/tests/test_batching.py | 7 +- graphene_sqlalchemy/tests/test_benchmark.py | 226 ++++++++++++++++++++ graphene_sqlalchemy/tests/utils.py | 8 + setup.cfg | 2 +- setup.py | 5 +- tox.ini | 2 +- 8 files changed, 245 insertions(+), 17 deletions(-) create mode 100644 graphene_sqlalchemy/tests/test_benchmark.py diff --git a/.travis.yml b/.travis.yml index 39151a5d..5a988428 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,6 @@ matrix: - env: TOXENV=py27 python: 2.7 # Python 3.5 - - env: TOXENV=py34 - python: 3.4 - # Python 3.5 - env: TOXENV=py35 python: 3.5 # Python 3.6 diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 0665248f..baf01deb 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -5,6 +5,11 @@ def get_batch_resolver(relationship_prop): + + # Cache this across `batch_load_fn` calls + # This is so SQL string generation is cached under-the-hood via `bakery` + selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + class RelationshipLoader(dataloader.DataLoader): cache = False @@ -43,15 +48,13 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden # The behavior of `selectin` is undefined if the parent is dirty assert parent not in session.dirty - loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - # Should the boolean be set to False? Does it matter for our purposes? states = [(sqlalchemy.inspect(parent), True) for parent in parents] # For our purposes, the query_context will only used to get the session query_context = QueryContext(session.query(parent_mapper.entity)) - loader._load_for_path( + selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 77681069..d8393fb0 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -1,7 +1,6 @@ import contextlib import logging -import pkg_resources import pytest import graphene @@ -10,7 +9,7 @@ from ..fields import BatchSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from .models import Article, HairKind, Pet, Reporter -from .utils import to_std_dicts +from .utils import is_sqlalchemy_version_less_than, to_std_dicts class MockLoggingHandler(logging.Handler): @@ -71,10 +70,6 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) -def is_sqlalchemy_version_less_than(version_string): - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) - - if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py new file mode 100644 index 00000000..1e5ee4f1 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -0,0 +1,226 @@ +import pytest +from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend + +import graphene +from graphene import relay + +from ..fields import BatchSQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from .models import Article, HairKind, Pet, Reporter +from .utils import is_sqlalchemy_version_less_than + +if is_sqlalchemy_version_less_than('1.2'): + pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) + + +def get_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, info): + return info.context.get('session').query(Article).all() + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + return graphene.Schema(query=Query) + + +def benchmark_query(session_factory, benchmark, query): + schema = get_schema() + cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) + cached_backend.document_from_string(schema, query) # Prime cache + + @benchmark + def execute_query(): + result = schema.execute( + query, + context_value={"session": session_factory()}, + backend=cached_backend, + ) + assert not result.errors + + +def test_one_to_one(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """) + + +def test_many_to_one(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + articles { + headline + reporter { + firstName + } + } + } + """) + + +def test_one_to_many(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline='Article_3') + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline='Article_4') + article_4.reporter = reporter_2 + session.add(article_4) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """) + + +def test_many_to_many(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """) diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index b59ab0e8..428757c3 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,3 +1,6 @@ +import pkg_resources + + def to_std_dicts(value): """Convert nested ordered dicts to normal dicts for better comparison.""" if isinstance(value, dict): @@ -6,3 +9,8 @@ def to_std_dicts(value): return [to_std_dicts(v) for v in value] else: return value + + +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) diff --git a/setup.cfg b/setup.cfg index 880c87d6..4e8e5029 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ max-line-length = 120 no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy diff --git a/setup.py b/setup.py index 4e7c4f9c..f16c8ff5 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "mock==2.0.0", "pytest-cov==2.6.1", "sqlalchemy_utils==0.33.9", + "pytest-benchmark==3.2.1", ] setup( @@ -48,8 +49,6 @@ "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", @@ -61,7 +60,7 @@ extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "coveralls==1.7.0", + "coveralls==1.10.0", "pre-commit==1.14.4", ], "test": tests_require, diff --git a/tox.ini b/tox.ini index e55f7d9b..562da2dc 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pre-commit,py{27,34,35,36,37}-sql{11,12,13} +envlist = pre-commit,py{27,35,36,37}-sql{11,12,13} skipsdist = true minversion = 3.7.0 From 6dca2794cd85e3ae77b6c11e3ec81d9fe81175e9 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Tue, 11 Feb 2020 16:31:35 -0500 Subject: [PATCH 08/67] Add class property `connection` to `SQLAlchemyObjectType` (#263) Currently to get the default connection of a `SQLAlchemyObjectType`, you have to go through `_meta`. For example, `PetType._meta.connection`. This adds a public way to get it. --- graphene_sqlalchemy/converter.py | 2 +- graphene_sqlalchemy/fields.py | 6 +++--- graphene_sqlalchemy/tests/test_fields.py | 8 ++++---- graphene_sqlalchemy/tests/test_types.py | 10 ++++++++++ graphene_sqlalchemy/types.py | 2 ++ 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 4ff55eed..ae90001b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -42,7 +42,7 @@ def dynamic_type(): **field_kwargs ) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type._meta.connection: + if _type.connection: # TODO Add a way to override connection_field_factory return connection_field_factory(relationship_prop, registry, **field_kwargs) return Field( diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 840204ae..254319f9 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -24,10 +24,10 @@ def type(self): assert issubclass(_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(_type.__name__) - assert _type._meta.connection, "The type {} doesn't have a connection".format( + assert _type.connection, "The type {} doesn't have a connection".format( _type.__name__ ) - return _type._meta.connection + return _type.connection @property def model(self): @@ -115,7 +115,7 @@ def get_resolver(self, parent_resolver): def from_relationship(cls, relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) def default_connection_field_factory(relationship, registry, **field_kwargs): diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 557ff114..9ed3c4aa 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -31,7 +31,7 @@ def resolver(_obj, _info): return Promise.resolve([]) result = UnsortedSQLAlchemyConnectionField.connection_resolver( - resolver, Pet._meta.connection, Pet, None, None + resolver, Pet.connection, Pet, None, None ) assert isinstance(result, Promise) @@ -51,18 +51,18 @@ def test_type_assert_object_has_connection(): def test_sort_added_by_default(): - field = SQLAlchemyConnectionField(Pet._meta.connection) + field = SQLAlchemyConnectionField(Pet.connection) assert "sort" in field.args assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - field = SQLAlchemyConnectionField(Pet._meta.connection, sort=None) + field = SQLAlchemyConnectionField(Pet.connection, sort=None) assert "sort" not in field.args def test_custom_sort(): - field = SQLAlchemyConnectionField(Pet._meta.connection, sort=Editor.sort_argument()) + field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument()) assert field.args["sort"] == Editor.sort_argument() diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index fda8e659..bf563b6e 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,6 +4,7 @@ from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) +from graphene.relay import Connection from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, @@ -46,6 +47,15 @@ class Meta: assert reporter == reporter_node +def test_connection(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + assert issubclass(ReporterType.connection, Connection) + + def test_sqlalchemy_default_fields(): @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2ed5110e..ef189b38 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -325,6 +325,8 @@ def __init_subclass_with_meta__( _meta.connection = connection _meta.id = id or "id" + cls.connection = connection # Public way to get the connection + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options ) From 4c5b4d1972d6ae09d228952fdd0ff2d717d851a3 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Tue, 11 Feb 2020 20:41:50 -0500 Subject: [PATCH 09/67] [documentation] Fix Connection patterns (#264) --- docs/examples.rst | 14 ++------------ docs/tips.rst | 7 +------ docs/tutorial.rst | 14 ++------------ examples/flask_sqlalchemy/schema.py | 8 ++++---- examples/nameko_sqlalchemy/README.md | 1 - examples/nameko_sqlalchemy/database.py | 2 +- examples/nameko_sqlalchemy/schema.py | 15 ++++++--------- graphene_sqlalchemy/tests/test_query.py | 14 +++----------- graphene_sqlalchemy/tests/test_sort_enums.py | 20 ++++++++------------ 9 files changed, 27 insertions(+), 68 deletions(-) diff --git a/docs/examples.rst b/docs/examples.rst index 283a0f5e..2013cfbb 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -13,22 +13,12 @@ Search all Models with Union interfaces = (relay.Node,) - class BookConnection(relay.Connection): - class Meta: - node = Book - - class Author(SQLAlchemyObjectType): class Meta: model = AuthorModel interfaces = (relay.Node,) - class AuthorConnection(relay.Connection): - class Meta: - node = Author - - class SearchResult(graphene.Union): class Meta: types = (Book, Author) @@ -39,8 +29,8 @@ Search all Models with Union search = graphene.List(SearchResult, q=graphene.String()) # List field for search results # Normal Fields - all_books = SQLAlchemyConnectionField(BookConnection) - all_authors = SQLAlchemyConnectionField(AuthorConnection) + all_books = SQLAlchemyConnectionField(Book.connection) + all_authors = SQLAlchemyConnectionField(Author.connection) def resolve_search(self, info, **args): q = args.get("q") # Search query diff --git a/docs/tips.rst b/docs/tips.rst index 1fd39107..baa8233f 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -50,13 +50,8 @@ Given the model model = Pet - class PetConnection(Connection): - class Meta: - node = PetNode - - class Query(ObjectType): - allPets = SQLAlchemyConnectionField(PetConnection) + allPets = SQLAlchemyConnectionField(PetNode.connection) some of the allowed queries are diff --git a/docs/tutorial.rst b/docs/tutorial.rst index bc5ee62d..3c4c135e 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -102,28 +102,18 @@ Create ``flask_sqlalchemy/schema.py`` and type the following: interfaces = (relay.Node, ) - class DepartmentConnection(relay.Connection): - class Meta: - node = Department - - class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel interfaces = (relay.Node, ) - class EmployeeConnection(relay.Connection): - class Meta: - node = Employee - - class Query(graphene.ObjectType): node = relay.Node.Field() # Allows sorting over multiple columns, by default over the primary key - all_employees = SQLAlchemyConnectionField(EmployeeConnection) + all_employees = SQLAlchemyConnectionField(Employee.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) schema = graphene.Schema(query=Query) diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 9ed09464..ea525e3b 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -29,11 +29,11 @@ class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee, sort=Employee.sort_argument()) + Employee.connection, sort=Employee.sort_argument()) # Allows sorting over multiple columns, by default over the primary key - all_roles = SQLAlchemyConnectionField(Role) + all_roles = SQLAlchemyConnectionField(Role.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(Department, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/examples/nameko_sqlalchemy/README.md b/examples/nameko_sqlalchemy/README.md index 6302cb33..e0803895 100644 --- a/examples/nameko_sqlalchemy/README.md +++ b/examples/nameko_sqlalchemy/README.md @@ -46,7 +46,6 @@ Now the following command will setup the database, and start the server: ```bash ./run.sh - ``` Now head on over to postman and send POST request to: diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index 01e76ca6..ca4d4122 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -14,7 +14,7 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py index a33cab9b..ced300b3 100644 --- a/examples/nameko_sqlalchemy/schema.py +++ b/examples/nameko_sqlalchemy/schema.py @@ -8,31 +8,28 @@ class Department(SQLAlchemyObjectType): - class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): - class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): - class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) - all_roles = SQLAlchemyConnectionField(Role) + all_employees = SQLAlchemyConnectionField(Employee.connection) + all_roles = SQLAlchemyConnectionField(Role.connection) role = graphene.Field(Role) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 45272e0b..39140814 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,5 +1,5 @@ import graphene -from graphene.relay import Connection, Node +from graphene.relay import Node from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField @@ -96,14 +96,10 @@ class Meta: model = Article interfaces = (Node,) - class ArticleConnection(Connection): - class Meta: - node = ArticleNode - class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) - all_articles = SQLAlchemyConnectionField(ArticleConnection) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) def resolve_reporter(self, _info): return session.query(Reporter).first() @@ -230,13 +226,9 @@ class Meta: model = Editor interfaces = (Node,) - class EditorConnection(Connection): - class Meta: - node = EditorNode - class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorConnection) + all_editors = SQLAlchemyConnectionField(EditorNode.connection) query = """ query { diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 1eb106da..d6f6965d 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -2,7 +2,7 @@ import sqlalchemy as sa from graphene import Argument, Enum, List, ObjectType, Schema -from graphene.relay import Connection, Node +from graphene.relay import Node from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType @@ -249,22 +249,18 @@ class Meta: model = Pet interfaces = (Node,) - class PetConnection(Connection): - class Meta: - node = PetNode - class Query(ObjectType): - defaultSort = SQLAlchemyConnectionField(PetConnection) - nameSort = SQLAlchemyConnectionField(PetConnection) - multipleSort = SQLAlchemyConnectionField(PetConnection) - descSort = SQLAlchemyConnectionField(PetConnection) + defaultSort = SQLAlchemyConnectionField(PetNode.connection) + nameSort = SQLAlchemyConnectionField(PetNode.connection) + multipleSort = SQLAlchemyConnectionField(PetNode.connection) + descSort = SQLAlchemyConnectionField(PetNode.connection) singleColumnSort = SQLAlchemyConnectionField( - PetConnection, sort=Argument(PetNode.sort_enum()) + PetNode.connection, sort=Argument(PetNode.sort_enum()) ) noDefaultSort = SQLAlchemyConnectionField( - PetConnection, sort=PetNode.sort_argument(has_default=False) + PetNode.connection, sort=PetNode.sort_argument(has_default=False) ) - noSort = SQLAlchemyConnectionField(PetConnection, sort=None) + noSort = SQLAlchemyConnectionField(PetNode.connection, sort=None) query = """ query sortTest { From 7a48d3dd6b297c6d13d01df1f268a16e52f3deb0 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Tue, 11 Feb 2020 21:10:44 -0500 Subject: [PATCH 10/67] Bump promise to 2.3 (#265) It contains a fix for a thread-safety issue in Dataloader. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f16c8ff5..7b350c39 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene "graphene>=2.1.3,<3", - "promise>=2.1", + "promise>=2.3", # Tests fail with 1.0.19 "SQLAlchemy>=1.2,<2", "six>=1.10.0,<2", From 17d535efba03070cbc505d915673e0f24d9ca60c Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Wed, 12 Feb 2020 10:02:34 -0500 Subject: [PATCH 11/67] Add `batching` params (#260) Add parameters to toggle batching on or off. This can be configured at 2 levels: - we can configure all the fields of a type at once via SQLAlchemyObjectType.meta.batching - or we can specify it for a specific field via ORMfield.batching. This trumps SQLAlchemyObjectType.meta.batching. --- graphene_sqlalchemy/converter.py | 96 ++++++++-- graphene_sqlalchemy/resolver.py | 0 graphene_sqlalchemy/resolvers.py | 26 +++ graphene_sqlalchemy/tests/test_batching.py | 195 +++++++++++++++++++- graphene_sqlalchemy/tests/test_converter.py | 36 ++-- graphene_sqlalchemy/types.py | 92 ++------- 6 files changed, 325 insertions(+), 120 deletions(-) delete mode 100644 graphene_sqlalchemy/resolver.py create mode 100644 graphene_sqlalchemy/resolvers.py diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index ae90001b..f4b805e2 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,14 +3,18 @@ from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import interfaces +from sqlalchemy.orm import interfaces, strategies from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, String) from graphene.types.json import JSONString +from .batching import get_batch_resolver from .enums import enum_for_sa_enum +from .fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) from .registry import get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType @@ -18,6 +22,9 @@ ChoiceType = JSONType = ScalarListType = TSVectorType = object +is_selectin_available = getattr(strategies, 'SelectInLoader', None) + + def get_column_doc(column): return getattr(column, "doc", None) @@ -26,33 +33,82 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs): - direction = relationship_prop.direction - model = relationship_prop.mapper.entity - +def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, + orm_field_name, **field_kwargs): + """ + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param function|None connection_field_factory: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Dynamic + """ def dynamic_type(): - _type = registry.get_type_for_model(model) + """:rtype: Field|None""" + direction = relationship_prop.direction + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + batching_ = batching if is_selectin_available else False - if not _type: + if not child_type: return None + if direction == interfaces.MANYTOONE or not relationship_prop.uselist: - return Field( - _type, - resolver=resolver, - **field_kwargs - ) - elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type.connection: - # TODO Add a way to override connection_field_factory - return connection_field_factory(relationship_prop, registry, **field_kwargs) - return Field( - List(_type), - **field_kwargs - ) + return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, + **field_kwargs) + + if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): + return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, + connection_field_factory, **field_kwargs) return Dynamic(dynamic_type) +def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): + """ + Convert one-to-one or many-to-one relationshsip. Return an object field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + resolver = get_custom_resolver(obj_type, orm_field_name) + if resolver is None: + resolver = get_batch_resolver(relationship_prop) if batching else \ + get_attr_resolver(obj_type, relationship_prop.key) + + return Field(child_type, resolver=resolver, **field_kwargs) + + +def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): + """ + Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param function|None connection_field_factory: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + if not child_type._meta.connection: + return Field(List(child_type), **field_kwargs) + + # TODO Allow override of connection_field_factory and resolver via ORMField + if connection_field_factory is None: + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ + default_connection_field_factory + + return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) + + def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'type' not in field_kwargs: # TODO The default type should be dependent on the type of the property propety. diff --git a/graphene_sqlalchemy/resolver.py b/graphene_sqlalchemy/resolver.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py new file mode 100644 index 00000000..83a6e35d --- /dev/null +++ b/graphene_sqlalchemy/resolvers.py @@ -0,0 +1,26 @@ +from graphene.utils.get_unbound_function import get_unbound_function + + +def get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def get_attr_resolver(obj_type, model_attr): + """ + In order to support field renaming via `ORMField.model_attr`, + we need to define resolver functions for each field. + + :param SQLAlchemyObjectType obj_type: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + return lambda root, _info: getattr(root, model_attr, None) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index d8393fb0..b97002a7 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -6,8 +6,9 @@ import graphene from graphene import relay -from ..fields import BatchSQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType +from ..fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) +from ..types import ORMField, SQLAlchemyObjectType from .models import Article, HairKind, Pet, Reporter from .utils import is_sqlalchemy_version_less_than, to_std_dicts @@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class ArticleType(SQLAlchemyObjectType): class Meta: model = Article interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class PetType(SQLAlchemyObjectType): class Meta: model = Pet interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + batching = True class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) @@ -513,3 +514,187 @@ def test_many_to_many(session_factory): }, ], } + + +def test_disable_batching_via_ormfield(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + favorite_article = ORMField(batching=False) + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + # Test one-to-many and many-to-many relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + +def test_connection_factory_field_overrides_batching_is_false(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = False + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + else: + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 1 + + +def test_connection_factory_field_overrides_batching_is_true(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + connection_field_factory = default_connection_field_factory + + articles = ORMField(batching=True) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e8051a18..e9ee2379 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -190,9 +190,12 @@ def test_should_jsontype_convert_jsonstring(): def test_should_manytomany_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, registry, default_connection_field_factory, mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -204,7 +207,7 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -220,19 +223,19 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): - registry = Registry() + class A(SQLAlchemyObjectType): + class Meta: + model = Article + dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - registry, - default_connection_field_factory, - mock_resolver, + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -244,10 +247,7 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -262,10 +262,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -280,10 +277,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, - A._meta.registry, - default_connection_field_factory, - mock_resolver, + Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ef189b38..ff22cded 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,25 +3,23 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty, strategies) + RelationshipProperty) from sqlalchemy.orm.exc import NoResultFound from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs -from graphene.utils.get_unbound_function import get_unbound_function from graphene.utils.orderedtype import OrderedType -from .batching import get_batch_resolver from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) from .enums import (enum_for_field, sort_argument_for_object_type, sort_enum_for_object_type) -from .fields import default_connection_field_factory from .registry import Registry, get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver from .utils import get_query, is_mapped_class, is_mapped_instance @@ -33,6 +31,7 @@ def __init__( required=None, description=None, deprecation_reason=None, + batching=None, _creation_counter=None, **field_kwargs ): @@ -69,6 +68,8 @@ class Meta: Same behavior as in graphene.Field. Defaults to None. :param str deprecation_reason: Same behavior as in graphene.Field. Defaults to None. + :param bool batching: + Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. :param int _creation_counter: Same behavior as in graphene.Field. """ @@ -80,6 +81,7 @@ class Meta: 'required': required, 'description': description, 'deprecation_reason': deprecation_reason, + 'batching': batching, } common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs @@ -87,7 +89,7 @@ class Meta: def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, connection_field_factory + obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -101,7 +103,8 @@ def construct_fields( :param Registry registry: :param tuple[string] only_fields: :param tuple[string] exclude_fields: - :param function connection_field_factory: + :param bool batching: + :param function|None connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ inspected_model = sqlalchemy.inspect(model) @@ -152,40 +155,23 @@ def construct_fields( for orm_field_name, orm_field in orm_fields.items(): attr_name = orm_field.kwargs.pop('model_attr') attr = all_model_attrs[attr_name] - custom_resolver = _get_custom_resolver(obj_type, orm_field_name) + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column( - attr, - registry, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - **orm_field.kwargs - ) + field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) elif isinstance(attr, RelationshipProperty): + batching_ = orm_field.kwargs.pop('batching', batching) field = convert_sqlalchemy_relationship( - attr, - registry, - connection_field_factory, - custom_resolver or _get_relationship_resolver(obj_type, attr, attr_name), - **orm_field.kwargs - ) + attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " "Field: {}.{}".format(obj_type.__name__, orm_field_name)) - field = convert_sqlalchemy_composite( - attr, - registry, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - ) + field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): - field = convert_sqlalchemy_hybrid_method( - attr, - custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), - **orm_field.kwargs - ) + field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: raise Exception('Property type is not supported') # Should never happen @@ -195,50 +181,6 @@ def construct_fields( return fields -def _get_custom_resolver(obj_type, orm_field_name): - """ - Since `graphene` will call `resolve_` on a field only if it - does not have a `resolver`, we need to re-implement that logic here so - users are able to override the default resolvers that we provide. - """ - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) - if resolver: - return get_unbound_function(resolver) - - return None - - -def _get_relationship_resolver(obj_type, relationship_prop, model_attr): - """ - Batch SQL queries using Dataloader to avoid the N+1 problem. - SQL batching only works for SQLAlchemy 1.2+ since it depends on - the `selectin` loader. - - :param SQLAlchemyObjectType obj_type: - :param sqlalchemy.orm.properties.RelationshipProperty relationship_prop: - :param str model_attr: the name of the SQLAlchemy attribute - :rtype: Callable - """ - if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: - # TODO Batch many-to-many and one-to-many relationships - return _get_attr_resolver(obj_type, model_attr, model_attr) - - return get_batch_resolver(relationship_prop) - - -def _get_attr_resolver(obj_type, orm_field_name, model_attr): - """ - In order to support field renaming via `ORMField.model_attr`, - we need to define resolver functions for each field. - - :param SQLAlchemyObjectType obj_type: - :param str orm_field_name: - :param str model_attr: the name of the SQLAlchemy attribute - :rtype: Callable - """ - return lambda root, _info: getattr(root, model_attr, None) - - class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): model = None # type: sqlalchemy.Model registry = None # type: sqlalchemy.Registry @@ -260,7 +202,8 @@ def __init_subclass_with_meta__( use_connection=None, interfaces=(), id=None, - connection_field_factory=default_connection_field_factory, + batching=False, + connection_field_factory=None, _meta=None, **options ): @@ -286,6 +229,7 @@ def __init_subclass_with_meta__( registry=registry, only_fields=only_fields, exclude_fields=exclude_fields, + batching=batching, connection_field_factory=connection_field_factory, ), _as=Field, From 3c3442e17d4d01ba7cde2679025cec3bc7263306 Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Wed, 12 Feb 2020 10:18:54 -0500 Subject: [PATCH 12/67] Release 2.3.0.dev1 (#266) Use this to test batching --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index ba71f614..392ee1ab 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.3.0.dev0" +__version__ = "2.3.0.dev1" __all__ = [ "__version__", From 421f8e48d169a91e20328108c6f56ae0987d21b8 Mon Sep 17 00:00:00 2001 From: Daniel Pepper Date: Mon, 24 Feb 2020 15:06:27 -0800 Subject: [PATCH 13/67] Fix deprecation warning in tests (#268) Use sqlalchemy.types.LargeBinary instead of Binary --- graphene_sqlalchemy/tests/test_converter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e9ee2379..f0fc1802 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -50,7 +50,8 @@ class Model(declarative_base()): def test_should_unknown_sqlalchemy_field_raise_exception(): re_err = "Don't know how to convert the SQLAlchemy field" with pytest.raises(Exception, match=re_err): - get_field(types.Binary()) + # support legacy Binary type and subsequent LargeBinary + get_field(getattr(types, 'LargeBinary', types.Binary)()) def test_should_date_convert_string(): From 849217a7731cbcfcf64432ccf306f4e4001328f2 Mon Sep 17 00:00:00 2001 From: Chris Berks Date: Thu, 4 Jun 2020 21:13:05 +0100 Subject: [PATCH 14/67] Add support for Non-Null SQLAlchemyConnectionField (#261) * Add support for Non-Null SQLAlchemyConnectionField * Remove implicit ORDER BY clause to fix tests with SQLAlchemy 1.3.16 --- graphene_sqlalchemy/fields.py | 50 ++++++++++++++++------ graphene_sqlalchemy/tests/test_batching.py | 8 ++-- graphene_sqlalchemy/tests/test_fields.py | 16 ++++++- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 254319f9..780fcbf0 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -5,6 +5,7 @@ from promise import Promise, is_thenable from sqlalchemy.orm.query import Query +from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -19,19 +20,26 @@ def type(self): from .types import SQLAlchemyObjectType _type = super(ConnectionField, self).type - if issubclass(_type, Connection): + nullable_type = get_nullable_type(_type) + if issubclass(nullable_type, Connection): return _type - assert issubclass(_type, SQLAlchemyObjectType), ( + assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" - ).format(_type.__name__) - assert _type.connection, "The type {} doesn't have a connection".format( - _type.__name__ + ).format(nullable_type.__name__) + assert ( + nullable_type.connection + ), "The type {} doesn't have a connection".format( + nullable_type.__name__ ) - return _type.connection + assert _type == nullable_type, ( + "Passing a SQLAlchemyObjectType instance is deprecated. " + "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" + ) + return nullable_type.connection @property def model(self): - return self.type._meta.node._meta.model + return get_nullable_type(self.type)._meta.node._meta.model @classmethod def get_query(cls, model, info, **args): @@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg return on_resolve(resolved) def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.model) + return partial( + self.connection_resolver, + parent_resolver, + get_nullable_type(self.type), + self.model, + ) # TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): def __init__(self, type, *args, **kwargs): - if "sort" not in kwargs and issubclass(type, Connection): + nullable_type = get_nullable_type(type) + if "sort" not in kwargs and issubclass(nullable_type, Connection): # Let super class raise if type is not a Connection try: - kwargs.setdefault("sort", type.Edge.node._type.sort_argument()) + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) except (AttributeError, TypeError): raise TypeError( 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' " to None to disabling the creation of the sort query argument".format( - type.__name__ + nullable_type.__name__ ) ) elif "sort" in kwargs and kwargs["sort"] is None: @@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): The API and behavior may change in future versions. Use at your own risk. """ + def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, self.resolver, self.type, self.model) + return partial( + self.connection_resolver, + self.resolver, + get_nullable_type(self.type), + self.model, + ) @classmethod def from_relationship(cls, relationship, registry, **field_kwargs): @@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory(): ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index b97002a7..fc646a3c 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -233,8 +233,7 @@ def test_one_to_one(session_factory): 'articles.headline AS articles_headline, ' 'articles.pub_date AS articles_pub_date \n' 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?) ' - 'ORDER BY articles.reporter_id', + 'WHERE articles.reporter_id IN (?, ?)', '(1, 2)' ] @@ -337,8 +336,7 @@ def test_one_to_many(session_factory): 'articles.headline AS articles_headline, ' 'articles.pub_date AS articles_pub_date \n' 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?) ' - 'ORDER BY articles.reporter_id', + 'WHERE articles.reporter_id IN (?, ?)', '(1, 2)' ] @@ -470,7 +468,7 @@ def test_many_to_many(session_factory): 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' 'JOIN pets ON pets.id = association_1.pet_id \n' 'WHERE reporters_1.id IN (?, ?) ' - 'ORDER BY reporters_1.id, pets.id', + 'ORDER BY pets.id', '(1, 2)' ] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 9ed3c4aa..357055e3 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,7 +1,7 @@ import pytest from promise import Promise -from graphene import ObjectType +from graphene import NonNull, ObjectType from graphene.relay import Connection, Node from ..fields import (SQLAlchemyConnectionField, @@ -26,6 +26,20 @@ class Meta: ## +def test_nonnull_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(NonNull(Pet.connection)) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_required_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(Pet.connection, required=True) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + def test_promise_connection_resolver(): def resolver(_obj, _info): return Promise.resolve([]) From 20ecaeadf2144b88555a3daf1a04e31b7f2ff95a Mon Sep 17 00:00:00 2001 From: Julien Nakache Date: Thu, 4 Jun 2020 17:39:34 -0400 Subject: [PATCH 15/67] Release 2.3.0 (#278) --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 392ee1ab..3945d506 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.3.0.dev1" +__version__ = "2.3.0" __all__ = [ "__version__", From cba727ca4cb344350f92d93ad01f6a3b183c11c6 Mon Sep 17 00:00:00 2001 From: Mel van Londen Date: Wed, 15 Sep 2021 23:36:47 -0700 Subject: [PATCH 16/67] Move from travis to github actions (#316) * move from travis to github actions * add flake8 to tox * add flake8 as env in tox * add flake8 to setup * remove sqlalchemy 1.1 in tests * fix flake8 exclude * move coveralls to github action * fix coverall github action config * move coveralls to tox * move coveralls dep to test list * add coverage command * move coveralls back into github action * modify coverage output --- .github/workflows/deploy.yml | 26 ++++++++++++++++++++ .github/workflows/lint.yml | 22 +++++++++++++++++ .github/workflows/tests.yml | 38 +++++++++++++++++++++++++++++ .travis.yml | 47 ------------------------------------ setup.py | 2 +- tox.ini | 28 ++++++++++++++++++--- 6 files changed, 111 insertions(+), 52 deletions(-) create mode 100644 .github/workflows/deploy.yml create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/tests.yml delete mode 100644 .travis.yml diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 00000000..50ffc6ad --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,26 @@ +name: 🚀 Deploy to PyPI + +on: + push: + tags: + - 'v*' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Build wheel and source tarball + run: | + pip install wheel + python setup.py sdist bdist_wheel + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@v1.1.0 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..3fc35f9d --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,22 @@ +name: Lint + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + - name: Run lint 💅 + run: tox + env: + TOXENV: flake8 \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..4adb26f6 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + max-parallel: 10 + matrix: + sql-alchemy: ["1.2", "1.3"] + python-version: ["3.6", "3.7", "3.8", "3.9"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions + - name: Test with tox + run: tox + env: + SQLALCHEMY: ${{ matrix.sql-alchemy }} + TOXENV: ${{ matrix.toxenv }} + - name: Upload coverage.xml + if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }} + uses: actions/upload-artifact@v2 + with: + name: graphene-sqlalchemy-coverage + path: coverage.xml + if-no-files-found: error + - name: Upload coverage.xml to codecov + if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }} + uses: codecov/codecov-action@v1 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5a988428..00000000 --- a/.travis.yml +++ /dev/null @@ -1,47 +0,0 @@ -language: python -matrix: - include: - # Python 2.7 - - env: TOXENV=py27 - python: 2.7 - # Python 3.5 - - env: TOXENV=py35 - python: 3.5 - # Python 3.6 - - env: TOXENV=py36 - python: 3.6 - # Python 3.7 - - env: TOXENV=py37 - python: 3.7 - dist: xenial - # SQLAlchemy 1.1 - - env: TOXENV=py37-sql11 - python: 3.7 - dist: xenial - # SQLAlchemy 1.2 - - env: TOXENV=py37-sql12 - python: 3.7 - dist: xenial - # SQLAlchemy 1.3 - - env: TOXENV=py37-sql13 - python: 3.7 - dist: xenial - # Pre-commit - - env: TOXENV=pre-commit - python: 3.7 - dist: xenial -install: pip install .[dev] -script: tox -after_success: coveralls -cache: - directories: - - $HOME/.cache/pip - - $HOME/.cache/pre-commit -deploy: - provider: pypi - user: syrusakbary - on: - tags: true - password: - secure: q0ey31cWljGB30l43aEd1KIPuAHRutzmsd2lBb/2zvD79ReBrzvCdFAkH2xcyo4Volk3aazQQTNUIurnTuvBxmtqja0e+gUaO5LdOcokVdOGyLABXh7qhd2kdvbTDWgSwA4EWneLGXn/SjXSe0f3pCcrwc6WDcLAHxtffMvO9gulpYQtUoOqXfMipMOkRD9iDWTJBsSo3trL70X1FHOVr6Yqi0mfkX2Y/imxn6wlTWRz28Ru94xrj27OmUnCv7qcG0taO8LNlUCquNFAr2sZ+l+U/GkQrrM1y+ehPz3pmI0cCCd7SX/7+EG9ViZ07BZ31nk4pgnqjmj3nFwqnCE/4IApGnduqtrMDF63C9TnB1TU8oJmbbUCu4ODwRpBPZMnwzaHsLnrpdrB89/98NtTfujdrh3U5bVB+t33yxrXVh+FjgLYj9PVeDixpFDn6V/Xcnv4BbRMNOhXIQT7a7/5b99RiXBjCk6KRu+Jdu5DZ+3G4Nbr4oim3kZFPUHa555qbzTlwAfkrQxKv3C3OdVJR7eGc9ADsbHyEJbdPNAh/T+xblXTXLS3hPYDvgM+WEGy3CytBDG3JVcXm25ZP96EDWjweJ7MyfylubhuKj/iR1Y1wiHeIsYq9CqRrFQUWL8gFJBfmgjs96xRXXXnvyLtKUKpKw3wFg5cR/6FnLeYZ8k= - distributions: "sdist bdist_wheel" diff --git a/setup.py b/setup.py index 7b350c39..e20a1750 100644 --- a/setup.py +++ b/setup.py @@ -60,8 +60,8 @@ extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "coveralls==1.10.0", "pre-commit==1.14.4", + "flake8==3.7.9", ], "test": tests_require, }, diff --git a/tox.ini b/tox.ini index 562da2dc..69d84f92 100644 --- a/tox.ini +++ b/tox.ini @@ -1,20 +1,40 @@ [tox] -envlist = pre-commit,py{27,35,36,37}-sql{11,12,13} +envlist = pre-commit,py{27,35,36,37,38,39}-sql{12,13},flake8 skipsdist = true minversion = 3.7.0 +[gh-actions] +python = + 2.7: py27 + 3.5: py35 + 3.6: py36 + 3.7: py37 + 3.8: py38 + 3.9: py39 + +[gh-actions:env] +SQLALCHEMY = + 1.2: sql12 + 1.3: sql13 + [testenv] +passenv = GITHUB_* deps = .[test] - sql11: sqlalchemy>=1.1,<1.2 sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 commands = - pytest graphene_sqlalchemy --cov=graphene_sqlalchemy {posargs} + pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] -basepython=python3.7 +basepython=python3.9 deps = .[dev] commands = pre-commit {posargs:run --all-files} + +[testenv:flake8] +basepython = python3.9 +deps = -e.[dev] +commands = + flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 \ No newline at end of file From d6dd67e388b58247dd4a03c29bf6e625d6e4e230 Mon Sep 17 00:00:00 2001 From: Ricardo Madriz Date: Mon, 20 Sep 2021 21:53:03 -0600 Subject: [PATCH 17/67] Graphene v3 (tests) (#317) Co-authored-by: Jonathan Ehwald Co-authored-by: Zbigniew Siciarz Co-authored-by: Cole Lin --- .github/workflows/tests.yml | 50 +++---- .gitignore | 3 + graphene_sqlalchemy/__init__.py | 2 +- graphene_sqlalchemy/batching.py | 49 +++++-- graphene_sqlalchemy/converter.py | 17 ++- graphene_sqlalchemy/enums.py | 3 +- graphene_sqlalchemy/fields.py | 66 +++++---- graphene_sqlalchemy/registry.py | 3 +- graphene_sqlalchemy/tests/conftest.py | 2 +- graphene_sqlalchemy/tests/test_batching.py | 138 ++++++------------ graphene_sqlalchemy/tests/test_benchmark.py | 10 +- graphene_sqlalchemy/tests/test_converter.py | 2 +- graphene_sqlalchemy/tests/test_query_enums.py | 4 +- graphene_sqlalchemy/tests/test_sort_enums.py | 2 +- graphene_sqlalchemy/tests/test_types.py | 8 +- graphene_sqlalchemy/tests/utils.py | 9 +- graphene_sqlalchemy/types.py | 8 +- graphene_sqlalchemy/utils.py | 6 + setup.cfg | 2 +- setup.py | 26 ++-- tox.ini | 6 +- 21 files changed, 196 insertions(+), 220 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4adb26f6..a9a3bd5d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,31 +8,31 @@ jobs: strategy: max-parallel: 10 matrix: - sql-alchemy: ["1.2", "1.3"] + sql-alchemy: ["1.2", "1.3", "1.4"] python-version: ["3.6", "3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install tox tox-gh-actions - - name: Test with tox - run: tox - env: - SQLALCHEMY: ${{ matrix.sql-alchemy }} - TOXENV: ${{ matrix.toxenv }} - - name: Upload coverage.xml - if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }} - uses: actions/upload-artifact@v2 - with: - name: graphene-sqlalchemy-coverage - path: coverage.xml - if-no-files-found: error - - name: Upload coverage.xml to codecov - if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }} - uses: codecov/codecov-action@v1 \ No newline at end of file + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions + - name: Test with tox + run: tox + env: + SQLALCHEMY: ${{ matrix.sql-alchemy }} + TOXENV: ${{ matrix.toxenv }} + - name: Upload coverage.xml + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} + uses: actions/upload-artifact@v2 + with: + name: graphene-sqlalchemy-coverage + path: coverage.xml + if-no-files-found: error + - name: Upload coverage.xml to codecov + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} + uses: codecov/codecov-action@v1 diff --git a/.gitignore b/.gitignore index a97b8c21..c4a735fe 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ target/ # Databases *.sqlite3 .vscode + +# mypy cache +.mypy_cache/ diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 3945d506..060bd13b 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.3.0" +__version__ = "3.0.0b1" __all__ = [ "__version__", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index baf01deb..85cc8855 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,8 +1,10 @@ +import aiodataloader import sqlalchemy -from promise import dataloader, promise from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext +from .utils import is_sqlalchemy_version_less_than + def get_batch_resolver(relationship_prop): @@ -10,10 +12,10 @@ def get_batch_resolver(relationship_prop): # This is so SQL string generation is cached under-the-hood via `bakery` selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - class RelationshipLoader(dataloader.DataLoader): + class RelationshipLoader(aiodataloader.DataLoader): cache = False - def batch_load_fn(self, parents): # pylint: disable=method-hidden + async def batch_load_fn(self, parents): """ Batch loads the relationships of all the parents as one SQL statement. @@ -52,21 +54,36 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden states = [(sqlalchemy.inspect(parent), True) for parent in parents] # For our purposes, the query_context will only used to get the session - query_context = QueryContext(session.query(parent_mapper.entity)) - - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - ) - - return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) + query_context = None + if is_sqlalchemy_version_less_than('1.4'): + query_context = QueryContext(session.query(parent_mapper.entity)) + else: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + + if is_sqlalchemy_version_less_than('1.4'): + selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper + ) + else: + selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None + ) + + return [getattr(parent, relationship_prop.key) for parent in parents] loader = RelationshipLoader() - def resolve(root, info, **args): - return loader.load(root) + async def resolve(root, info, **args): + return await loader.load(root) return resolve diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index f4b805e2..1720e3d8 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,6 +1,5 @@ -from enum import EnumMeta +from functools import singledispatch -from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies @@ -21,6 +20,11 @@ except ImportError: ChoiceType = JSONType = ScalarListType = TSVectorType = object +try: + from sqlalchemy_utils.types.choice import EnumTypeImpl +except ImportError: + EnumTypeImpl = object + is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -110,9 +114,9 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): - if 'type' not in field_kwargs: + if 'type_' not in field_kwargs: # TODO The default type should be dependent on the type of the property propety. - field_kwargs['type'] = String + field_kwargs['type_'] = String return Field( resolver=resolver, @@ -156,7 +160,8 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) + + field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) field_kwargs.setdefault('required', not is_column_nullable(column)) field_kwargs.setdefault('description', get_column_doc(column)) @@ -221,7 +226,7 @@ def convert_enum_to_enum(type, column, registry=None): @convert_sqlalchemy_type.register(ChoiceType) def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() - if isinstance(type.choices, EnumMeta): + if isinstance(type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table return Enum(name, list((v.name, v.value) for v in type.choices)) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 0adea107..f100be19 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,4 +1,3 @@ -import six from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType @@ -63,7 +62,7 @@ def enum_for_field(obj_type, field_name): if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) - if not field_name or not isinstance(field_name, six.string_types): + if not field_name or not isinstance(field_name, str): raise TypeError( "Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 780fcbf0..a22a3ae7 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,17 +1,18 @@ +import enum import warnings from functools import partial -import six from promise import Promise, is_thenable from sqlalchemy.orm.query import Query from graphene import NonNull from graphene.relay import Connection, ConnectionField -from graphene.relay.connection import PageInfo -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphene.relay.connection import connection_adapter, page_info_adapter +from graphql_relay.connection.arrayconnection import \ + connection_from_array_slice from .batching import get_batch_resolver -from .utils import get_query +from .utils import EnumValue, get_query class UnsortedSQLAlchemyConnectionField(ConnectionField): @@ -19,10 +20,10 @@ class UnsortedSQLAlchemyConnectionField(ConnectionField): def type(self): from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - nullable_type = get_nullable_type(_type) + type_ = super(ConnectionField, self).type + nullable_type = get_nullable_type(type_) if issubclass(nullable_type, Connection): - return _type + return type_ assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(nullable_type.__name__) @@ -31,7 +32,7 @@ def type(self): ), "The type {} doesn't have a connection".format( nullable_type.__name__ ) - assert _type == nullable_type, ( + assert type_ == nullable_type, ( "Passing a SQLAlchemyObjectType instance is deprecated. " "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" ) @@ -53,15 +54,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved): _len = resolved.count() else: _len = len(resolved) - connection = connection_from_list_slice( - resolved, - args, + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, slice_start=0, - list_length=_len, - list_slice_length=_len, - connection_type=connection_type, - pageinfo_type=PageInfo, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, edge_type=connection_type.Edge, + page_info_type=page_info_adapter, ) connection.iterable = resolved connection.length = _len @@ -77,7 +82,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg return on_resolve(resolved) - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.connection_resolver, parent_resolver, @@ -88,8 +93,8 @@ def get_resolver(self, parent_resolver): # TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): - def __init__(self, type, *args, **kwargs): - nullable_type = get_nullable_type(type) + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) if "sort" not in kwargs and issubclass(nullable_type, Connection): # Let super class raise if type is not a Connection try: @@ -103,16 +108,25 @@ def __init__(self, type, *args, **kwargs): ) elif "sort" in kwargs and kwargs["sort"] is None: del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @classmethod def get_query(cls, model, info, sort=None, **args): query = get_query(model, info.context) if sort is not None: - if isinstance(sort, six.string_types): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) return query @@ -123,7 +137,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): Use at your own risk. """ - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.connection_resolver, self.resolver, @@ -148,13 +162,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): __connectionFactory = UnsortedSQLAlchemyConnectionField -def createConnectionField(_type, **field_kwargs): +def createConnectionField(type_, **field_kwargs): warnings.warn( 'createConnectionField is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', DeprecationWarning, ) - return __connectionFactory(_type, **field_kwargs) + return __connectionFactory(type_, **field_kwargs) def registerConnectionFieldFactory(factoryMethod): diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index c20bc2ca..acfa744b 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,6 +1,5 @@ from collections import defaultdict -import six from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Enum @@ -43,7 +42,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) ) - if not field_name or not isinstance(field_name, six.string_types): + if not field_name or not isinstance(field_name, str): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 98515051..34ba9d8a 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -22,7 +22,7 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def session_factory(): engine = create_engine(test_db_url) Base.metadata.create_all(engine) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index fc646a3c..1896900b 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -1,3 +1,4 @@ +import ast import contextlib import logging @@ -9,8 +10,9 @@ from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -from .utils import is_sqlalchemy_version_less_than, to_std_dicts +from .utils import remove_cache_miss_stat, to_std_dicts class MockLoggingHandler(logging.Handler): @@ -75,7 +77,8 @@ def resolve_reporters(self, info): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) -def test_many_to_one(session_factory): +@pytest.mark.asyncio +async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( @@ -103,7 +106,7 @@ def test_many_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { articles { headline @@ -125,26 +128,12 @@ def test_many_to_one(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date, ' - 'articles.reporter_id AS articles_reporter_id \n' - 'FROM articles', - '()', - - 'SELECT reporters.id AS reporters_id, ' - '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters \n' - 'WHERE reporters.id IN (?, ?)', - '(1, 2)', - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -166,7 +155,8 @@ def test_many_to_one(session_factory): } -def test_one_to_one(session_factory): +@pytest.mark.asyncio +async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( @@ -194,7 +184,7 @@ def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -216,26 +206,12 @@ def test_one_to_one(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT articles.reporter_id AS articles_reporter_id, ' - 'articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date \n' - 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?)', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -257,7 +233,8 @@ def test_one_to_one(session_factory): } -def test_one_to_many(session_factory): +@pytest.mark.asyncio +async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( @@ -293,7 +270,7 @@ def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -319,26 +296,12 @@ def test_one_to_many(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT articles.reporter_id AS articles_reporter_id, ' - 'articles.id AS articles_id, ' - 'articles.headline AS articles_headline, ' - 'articles.pub_date AS articles_pub_date \n' - 'FROM articles \n' - 'WHERE articles.reporter_id IN (?, ?)', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -382,7 +345,8 @@ def test_one_to_many(session_factory): } -def test_many_to_many(session_factory): +@pytest.mark.asyncio +async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( @@ -420,7 +384,7 @@ def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = await schema.execute_async(""" query { reporters { firstName @@ -446,31 +410,12 @@ def test_many_to_many(session_factory): assert len(sql_statements) == 1 return - assert messages == [ - 'BEGIN (implicit)', - - 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' - 'reporters.id AS reporters_id, ' - 'reporters.first_name AS reporters_first_name, ' - 'reporters.last_name AS reporters_last_name, ' - 'reporters.email AS reporters_email, ' - 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' - 'FROM reporters', - '()', - - 'SELECT reporters_1.id AS reporters_1_id, ' - 'pets.id AS pets_id, ' - 'pets.name AS pets_name, ' - 'pets.pet_kind AS pets_pet_kind, ' - 'pets.hair_kind AS pets_hair_kind, ' - 'pets.reporter_id AS pets_reporter_id \n' - 'FROM reporters AS reporters_1 ' - 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' - 'JOIN pets ON pets.id = association_1.pet_id \n' - 'WHERE reporters_1.id IN (?, ?) ' - 'ORDER BY pets.id', - '(1, 2)' - ] + if not is_sqlalchemy_version_less_than('1.4'): + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] assert not result.errors result = to_std_dicts(result.data) @@ -586,7 +531,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 -def test_connection_factory_field_overrides_batching_is_false(session_factory): +@pytest.mark.asyncio +async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() reporter_1 = Reporter(first_name='Reporter_1') session.add(reporter_1) @@ -620,7 +566,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + await schema.execute_async(""" query { reporters { articles { diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 1e5ee4f1..11e9d0e0 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,13 +1,11 @@ import pytest -from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend import graphene from graphene import relay -from ..fields import BatchSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -from .utils import is_sqlalchemy_version_less_than if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) @@ -18,19 +16,16 @@ class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class ArticleType(SQLAlchemyObjectType): class Meta: model = Article interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class PetType(SQLAlchemyObjectType): class Meta: model = Pet interfaces = (relay.Node,) - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) @@ -47,15 +42,12 @@ def resolve_reporters(self, info): def benchmark_query(session_factory, benchmark, query): schema = get_schema() - cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) - cached_backend.document_from_string(schema, query) # Prime cache @benchmark def execute_query(): result = schema.execute( query, context_value={"session": session_factory()}, - backend=cached_backend, ) assert not result.errors diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f0fc1802..3196d003 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -51,7 +51,7 @@ def test_should_unknown_sqlalchemy_field_raise_exception(): re_err = "Don't know how to convert the SQLAlchemy field" with pytest.raises(Exception, match=re_err): # support legacy Binary type and subsequent LargeBinary - get_field(getattr(types, 'LargeBinary', types.Binary)()) + get_field(getattr(types, 'LargeBinary', types.BINARY)()) def test_should_date_convert_string(): diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index ec585d57..5166c45f 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -32,7 +32,7 @@ def resolve_reporters(self, _info): def resolve_pets(self, _info, kind): query = session.query(Pet) if kind: - query = query.filter_by(pet_kind=kind) + query = query.filter_by(pet_kind=kind.value) return query query = """ @@ -131,7 +131,7 @@ class Query(graphene.ObjectType): def resolve_pet(self, info, kind=None): query = session.query(Pet) if kind: - query = query.filter(Pet.pet_kind == kind) + query = query.filter(Pet.pet_kind == kind.value) return query.first() query = """ diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index d6f6965d..6291d4f8 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -354,7 +354,7 @@ def makeNodes(nodeList): """ result = schema.execute(queryError, context_value={"session": session}) assert result.errors is not None - assert '"sort" has invalid value' in result.errors[0].message + assert 'cannot represent non-enum value' in result.errors[0].message queryNoSort = """ query sortTest { diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index bf563b6e..32f01509 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,6 +1,6 @@ -import mock +from unittest import mock + import pytest -import six # noqa F401 from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) @@ -136,10 +136,10 @@ class Meta: # columns email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(model_attr='email', type=Int) + email_v2 = ORMField(model_attr='email', type_=Int) # column_property - column_prop = ORMField(type=String) + column_prop = ORMField(type_=String) # composite composite_prop = ORMField() diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index 428757c3..c90ee476 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,4 +1,4 @@ -import pkg_resources +import re def to_std_dicts(value): @@ -11,6 +11,7 @@ def to_std_dicts(value): return value -def is_sqlalchemy_version_less_than(version_string): - """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) +def remove_cache_miss_stat(message): + """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" + # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 + return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ff22cded..72f06c06 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -27,7 +27,7 @@ class ORMField(OrderedType): def __init__( self, model_attr=None, - type=None, + type_=None, required=None, description=None, deprecation_reason=None, @@ -49,7 +49,7 @@ class MyType(SQLAlchemyObjectType): class Meta: model = MyModel - id = ORMField(type=graphene.Int) + id = ORMField(type_=graphene.Int) name = ORMField(required=True) -> MyType.id will be of type Int (vs ID). @@ -58,7 +58,7 @@ class Meta: :param str model_attr: Name of the SQLAlchemy model attribute used to resolve this field. Default to the name of the attribute referencing the ORMField. - :param type: + :param type_: Default to the type mapping in converter.py. :param str description: Default to the `doc` attribute of the SQLAlchemy column property. @@ -77,7 +77,7 @@ class Meta: # The is only useful for documentation and auto-completion common_kwargs = { 'model_attr': model_attr, - 'type': type, + 'type_': type_, 'required': required, 'description': description, 'deprecation_reason': deprecation_reason, diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 7139eefc..b30c0eb4 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,6 +1,7 @@ import re import warnings +import pkg_resources from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError @@ -140,3 +141,8 @@ def sort_argument_for_model(cls, has_default=True): enum.default = None return Argument(List(enum), default_value=enum.default) + + +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) diff --git a/setup.cfg b/setup.cfg index 4e8e5029..f36334d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ max-line-length = 120 no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=aiodataloader,app,database,flask,models,nameko,pkg_resources,promise,pytest,schema,setuptools,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy diff --git a/setup.py b/setup.py index e20a1750..da49f1d4 100644 --- a/setup.py +++ b/setup.py @@ -13,24 +13,18 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene - "graphene>=2.1.3,<3", + "graphene>=3.0.0b7", "promise>=2.3", - # Tests fail with 1.0.19 - "SQLAlchemy>=1.2,<2", - "six>=1.10.0,<2", - "singledispatch>=3.4.0.3,<4", + "SQLAlchemy>=1.1,<2", + "aiodataloader>=0.2.0,<1.0", ] -try: - import enum -except ImportError: # Python < 2.7 and Python 3.3 - requirements.append("enum34 >= 1.1.6") tests_require = [ - "pytest==4.3.1", - "mock==2.0.0", - "pytest-cov==2.6.1", - "sqlalchemy_utils==0.33.9", - "pytest-benchmark==3.2.1", + "pytest>=6.2.0,<7.0", + "pytest-asyncio>=0.15.1", + "pytest-cov>=2.11.0,<3.0", + "sqlalchemy_utils>=0.37.0,<1.0", + "pytest-benchmark>=3.4.0,<4.0", ] setup( @@ -46,12 +40,10 @@ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay graphene", diff --git a/tox.ini b/tox.ini index 69d84f92..a2843f05 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pre-commit,py{27,35,36,37,38,39}-sql{12,13},flake8 +envlist = pre-commit,py{36,37,38,39}-sql{11,12,13,14} skipsdist = true minversion = 3.7.0 @@ -16,6 +16,7 @@ python = SQLALCHEMY = 1.2: sql12 1.3: sql13 + 1.4: sql14 [testenv] passenv = GITHUB_* @@ -23,6 +24,7 @@ deps = .[test] sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 + sql14: sqlalchemy>=1.4,<1.5 commands = pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} @@ -37,4 +39,4 @@ commands = basepython = python3.9 deps = -e.[dev] commands = - flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 \ No newline at end of file + flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 From 57cd7866848d5a6ad174710be2a09c1339c159e8 Mon Sep 17 00:00:00 2001 From: Mel van Londen Date: Mon, 20 Sep 2021 21:44:25 -0700 Subject: [PATCH 18/67] Build clean up (#318) --- .github/workflows/deploy.yml | 2 +- tox.ini | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 50ffc6ad..a9f74233 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -3,7 +3,7 @@ name: 🚀 Deploy to PyPI on: push: tags: - - 'v*' + - '*' jobs: build: diff --git a/tox.ini b/tox.ini index a2843f05..b8ce0618 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,10 @@ [tox] -envlist = pre-commit,py{36,37,38,39}-sql{11,12,13,14} +envlist = pre-commit,py{36,37,38,39}-sql{12,13,14} skipsdist = true minversion = 3.7.0 [gh-actions] python = - 2.7: py27 - 3.5: py35 3.6: py36 3.7: py37 3.8: py38 From 7bf0aa5cabae99e5f8be46ac2a2f0faf30093c97 Mon Sep 17 00:00:00 2001 From: Kyle Quinn <29496224+quinnkj@users.noreply.github.com> Date: Fri, 8 Apr 2022 06:41:48 -0400 Subject: [PATCH 19/67] I resolved spelling and capitalization mistakes. (#290) For ~~instaling~~installing ~~g~~Graphene, just run this command in your shell. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9b617069..04692973 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://gra ## Installation -For instaling graphene, just run this command in your shell +For installing Graphene, just run this command in your shell. ```bash pip install "graphene-sqlalchemy>=2.0" @@ -34,7 +34,7 @@ class UserModel(Base): last_name = Column(String) ``` -To create a GraphQL schema for it you simply have to write the following: +To create a GraphQL schema for it, you simply have to write the following: ```python import graphene From 771f4f58f589878820b681598e4a3f4502be00ad Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 27 Apr 2022 21:31:38 +0200 Subject: [PATCH 20/67] Fix for import from graphql-relay-py (#329) (#330) * Add newlines to make pre-commit happy * Fix import from graphql_relay The module name was deprecated, but all imports should be made from the top level anyway. --- .github/workflows/deploy.yml | 2 +- .github/workflows/lint.yml | 2 +- graphene_sqlalchemy/fields.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index a9f74233..1ae7b4b6 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -23,4 +23,4 @@ jobs: uses: pypa/gh-action-pypi-publish@v1.1.0 with: user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3fc35f9d..559326c4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,4 +19,4 @@ jobs: - name: Run lint 💅 run: tox env: - TOXENV: flake8 \ No newline at end of file + TOXENV: flake8 diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index a22a3ae7..d7a83392 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -8,8 +8,7 @@ from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter -from graphql_relay.connection.arrayconnection import \ - connection_from_array_slice +from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver from .utils import EnumValue, get_query From 869a55b3e48b63f1a86f7fbc167b2710be004dc4 Mon Sep 17 00:00:00 2001 From: Jacob Beard Date: Wed, 27 Apr 2022 20:31:37 -0400 Subject: [PATCH 21/67] Add support for N-Dimensional Arrays Fixes #288 --- graphene_sqlalchemy/converter.py | 6 +++++- graphene_sqlalchemy/tests/test_converter.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 1720e3d8..04061801 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -239,11 +239,15 @@ def convert_scalar_list_to_list(type, column, registry=None): return List(String) +def init_array_list_recursive(inner_type, n): + return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n-1)) + + @convert_sqlalchemy_type.register(types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return List(inner_type) + return List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) @convert_sqlalchemy_type.register(postgresql.HSTORE) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 3196d003..57c43058 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -324,6 +324,21 @@ def test_should_array_convert(): assert field.type.of_type == graphene.Int +def test_should_2d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=2)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert field.type.of_type.of_type == graphene.Int + + +def test_should_3d_array_convert(): + field = get_field(types.ARRAY(types.Integer, dimensions=3)) + assert isinstance(field.type, graphene.List) + assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.List) + assert field.type.of_type.of_type.of_type == graphene.Int + + def test_should_postgresql_json_convert(): assert get_field(postgresql.JSON()).type == graphene.JSONString From 5da2048f15f16f6e2443a2c2471e9626069dbb04 Mon Sep 17 00:00:00 2001 From: Connor Brinton Date: Thu, 28 Apr 2022 12:05:09 -0400 Subject: [PATCH 22/67] =?UTF-8?q?=F0=9F=A5=85=20Don't=20suppress=20SQLAlch?= =?UTF-8?q?emy=20errors=20when=20mapping=20classes=20(#169)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These changes modify graphene-sqlalchemy so as not to suppress errors coming from SQLAlchemy when attempting to map classes. Previously this made the debugging experience difficult since issues with SQLAlchemy models would produce an unclear error message from graphene-sqlalchemy. With these changes, the SQLAlchemy error is propagated to the end-user, allowing them to correct the real issue quickly. Fixes #121 --- .pre-commit-config.yaml | 8 +-- graphene_sqlalchemy/tests/test_types.py | 65 +++++++++++++++++++++++++ graphene_sqlalchemy/types.py | 8 +-- graphene_sqlalchemy/utils.py | 8 ++- 4 files changed, 81 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 136f8e7a..1c67ab03 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ default_language_version: python: python3.7 repos: -- repo: git://github.com/pre-commit/pre-commit-hooks +- repo: https://github.com/pre-commit/pre-commit-hooks rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 hooks: - id: check-merge-conflict @@ -11,15 +11,15 @@ repos: exclude: ^docs/.*$ - id: trailing-whitespace exclude: README.md -- repo: git://github.com/PyCQA/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 hooks: - id: flake8 -- repo: git://github.com/asottile/seed-isort-config +- repo: https://github.com/asottile/seed-isort-config rev: v1.7.0 hooks: - id: seed-isort-config -- repo: git://github.com/pre-commit/mirrors-isort +- repo: https://github.com/pre-commit/mirrors-isort rev: v4.3.4 hooks: - id: isort diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 32f01509..1f15fa1a 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,11 +1,14 @@ from unittest import mock import pytest +import sqlalchemy.exc +import sqlalchemy.orm.exc from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) from graphene.relay import Connection +from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField, createConnectionField, @@ -492,3 +495,65 @@ class Meta: def test_deprecated_createConnectionField(): with pytest.warns(DeprecationWarning): createConnectionField(None) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unique_errors_propagate(class_mapper_mock): + # Define unique error to detect + class UniqueError(Exception): + pass + + # Mock class_mapper effect + class_mapper_mock.side_effect = UniqueError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleOne(SQLAlchemyObjectType): + class Meta(object): + model = Article + except UniqueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, UniqueError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_argument_errors_propagate(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleTwo(SQLAlchemyObjectType): + class Meta(object): + model = Article + except sqlalchemy.exc.ArgumentError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, sqlalchemy.exc.ArgumentError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unmapped_errors_reformat(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleThree(SQLAlchemyObjectType): + class Meta(object): + model = Article + except ValueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, ValueError) + assert "You need to pass a valid SQLAlchemy Model" in str(error) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 72f06c06..ac69b697 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -207,9 +207,11 @@ def __init_subclass_with_meta__( _meta=None, **options ): - assert is_mapped_class(model), ( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' - ).format(cls.__name__, model) + # Make sure model is a valid SQLAlchemy model + if not is_mapped_class(model): + raise ValueError( + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + ) if not registry: registry = get_global_registry() diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index b30c0eb4..340ad47e 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -27,7 +27,13 @@ def get_query(model, context): def is_mapped_class(cls): try: class_mapper(cls) - except (ArgumentError, UnmappedClassError): + except ArgumentError as error: + # Only handle ArgumentErrors for non-class objects + if "Class object expected" in str(error): + return False + raise + except UnmappedClassError: + # Unmapped classes return false return False else: return True From 0820da77d94d947e35325ba40eea96462bc890a7 Mon Sep 17 00:00:00 2001 From: Cadu Date: Fri, 29 Apr 2022 20:07:58 -0300 Subject: [PATCH 23/67] Support setting @hybrid_property's return type from the functions type annotations. (#340) Adds support for automatic type conversion for @hybrid_property's using converters similar to @convert_sqlalchemy_type.register(). Currently, all basic types and (nested) Lists are supported. This feature replaces the old default string conversion. String conversion is still used as a fallback in case no compatible converter was found to ensure backward compatibility. Thank you @conao3 & @flipbit03! --- graphene_sqlalchemy/converter.py | 124 +++++++++++++++++++- graphene_sqlalchemy/tests/models.py | 122 +++++++++++++++++++ graphene_sqlalchemy/tests/test_converter.py | 91 +++++++++++++- graphene_sqlalchemy/tests/test_types.py | 63 ++++++++-- graphene_sqlalchemy/utils.py | 56 ++++++++- 5 files changed, 438 insertions(+), 18 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 04061801..a2e03694 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,11 +1,16 @@ +import datetime +import typing +import warnings +from decimal import Decimal from functools import singledispatch +from typing import Any from sqlalchemy import types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - String) +from graphene import (ID, Boolean, Date, DateTime, Dynamic, Enum, Field, Float, + Int, List, String, Time) from graphene.types.json import JSONString from .batching import get_batch_resolver @@ -14,6 +19,14 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver +from .utils import (registry_sqlalchemy_model_from_str, safe_isinstance, + singledispatchbymatchfunction, value_equals) + +try: + from typing import ForwardRef +except ImportError: + # python 3.6 + from typing import _ForwardRef as ForwardRef try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType @@ -25,7 +38,6 @@ except ImportError: EnumTypeImpl = object - is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -48,6 +60,7 @@ def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_fiel :param dict field_kwargs: :rtype: Dynamic """ + def dynamic_type(): """:rtype: Field|None""" direction = relationship_prop.direction @@ -115,8 +128,7 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'type_' not in field_kwargs: - # TODO The default type should be dependent on the type of the property propety. - field_kwargs['type_'] = String + field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) return Field( resolver=resolver, @@ -240,7 +252,7 @@ def convert_scalar_list_to_list(type, column, registry=None): def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n-1)) + return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n - 1)) @convert_sqlalchemy_type.register(types.ARRAY) @@ -260,3 +272,103 @@ def convert_json_to_string(type, column, registry=None): @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): return JSONString + + +@singledispatchbymatchfunction +def convert_sqlalchemy_hybrid_property_type(arg: Any): + existing_graphql_type = get_global_registry().get_type_for_model(arg) + if existing_graphql_type: + return existing_graphql_type + + # No valid type found, warn and fall back to graphene.String + warnings.warn( + (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." + "Falling back to \"graphene.String\"") + ) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) +def convert_sqlalchemy_hybrid_property_type_str(arg): + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) +def convert_sqlalchemy_hybrid_property_type_int(arg): + return Int + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) +def convert_sqlalchemy_hybrid_property_type_float(arg): + return Float + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(arg): + # The reason Decimal should be serialized as a String is because this is a + # base10 type used in things like money, and string allows it to not + # lose precision (which would happen if we downcasted to a Float, for example) + return String + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) +def convert_sqlalchemy_hybrid_property_type_bool(arg): + return Boolean + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) +def convert_sqlalchemy_hybrid_property_type_datetime(arg): + return DateTime + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) +def convert_sqlalchemy_hybrid_property_type_date(arg): + return Date + + +@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) +def convert_sqlalchemy_hybrid_property_type_time(arg): + return Time + + +@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +def convert_sqlalchemy_hybrid_property_type_list_t(arg): + # type is either list[T] or List[T], generic argument at __args__[0] + internal_type = arg.__args__[0] + + graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + + return List(graphql_internal_type) + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(arg): + """ + Generate a lambda that will resolve the type at runtime + This takes care of self-references + """ + + def forward_reference_solver(): + model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) + if not model: + return String + # Always fall back to string if no ForwardRef type found. + return get_global_registry().get_type_for_model(model) + + return forward_reference_solver + + +@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(arg): + """ + Convert Bare String into a ForwardRef + """ + + return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + + +def convert_hybrid_property_return_type(hybrid_prop): + # Grab the original method's return type annotations from inside the hybrid property + return_type_annotation = hybrid_prop.fget.__annotations__.get('return', str) + + return convert_sqlalchemy_hybrid_property_type(return_type_annotation) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 88e992b9..bda5a863 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,6 +1,9 @@ from __future__ import absolute_import +import datetime import enum +from decimal import Decimal +from typing import List, Tuple from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, func, select) @@ -69,6 +72,26 @@ class Reporter(Base): def hybrid_prop(self): return self.first_name + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_list(self) -> List[int]: + return [1, 2, 3] + column_prop = column_property( select([func.cast(func.count(id), Integer)]), doc="Column property" ) @@ -95,3 +118,102 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) mapper(ReflectedEditor, editor_table) + + +############################################ +# The models below are mainly used in the +# @hybrid_property type inference scenarios +############################################ + + +class ShoppingCartItem(Base): + __tablename__ = "shopping_cart_items" + + id = Column(Integer(), primary_key=True) + + @hybrid_property + def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] + + +class ShoppingCart(Base): + __tablename__ = "shopping_carts" + + id = Column(Integer(), primary_key=True) + + # Standard Library types + + @hybrid_property + def hybrid_prop_str(self) -> str: + return self.first_name + + @hybrid_property + def hybrid_prop_int(self) -> int: + return 42 + + @hybrid_property + def hybrid_prop_float(self) -> float: + return 42.3 + + @hybrid_property + def hybrid_prop_bool(self) -> bool: + return True + + @hybrid_property + def hybrid_prop_decimal(self) -> Decimal: + return Decimal("3.14") + + @hybrid_property + def hybrid_prop_date(self) -> datetime.date: + return datetime.datetime.now().date() + + @hybrid_property + def hybrid_prop_time(self) -> datetime.time: + return datetime.datetime.now().time() + + @hybrid_property + def hybrid_prop_datetime(self) -> datetime.datetime: + return datetime.datetime.now() + + # Lists and Nested Lists + + @hybrid_property + def hybrid_prop_list_int(self) -> List[int]: + return [1, 2, 3] + + @hybrid_property + def hybrid_prop_list_date(self) -> List[datetime.date]: + return [self.hybrid_prop_date, self.hybrid_prop_date, self.hybrid_prop_date] + + @hybrid_property + def hybrid_prop_nested_list_int(self) -> List[List[int]]: + return [self.hybrid_prop_list_int, ] + + @hybrid_property + def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: + return [[self.hybrid_prop_list_int, ], ] + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + # Other SQLAlchemy Instances + @hybrid_property + def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: + return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] + + # Unsupported Type + @hybrid_property + def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: + return "this will actually", "be a string" + + # Self-references + + @hybrid_property + def hybrid_prop_self_referential(self) -> 'ShoppingCart': + return ShoppingCart(id=1) + + @hybrid_property + def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + return [ShoppingCart(id=1)] diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 57c43058..4b9e74ed 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,4 +1,5 @@ import enum +from typing import Dict, Union import pytest from sqlalchemy import Column, func, select, types @@ -9,9 +10,11 @@ from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene +from graphene import Boolean, Float, Int, Scalar, String from graphene.relay import Node -from graphene.types.datetime import DateTime +from graphene.types.datetime import Date, DateTime, Time from graphene.types.json import JSONString +from graphene.types.structures import List, Structure from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -20,7 +23,8 @@ default_connection_field_factory) from ..registry import Registry, get_global_registry from ..types import SQLAlchemyObjectType -from .models import Article, CompositeFullName, Pet, Reporter +from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, + ShoppingCartItem) def mock_resolver(): @@ -384,3 +388,86 @@ def __init__(self, col1, col2): Registry(), mock_resolver, ) + + +def test_sqlalchemy_hybrid_property_type_inference(): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + interfaces = (Node,) + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + interfaces = (Node,) + + ####################################################### + # Check ShoppingCartItem's Properties and Return Types + ####################################################### + + shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = { + 'hybrid_prop_shopping_cart': List(ShoppingCartType) + } + + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_item_expected_types.items(): + hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + + ################################################### + # Check ShoppingCart's Properties and Return Types + ################################################### + + shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = { + # Basic types + "hybrid_prop_str": String, + "hybrid_prop_int": Int, + "hybrid_prop_float": Float, + "hybrid_prop_bool": Boolean, + "hybrid_prop_decimal": String, # Decimals should be serialized Strings + "hybrid_prop_date": Date, + "hybrid_prop_time": Time, + "hybrid_prop_datetime": DateTime, + # Lists and Nested Lists + "hybrid_prop_list_int": List(Int), + "hybrid_prop_list_date": List(Date), + "hybrid_prop_nested_list_int": List(List(Int)), + "hybrid_prop_deeply_nested_list_int": List(List(List(Int))), + "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType), + "hybrid_prop_unsupported_type_tuple": String, + # Self Referential List + "hybrid_prop_self_referential": ShoppingCartType, + "hybrid_prop_self_referential_list": List(ShoppingCartType), + } + + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys() + ]) + + for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_expected_types.items(): + hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] + + # this is a simple way of showing the failed property name + # instead of having to unroll the loop. + assert ( + (hybrid_prop_name, str(hybrid_prop_field.type)) == + (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + ) + assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 1f15fa1a..2d660b67 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,8 +4,8 @@ import sqlalchemy.exc import sqlalchemy.orm.exc -from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, - ObjectType, Schema, String) +from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, + Node, NonNull, ObjectType, Schema, String) from graphene.relay import Connection from .. import utils @@ -74,7 +74,7 @@ class Meta: model = Article interfaces = (Node,) - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Columns "column_prop", # SQLAlchemy retuns column properties first "id", @@ -86,11 +86,16 @@ class Meta: "composite_prop", # Hybrid "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", # Relationship "pets", "articles", "favorite_article", - ] + ]) # column first_name_field = ReporterType._meta.fields['first_name'] @@ -115,6 +120,36 @@ class Meta: # "doc" is ignored by hybrid_property assert hybrid_prop.description is None + # hybrid_property_str + hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + assert hybrid_prop_str.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop_str.description is None + + # hybrid_property_int + hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + assert hybrid_prop_int.type == Int + # "doc" is ignored by hybrid_property + assert hybrid_prop_int.description is None + + # hybrid_property_float + hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + assert hybrid_prop_float.type == Float + # "doc" is ignored by hybrid_property + assert hybrid_prop_float.description is None + + # hybrid_property_bool + hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + assert hybrid_prop_bool.type == Boolean + # "doc" is ignored by hybrid_property + assert hybrid_prop_bool.description is None + + # hybrid_property_list + hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + assert hybrid_prop_list.type == List(Int) + # "doc" is ignored by hybrid_property + assert hybrid_prop_list.description is None + # relationship favorite_article_field = ReporterType._meta.fields['favorite_article'] assert isinstance(favorite_article_field, Dynamic) @@ -166,7 +201,7 @@ class Meta: interfaces = (Node,) use_connection = False - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Fields from ReporterMixin "first_name", "last_name", @@ -182,7 +217,12 @@ class Meta: # Then the automatic SQLAlchemy fields "id", "favorite_pet_kind", - ] + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ]) first_name_field = ReporterType._meta.fields['first_name'] assert isinstance(first_name_field.type, NonNull) @@ -271,7 +311,7 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert list(ReporterType._meta.fields.keys()) == [ + assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ "first_name", "last_name", "column_prop", @@ -279,10 +319,15 @@ class Meta: "favorite_pet_kind", "composite_prop", "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", "pets", "articles", "favorite_article", - ] + ]) def test_only_and_exclude_fields(): @@ -387,7 +432,7 @@ class Meta: assert issubclass(CustomReporterType, ObjectType) assert CustomReporterType._meta.model == Reporter - assert len(CustomReporterType._meta.fields) == 11 + assert len(CustomReporterType._meta.fields) == 16 # Test Custom SQLAlchemyObjectType with Custom Options diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 340ad47e..301e782c 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,11 +1,15 @@ import re import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, Optional import pkg_resources from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene_sqlalchemy.registry import get_global_registry + def get_session(context): return context.get("session") @@ -87,7 +91,6 @@ def _deprecated_default_symbol_name(column_name, sort_asc): def _deprecated_object_type_for_model(cls, name): - try: return _deprecated_object_type_cache[cls, name] except KeyError: @@ -152,3 +155,54 @@ def sort_argument_for_model(cls, has_default=True): def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + + +class singledispatchbymatchfunction: + """ + Inspired by @singledispatch, this is a variant that works using a matcher function + instead of relying on the type of the first argument. + The register method can be used to register a new matcher, which is passed as the first argument: + """ + + def __init__(self, default: Callable): + self.registry: Dict[Callable, Callable] = OrderedDict() + self.default = default + + def __call__(self, *args, **kwargs): + for matcher_function, final_method in self.registry.items(): + # Register order is important. First one that matches, runs. + if matcher_function(args[0]): + return final_method(*args, **kwargs) + + # No match, using default. + return self.default(*args, **kwargs) + + def register(self, matcher_function: Callable[[Any], bool]): + + def grab_function_from_outside(f): + self.registry[matcher_function] = f + return self + + return grab_function_from_outside + + +def value_equals(value): + """A simple function that makes the equality based matcher functions for + SingleDispatchByMatchFunction prettier""" + return lambda x: x == value + + +def safe_isinstance(cls): + def safe_isinstance_checker(arg): + try: + return isinstance(arg, cls) + except TypeError: + pass + return safe_isinstance_checker + + +def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + try: + return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + except StopIteration: + pass From b0aa63c968b2d8880310e8ae53c03280446508b0 Mon Sep 17 00:00:00 2001 From: Cadu Date: Tue, 3 May 2022 16:23:11 -0300 Subject: [PATCH 24/67] Added suport for Optional[T] in @hybrid_property's type annotation inference. (#343) Automatic @hybrid_property type conversion now supports Optionals. --- graphene_sqlalchemy/converter.py | 12 ++++++++++++ graphene_sqlalchemy/tests/models.py | 8 +++++++- graphene_sqlalchemy/tests/test_converter.py | 2 ++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index a2e03694..a9da6231 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -331,6 +331,18 @@ def convert_sqlalchemy_hybrid_property_type_time(arg): return Time +@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) == typing.Union) +def convert_sqlalchemy_hybrid_property_type_option_t(arg): + # Option is actually Union[T, ] + + # Just get the T out of the list of arguments by filtering out the NoneType + internal_type = next(filter(lambda x: not type(None) == x, arg.__args__)) + + graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + + return graphql_internal_type + + @convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) def convert_sqlalchemy_hybrid_property_type_list_t(arg): # type is either list[T] or List[T], generic argument at __args__[0] diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index bda5a863..bda46e1c 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -3,7 +3,7 @@ import datetime import enum from decimal import Decimal -from typing import List, Tuple +from typing import List, Optional, Tuple from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, func, select) @@ -217,3 +217,9 @@ def hybrid_prop_self_referential(self) -> 'ShoppingCart': @hybrid_property def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: return [ShoppingCart(id=1)] + + # Optional[T] + + @hybrid_property + def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: + return None diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 4b9e74ed..70e11713 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -452,6 +452,8 @@ class Meta: # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, "hybrid_prop_self_referential_list": List(ShoppingCartType), + # Optionals + "hybrid_prop_optional_self_referential": ShoppingCartType, } assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ From a47dbb369c5e7a057affe7cb98fa15d3acf29cd8 Mon Sep 17 00:00:00 2001 From: Bryan Malyn Date: Thu, 5 May 2022 09:52:47 -0500 Subject: [PATCH 25/67] Pick up the docstrings of hybrid properties (#344) --- graphene_sqlalchemy/converter.py | 3 +++ graphene_sqlalchemy/tests/models.py | 5 +++++ graphene_sqlalchemy/tests/test_types.py | 17 ++++++++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index a9da6231..5d75984b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -130,6 +130,9 @@ def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'type_' not in field_kwargs: field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) + if 'description' not in field_kwargs: + field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) + return Field( resolver=resolver, **field_kwargs diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index bda46e1c..e41adb51 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -68,6 +68,11 @@ class Reporter(Base): articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) + @hybrid_property + def hybrid_prop_with_doc(self): + """Docstring test""" + return self.first_name + @hybrid_property def hybrid_prop(self): return self.first_name diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 2d660b67..9a2e992d 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -85,6 +85,7 @@ class Meta: # Composite "composite_prop", # Hybrid + "hybrid_prop_with_doc", "hybrid_prop", "hybrid_prop_str", "hybrid_prop_int", @@ -150,6 +151,12 @@ class Meta: # "doc" is ignored by hybrid_property assert hybrid_prop_list.description is None + # hybrid_prop_with_doc + hybrid_prop_with_doc = ReporterType._meta.fields['hybrid_prop_with_doc'] + assert hybrid_prop_with_doc.type == String + # docstring is picked up from hybrid_prop_with_doc + assert hybrid_prop_with_doc.description == "Docstring test" + # relationship favorite_article_field = ReporterType._meta.fields['favorite_article'] assert isinstance(favorite_article_field, Dynamic) @@ -183,6 +190,7 @@ class Meta: composite_prop = ORMField() # hybrid_property + hybrid_prop_with_doc = ORMField(description='Overridden') hybrid_prop = ORMField(description='Overridden') # relationships @@ -210,6 +218,7 @@ class Meta: "email_v2", "column_prop", "composite_prop", + "hybrid_prop_with_doc", "hybrid_prop", "favorite_article", "articles", @@ -250,6 +259,11 @@ class Meta: assert hybrid_prop_field.description == "Overridden" assert hybrid_prop_field.deprecation_reason is None + hybrid_prop_with_doc_field = ReporterType._meta.fields['hybrid_prop_with_doc'] + assert hybrid_prop_with_doc_field.type == String + assert hybrid_prop_with_doc_field.description == "Overridden" + assert hybrid_prop_with_doc_field.deprecation_reason is None + column_prop_field_v2 = ReporterType._meta.fields['column_prop'] assert column_prop_field_v2.type == String assert column_prop_field_v2.description is None @@ -318,6 +332,7 @@ class Meta: "email", "favorite_pet_kind", "composite_prop", + "hybrid_prop_with_doc", "hybrid_prop", "hybrid_prop_str", "hybrid_prop_int", @@ -432,7 +447,7 @@ class Meta: assert issubclass(CustomReporterType, ObjectType) assert CustomReporterType._meta.model == Reporter - assert len(CustomReporterType._meta.fields) == 16 + assert len(CustomReporterType._meta.fields) == 17 # Test Custom SQLAlchemyObjectType with Custom Options From 294d529711d4c3360b14e0a6bcb8a484fdb19704 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 2 Jun 2022 11:08:19 +0200 Subject: [PATCH 26/67] Update README.md --- README.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 04692973..68719f4d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,14 @@ -Please read [UPGRADE-v2.0.md](https://github.com/graphql-python/graphene/blob/master/UPGRADE-v2.0.md) -to learn how to upgrade to Graphene `2.0`. +Version 3.0 is in beta stage. Please read https://github.com/graphql-python/graphene-sqlalchemy/issues/348 to learn about progress and changes in upcoming +beta releases. --- -# ![Graphene Logo](http://graphene-python.org/favicon.png) Graphene-SQLAlchemy [![Build Status](https://travis-ci.org/graphql-python/graphene-sqlalchemy.svg?branch=master)](https://travis-ci.org/graphql-python/graphene-sqlalchemy) [![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy.svg)](https://badge.fury.io/py/graphene-sqlalchemy) [![Coverage Status](https://coveralls.io/repos/graphql-python/graphene-sqlalchemy/badge.svg?branch=master&service=github)](https://coveralls.io/github/graphql-python/graphene-sqlalchemy?branch=master) +# ![Graphene Logo](http://graphene-python.org/favicon.png) Graphene-SQLAlchemy +[![Build Status](https://github.com/graphql-python/graphene-sqlalchemy/workflows/Tests/badge.svg)](https://github.com/graphql-python/graphene-sqlalchemy/actions) +[![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy.svg)](https://badge.fury.io/py/graphene-sqlalchemy) +![GitHub release (latest by date including pre-releases)](https://img.shields.io/github/v/release/graphql-python/graphene-sqlalchemy?color=green&include_prereleases&label=latest) +[![codecov](https://codecov.io/gh/graphql-python/graphene-sqlalchemy/branch/master/graph/badge.svg?token=Zi5S1TikeN)](https://codecov.io/gh/graphql-python/graphene-sqlalchemy) + A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://graphene-python.org/). @@ -13,7 +18,7 @@ A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://gra For installing Graphene, just run this command in your shell. ```bash -pip install "graphene-sqlalchemy>=2.0" +pip install "graphene-sqlalchemy>=3" ``` ## Examples From f16d434b716b5602f1406b4fd0e2309bbe6f1fa4 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 3 Jun 2022 12:27:52 +0200 Subject: [PATCH 27/67] Add Python 3.10 & Update Build Scripts (#352) This PR drops tests for Python 3.6 and updates the build scripts. --- .flake8 | 4 +++ .github/workflows/deploy.yml | 8 ++--- .github/workflows/lint.yml | 8 ++--- .github/workflows/tests.yml | 14 ++++---- .pre-commit-config.yaml | 33 ++++++++---------- graphene_sqlalchemy/__init__.py | 2 +- graphene_sqlalchemy/converter.py | 3 +- graphene_sqlalchemy/types.py | 60 ++++++++++++++++---------------- graphene_sqlalchemy/utils.py | 1 + setup.py | 7 ++-- tox.ini | 8 ++--- 11 files changed, 76 insertions(+), 72 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..30f6dedd --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +ignore = E203,W503 +exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs +max-line-length = 120 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 1ae7b4b6..9cc136a1 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -10,11 +10,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 with: - python-version: 3.9 + python-version: '3.10' - name: Build wheel and source tarball run: | pip install wheel diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 559326c4..9352dbe5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,11 +7,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 with: - python-version: 3.9 + python-version: '3.10' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a9a3bd5d..de78190d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,12 +9,12 @@ jobs: max-parallel: 10 matrix: sql-alchemy: ["1.2", "1.3", "1.4"] - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -27,12 +27,12 @@ jobs: SQLALCHEMY: ${{ matrix.sql-alchemy }} TOXENV: ${{ matrix.toxenv }} - name: Upload coverage.xml - if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} - uses: actions/upload-artifact@v2 + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.10' }} + uses: actions/upload-artifact@v3 with: name: graphene-sqlalchemy-coverage path: coverage.xml if-no-files-found: error - name: Upload coverage.xml to codecov - if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }} - uses: codecov/codecov-action@v1 + if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.10' }} + uses: codecov/codecov-action@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c67ab03..66db3814 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,25 +1,22 @@ default_language_version: - python: python3.7 + python: python3.10 repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 hooks: - - id: check-merge-conflict - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer + - id: check-merge-conflict + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer exclude: ^docs/.*$ - - id: trailing-whitespace + - id: trailing-whitespace exclude: README.md -- repo: https://github.com/PyCQA/flake8 - rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 + - repo: https://github.com/pycqa/isort + rev: 5.10.1 hooks: - - id: flake8 -- repo: https://github.com/asottile/seed-isort-config - rev: v1.7.0 + - id: isort + name: isort (python) + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.0 hooks: - - id: seed-isort-config -- repo: https://github.com/pre-commit/mirrors-isort - rev: v4.3.4 - hooks: - - id: isort + - id: flake8 diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 060bd13b..18d34f1d 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ -from .types import SQLAlchemyObjectType from .fields import SQLAlchemyConnectionField +from .types import SQLAlchemyObjectType from .utils import get_query, get_session __version__ = "3.0.0b1" diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 5d75984b..60e14ddd 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -29,7 +29,8 @@ from typing import _ForwardRef as ForwardRef try: - from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType + from sqlalchemy_utils import (ChoiceType, JSONType, ScalarListType, + TSVectorType) except ImportError: ChoiceType = JSONType = ScalarListType = TSVectorType = object diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ac69b697..e6c3d14c 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -25,15 +25,15 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - _creation_counter=None, - **field_kwargs + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + _creation_counter=None, + **field_kwargs ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -89,7 +89,7 @@ class Meta: def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory + obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -110,11 +110,11 @@ def construct_fields( inspected_model = sqlalchemy.inspect(model) # Gather all the relevant attributes from the SQLAlchemy model in order all_model_attrs = OrderedDict( - inspected_model.column_attrs.items() + - inspected_model.composites.items() + - [(name, item) for name, item in inspected_model.all_orm_descriptors.items() - if isinstance(item, hybrid_property)] + - inspected_model.relationships.items() + inspected_model.column_attrs.items() + + inspected_model.composites.items() + + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property)] + + inspected_model.relationships.items() ) # Filter out excluded fields @@ -191,21 +191,21 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options ): # Make sure model is a valid SQLAlchemy model if not is_mapped_class(model): diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 301e782c..084f9b86 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -140,6 +140,7 @@ def sort_argument_for_model(cls, has_default=True): ) from graphene import Argument, List + from .enums import sort_enum_for_object_type enum = sort_enum_for_object_type( diff --git a/setup.py b/setup.py index da49f1d4..ac9ad7e6 100644 --- a/setup.py +++ b/setup.py @@ -41,9 +41,10 @@ "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay graphene", @@ -52,8 +53,8 @@ extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "pre-commit==1.14.4", - "flake8==3.7.9", + "pre-commit==2.19", + "flake8==4.0.0", ], "test": tests_require, }, diff --git a/tox.ini b/tox.ini index b8ce0618..2802dee0 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,14 @@ [tox] -envlist = pre-commit,py{36,37,38,39}-sql{12,13,14} +envlist = pre-commit,py{37,38,39,310}-sql{12,13,14} skipsdist = true minversion = 3.7.0 [gh-actions] python = - 3.6: py36 3.7: py37 3.8: py38 3.9: py39 + 3.10: py310 [gh-actions:env] SQLALCHEMY = @@ -27,14 +27,14 @@ commands = pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] -basepython=python3.9 +basepython=python3.10 deps = .[dev] commands = pre-commit {posargs:run --all-files} [testenv:flake8] -basepython = python3.9 +basepython = python3.10 deps = -e.[dev] commands = flake8 --exclude setup.py,docs,examples,tests,.tox --max-line-length 120 From a70256962f57cb4fd4bd2d72ec87e59703fd6e74 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 15 Jul 2022 10:40:37 +0200 Subject: [PATCH 28/67] Native support for additional Type Converters (#353) * Fields generated from Hybrid Properties & Type hints now support Unions (Union[ObjectType1,OT2] or ObjectType1 | OT2) * Support for Variant and types.JSON Columns * BREAKING: Date&Time now convert to their corresponding graphene scalars instead of String. * BREAKING: PG UUID & sqlalchemy_utils.UUIDType now convert to graphene.UUID instead of graphene.String * Change: Sort Enums & ChoiceType enums are now generated from Column.key instead of Column.name, see #330 Signed-off-by: Erik Wrede Co-authored-by: Nicolas Delaby Co-authored-by: davidcim Co-authored-by: Viktor Pegy Co-authored-by: Ian Epperson --- graphene_sqlalchemy/converter.py | 204 +++++++++++----- graphene_sqlalchemy/enums.py | 4 +- graphene_sqlalchemy/registry.py | 38 ++- graphene_sqlalchemy/tests/models.py | 10 +- graphene_sqlalchemy/tests/test_converter.py | 243 +++++++++++++++---- graphene_sqlalchemy/tests/test_registry.py | 56 ++++- graphene_sqlalchemy/tests/test_sort_enums.py | 25 +- graphene_sqlalchemy/tests/test_types.py | 2 +- graphene_sqlalchemy/tests/test_utils.py | 8 +- graphene_sqlalchemy/utils.py | 9 +- 10 files changed, 468 insertions(+), 131 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 60e14ddd..1e7846eb 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,16 +1,16 @@ import datetime +import sys import typing import warnings from decimal import Decimal from functools import singledispatch -from typing import Any +from typing import Any, cast -from sqlalchemy import types +from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -from graphene import (ID, Boolean, Date, DateTime, Dynamic, Enum, Field, Float, - Int, List, String, Time) +import graphene from graphene.types.json import JSONString from .batching import get_batch_resolver @@ -19,8 +19,9 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import (registry_sqlalchemy_model_from_str, safe_isinstance, - singledispatchbymatchfunction, value_equals) +from .utils import (DummyImport, registry_sqlalchemy_model_from_str, + safe_isinstance, singledispatchbymatchfunction, + value_equals) try: from typing import ForwardRef @@ -29,15 +30,14 @@ from typing import _ForwardRef as ForwardRef try: - from sqlalchemy_utils import (ChoiceType, JSONType, ScalarListType, - TSVectorType) + from sqlalchemy_utils.types.choice import EnumTypeImpl except ImportError: - ChoiceType = JSONType = ScalarListType = TSVectorType = object + EnumTypeImpl = object try: - from sqlalchemy_utils.types.choice import EnumTypeImpl + import sqlalchemy_utils as sqa_utils except ImportError: - EnumTypeImpl = object + sqa_utils = DummyImport() is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -79,7 +79,7 @@ def dynamic_type(): return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, connection_field_factory, **field_kwargs) - return Dynamic(dynamic_type) + return graphene.Dynamic(dynamic_type) def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): @@ -100,7 +100,7 @@ def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_ resolver = get_batch_resolver(relationship_prop) if batching else \ get_attr_resolver(obj_type, relationship_prop.key) - return Field(child_type, resolver=resolver, **field_kwargs) + return graphene.Field(child_type, resolver=resolver, **field_kwargs) def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): @@ -117,7 +117,7 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) if not child_type._meta.connection: - return Field(List(child_type), **field_kwargs) + return graphene.Field(graphene.List(child_type), **field_kwargs) # TODO Allow override of connection_field_factory and resolver via ORMField if connection_field_factory is None: @@ -134,7 +134,7 @@ def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'description' not in field_kwargs: field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) - return Field( + return graphene.Field( resolver=resolver, **field_kwargs ) @@ -181,7 +181,7 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): field_kwargs.setdefault('required', not is_column_nullable(column)) field_kwargs.setdefault('description', get_column_doc(column)) - return Field( + return graphene.Field( resolver=resolver, **field_kwargs ) @@ -195,75 +195,90 @@ def convert_sqlalchemy_type(type, column, registry=None): ) -@convert_sqlalchemy_type.register(types.Date) -@convert_sqlalchemy_type.register(types.Time) -@convert_sqlalchemy_type.register(types.String) -@convert_sqlalchemy_type.register(types.Text) -@convert_sqlalchemy_type.register(types.Unicode) -@convert_sqlalchemy_type.register(types.UnicodeText) -@convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(sqa_types.String) +@convert_sqlalchemy_type.register(sqa_types.Text) +@convert_sqlalchemy_type.register(sqa_types.Unicode) +@convert_sqlalchemy_type.register(sqa_types.UnicodeText) @convert_sqlalchemy_type.register(postgresql.INET) @convert_sqlalchemy_type.register(postgresql.CIDR) -@convert_sqlalchemy_type.register(TSVectorType) +@convert_sqlalchemy_type.register(sqa_utils.TSVectorType) +@convert_sqlalchemy_type.register(sqa_utils.EmailType) +@convert_sqlalchemy_type.register(sqa_utils.URLType) +@convert_sqlalchemy_type.register(sqa_utils.IPAddressType) def convert_column_to_string(type, column, registry=None): - return String + return graphene.String + + +@convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(sqa_utils.UUIDType) +def convert_column_to_uuid(type, column, registry=None): + return graphene.UUID -@convert_sqlalchemy_type.register(types.DateTime) +@convert_sqlalchemy_type.register(sqa_types.DateTime) def convert_column_to_datetime(type, column, registry=None): - from graphene.types.datetime import DateTime - return DateTime + return graphene.DateTime -@convert_sqlalchemy_type.register(types.SmallInteger) -@convert_sqlalchemy_type.register(types.Integer) +@convert_sqlalchemy_type.register(sqa_types.Time) +def convert_column_to_time(type, column, registry=None): + return graphene.Time + + +@convert_sqlalchemy_type.register(sqa_types.Date) +def convert_column_to_date(type, column, registry=None): + return graphene.Date + + +@convert_sqlalchemy_type.register(sqa_types.SmallInteger) +@convert_sqlalchemy_type.register(sqa_types.Integer) def convert_column_to_int_or_id(type, column, registry=None): - return ID if column.primary_key else Int + return graphene.ID if column.primary_key else graphene.Int -@convert_sqlalchemy_type.register(types.Boolean) +@convert_sqlalchemy_type.register(sqa_types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean + return graphene.Boolean -@convert_sqlalchemy_type.register(types.Float) -@convert_sqlalchemy_type.register(types.Numeric) -@convert_sqlalchemy_type.register(types.BigInteger) +@convert_sqlalchemy_type.register(sqa_types.Float) +@convert_sqlalchemy_type.register(sqa_types.Numeric) +@convert_sqlalchemy_type.register(sqa_types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float + return graphene.Float -@convert_sqlalchemy_type.register(types.Enum) +@convert_sqlalchemy_type.register(sqa_types.Enum) def convert_enum_to_enum(type, column, registry=None): return lambda: enum_for_sa_enum(type, registry or get_global_registry()) # TODO Make ChoiceType conversion consistent with other enums -@convert_sqlalchemy_type.register(ChoiceType) +@convert_sqlalchemy_type.register(sqa_utils.ChoiceType) def convert_choice_to_enum(type, column, registry=None): - name = "{}_{}".format(column.table.name, column.name).upper() + name = "{}_{}".format(column.table.name, column.key).upper() if isinstance(type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table - return Enum(name, list((v.name, v.value) for v in type.choices)) + return graphene.Enum(name, list((v.name, v.value) for v in type.choices)) else: - return Enum(name, type.choices) + return graphene.Enum(name, type.choices) -@convert_sqlalchemy_type.register(ScalarListType) +@convert_sqlalchemy_type.register(sqa_utils.ScalarListType) def convert_scalar_list_to_list(type, column, registry=None): - return List(String) + return graphene.List(graphene.String) def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n - 1)) + return inner_type if n == 0 else graphene.List(init_array_list_recursive(inner_type, n - 1)) -@convert_sqlalchemy_type.register(types.ARRAY) +@convert_sqlalchemy_type.register(sqa_types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) + return graphene.List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) @convert_sqlalchemy_type.register(postgresql.HSTORE) @@ -273,38 +288,50 @@ def convert_json_to_string(type, column, registry=None): return JSONString -@convert_sqlalchemy_type.register(JSONType) +@convert_sqlalchemy_type.register(sqa_utils.JSONType) +@convert_sqlalchemy_type.register(sqa_types.JSON) def convert_json_type_to_string(type, column, registry=None): return JSONString +@convert_sqlalchemy_type.register(sqa_types.Variant) +def convert_variant_to_impl_type(type, column, registry=None): + return convert_sqlalchemy_type(type.impl, column, registry=registry) + + @singledispatchbymatchfunction def convert_sqlalchemy_hybrid_property_type(arg: Any): existing_graphql_type = get_global_registry().get_type_for_model(arg) if existing_graphql_type: return existing_graphql_type + if isinstance(arg, type(graphene.ObjectType)): + return arg + + if isinstance(arg, type(graphene.Scalar)): + return arg + # No valid type found, warn and fall back to graphene.String warnings.warn( (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." "Falling back to \"graphene.String\"") ) - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) def convert_sqlalchemy_hybrid_property_type_str(arg): - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) def convert_sqlalchemy_hybrid_property_type_int(arg): - return Int + return graphene.Int @convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) def convert_sqlalchemy_hybrid_property_type_float(arg): - return Float + return graphene.Float @convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) @@ -312,39 +339,85 @@ def convert_sqlalchemy_hybrid_property_type_decimal(arg): # The reason Decimal should be serialized as a String is because this is a # base10 type used in things like money, and string allows it to not # lose precision (which would happen if we downcasted to a Float, for example) - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) def convert_sqlalchemy_hybrid_property_type_bool(arg): - return Boolean + return graphene.Boolean @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) def convert_sqlalchemy_hybrid_property_type_datetime(arg): - return DateTime + return graphene.DateTime @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) def convert_sqlalchemy_hybrid_property_type_date(arg): - return Date + return graphene.Date @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) def convert_sqlalchemy_hybrid_property_type_time(arg): - return Time + return graphene.Time -@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) == typing.Union) -def convert_sqlalchemy_hybrid_property_type_option_t(arg): - # Option is actually Union[T, ] +def is_union(arg) -> bool: + if sys.version_info >= (3, 10): + from types import UnionType + + if isinstance(arg, UnionType): + return True + return getattr(arg, '__origin__', None) == typing.Union + + +def graphene_union_for_py_union(obj_types: typing.List[graphene.ObjectType], registry) -> graphene.Union: + union_type = registry.get_union_for_object_types(obj_types) + + if union_type is None: + # Union Name is name of the three + union_name = ''.join(sorted([obj_type._meta.name for obj_type in obj_types])) + union_type = graphene.Union(union_name, obj_types) + registry.register_union_type(union_type, obj_types) + + return union_type + + +@convert_sqlalchemy_hybrid_property_type.register(is_union) +def convert_sqlalchemy_hybrid_property_union(arg): + """ + Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. + Since Optionals are internally represented as Union[T, ], they are handled here as well. + + The GQL Spec currently only allows for ObjectType unions: + GraphQL Unions represent an object that could be one of a list of GraphQL Object types, but provides for no + guaranteed fields between those types. + That's why we have to check for the nested types to be instances of graphene.ObjectType, except for the union case. + + type(x) == _types.UnionType is necessary to support X | Y notation, but might break in future python releases. + """ + from .registry import get_global_registry + # Option is actually Union[T, ] # Just get the T out of the list of arguments by filtering out the NoneType - internal_type = next(filter(lambda x: not type(None) == x, arg.__args__)) + nested_types = list(filter(lambda x: not type(None) == x, arg.__args__)) - graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + # Map the graphene types to the nested types. + # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... + graphene_types = list(map(convert_sqlalchemy_hybrid_property_type, nested_types)) + + # If only one type is left after filtering out NoneType, the Union was an Optional + if len(graphene_types) == 1: + return graphene_types[0] + + # Now check if every type is instance of an ObjectType + if not all(isinstance(graphene_type, type(graphene.ObjectType)) for graphene_type in graphene_types): + raise ValueError("Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " + "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " + "or use an ORMField to override this behaviour.") - return graphql_internal_type + return graphene_union_for_py_union(cast(typing.List[graphene.ObjectType], list(graphene_types)), + get_global_registry()) @convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) @@ -354,7 +427,7 @@ def convert_sqlalchemy_hybrid_property_type_list_t(arg): graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) - return List(graphql_internal_type) + return graphene.List(graphql_internal_type) @convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) @@ -363,11 +436,12 @@ def convert_sqlalchemy_hybrid_property_forwardref(arg): Generate a lambda that will resolve the type at runtime This takes care of self-references """ + from .registry import get_global_registry def forward_reference_solver(): model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) if not model: - return String + return graphene.String # Always fall back to string if no ForwardRef type found. return get_global_registry().get_type_for_model(model) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index f100be19..a2ed17ad 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -144,9 +144,9 @@ def sort_enum_for_object_type( column = orm_field.columns[0] if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(column.name, True) + asc_name = get_name(column.key, True) asc_value = EnumValue(asc_name, column.asc()) - desc_name = get_name(column.name, False) + desc_name = get_name(column.key, False) desc_value = EnumValue(desc_name, column.desc()) if column.primary_key: default.append(asc_value) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index acfa744b..80470d9b 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,7 +1,9 @@ from collections import defaultdict +from typing import List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType +import graphene from graphene import Enum @@ -13,12 +15,13 @@ def __init__(self): self._registry_composites = {} self._registry_enums = {} self._registry_sort_enums = {} + self._registry_unions = {} def register(self, obj_type): - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -37,7 +40,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -55,7 +58,7 @@ def register_composite_converter(self, composite, converter): def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) - def register_enum(self, sa_enum, graphene_enum): + def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum): if not isinstance(sa_enum, SQLAlchemyEnumType): raise TypeError( "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) @@ -67,14 +70,14 @@ def register_enum(self, sa_enum, graphene_enum): self._registry_enums[sa_enum] = graphene_enum - def get_graphene_enum_for_sa_enum(self, sa_enum): + def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): return self._registry_enums.get(sa_enum) - def register_sort_enum(self, obj_type, sort_enum): - from .types import SQLAlchemyObjectType + def register_sort_enum(self, obj_type, sort_enum: Enum): + from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -83,9 +86,26 @@ def register_sort_enum(self, obj_type, sort_enum): raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) self._registry_sort_enums[obj_type] = sort_enum - def get_sort_enum_for_object_type(self, obj_type): + def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) + def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]]): + if not isinstance(union, graphene.Union): + raise TypeError( + "Expected graphene.Union, but got: {!r}".format(union) + ) + + for obj_type in obj_types: + if not isinstance(obj_type, type(graphene.ObjectType)): + raise TypeError( + "Expected Graphene ObjectType, but got: {!r}".format(obj_type) + ) + + self._registry_unions[frozenset(obj_types)] = union + + def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]): + return self._registry_unions.get(frozenset(obj_types)) + registry = None diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index e41adb51..dc399ee0 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -5,8 +5,8 @@ from decimal import Decimal from typing import List, Optional, Tuple -from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, - func, select) +from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, Numeric, + String, Table, func, select) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import column_property, composite, mapper, relationship @@ -228,3 +228,9 @@ def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: @hybrid_property def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: return None + + +class KeyedModel(Base): + __tablename__ = "test330" + id = Column(Integer(), primary_key=True) + reporter_number = Column("% reporter_number", Numeric, key="reporter_number") diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 70e11713..a6c2b1bf 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,28 +1,28 @@ import enum +import sys from typing import Dict, Union import pytest +import sqlalchemy_utils as sqa_utils from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene -from graphene import Boolean, Float, Int, Scalar, String from graphene.relay import Node -from graphene.types.datetime import Date, DateTime, Time -from graphene.types.json import JSONString -from graphene.types.structures import List, Structure +from graphene.types.structures import Structure from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) from ..fields import (UnsortedSQLAlchemyConnectionField, default_connection_field_factory) from ..registry import Registry, get_global_registry -from ..types import SQLAlchemyObjectType +from ..types import ORMField, SQLAlchemyObjectType from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, ShoppingCartItem) @@ -51,23 +51,117 @@ class Model(declarative_base()): return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) -def test_should_unknown_sqlalchemy_field_raise_exception(): - re_err = "Don't know how to convert the SQLAlchemy field" - with pytest.raises(Exception, match=re_err): - # support legacy Binary type and subsequent LargeBinary - get_field(getattr(types, 'LargeBinary', types.BINARY)()) +def get_hybrid_property_type(prop_method): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + prop = prop_method + + column_prop = inspect(Model).all_orm_descriptors['prop'] + return convert_sqlalchemy_hybrid_method(column_prop, mock_resolver(), **ORMField().kwargs) + + +def test_hybrid_prop_int(): + @hybrid_property + def prop_method() -> int: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_hybrid_prop_scalar_union_310(): + @hybrid_property + def prop_method() -> int | str: + return "not allowed in gql schema" + + with pytest.raises(ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*"): + get_hybrid_property_type(prop_method) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_hybrid_prop_scalar_union_and_optional_310(): + """Checks if the use of Optionals does not interfere with non-conform scalar return types""" + + @hybrid_property + def prop_method() -> int | None: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_should_union_work_310(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + + @hybrid_property + def prop_method() -> Union[PetType, ShoppingCartType]: + return None + + @hybrid_property + def prop_method_2() -> Union[ShoppingCartType, PetType]: + return None + + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type + + assert isinstance(field_type_1, graphene.Union) + assert field_type_1 is field_type_2 + + # TODO verify types of the union + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_should_union_work_310(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + + @hybrid_property + def prop_method() -> PetType | ShoppingCartType: + return None + @hybrid_property + def prop_method_2() -> ShoppingCartType | PetType: + return None -def test_should_date_convert_string(): - assert get_field(types.Date()).type == graphene.String + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type + + assert isinstance(field_type_1, graphene.Union) + assert field_type_1 is field_type_2 def test_should_datetime_convert_datetime(): - assert get_field(types.DateTime()).type == DateTime + assert get_field(types.DateTime()).type == graphene.DateTime + +def test_should_time_convert_time(): + assert get_field(types.Time()).type == graphene.Time -def test_should_time_convert_string(): - assert get_field(types.Time()).type == graphene.String + +def test_should_date_convert_date(): + assert get_field(types.Date()).type == graphene.Date def test_should_string_convert_string(): @@ -86,6 +180,30 @@ def test_should_unicodetext_convert_string(): assert get_field(types.UnicodeText()).type == graphene.String +def test_should_tsvector_convert_string(): + assert get_field(sqa_utils.TSVectorType()).type == graphene.String + + +def test_should_email_convert_string(): + assert get_field(sqa_utils.EmailType()).type == graphene.String + + +def test_should_URL_convert_string(): + assert get_field(sqa_utils.URLType()).type == graphene.String + + +def test_should_IPaddress_convert_string(): + assert get_field(sqa_utils.IPAddressType()).type == graphene.String + + +def test_should_inet_convert_string(): + assert get_field(postgresql.INET()).type == graphene.String + + +def test_should_cidr_convert_string(): + assert get_field(postgresql.CIDR()).type == graphene.String + + def test_should_enum_convert_enum(): field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) field_type = field.type() @@ -142,7 +260,7 @@ def test_should_numeric_convert_float(): def test_should_choice_convert_enum(): - field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + field = get_field(sqa_utils.ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -155,7 +273,7 @@ class TestEnum(enum.Enum): es = u"Spanish" en = u"English" - field = get_field(ChoiceType(TestEnum, impl=types.String())) + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -163,12 +281,32 @@ class TestEnum(enum.Enum): assert graphene_type._meta.enum.__members__["en"].value == "English" +def test_choice_enum_column_key_name_issue_301(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class TestEnum(enum.Enum): + es = u"Spanish" + en = u"English" + + testChoice = Column("% descuento1", sqa_utils.ChoiceType(TestEnum, impl=types.String()), key="descuento1") + field = get_field_from_column(testChoice) + + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_DESCUENTO1" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" + + def test_should_intenum_choice_convert_enum(): class TestEnum(enum.IntEnum): one = 1 two = 2 - field = get_field(ChoiceType(TestEnum, impl=types.String())) + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -185,13 +323,22 @@ def test_should_columproperty_convert(): def test_should_scalar_list_convert_list(): - field = get_field(ScalarListType()) + field = get_field(sqa_utils.ScalarListType()) assert isinstance(field.type, graphene.List) assert field.type.of_type == graphene.String def test_should_jsontype_convert_jsonstring(): - assert get_field(JSONType()).type == JSONString + assert get_field(sqa_utils.JSONType()).type == graphene.JSONString + assert get_field(types.JSON).type == graphene.JSONString + + +def test_should_variant_int_convert_int(): + assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int + + +def test_should_variant_string_convert_string(): + assert get_field(types.Variant(types.String(), {})).type == graphene.String def test_should_manytomany_convert_connectionorlist(): @@ -291,7 +438,11 @@ class Meta: def test_should_postgresql_uuid_convert(): - assert get_field(postgresql.UUID()).type == graphene.String + assert get_field(postgresql.UUID()).type == graphene.UUID + + +def test_should_sqlalchemy_utils_uuid_convert(): + assert get_field(sqa_utils.UUIDType()).type == graphene.UUID def test_should_postgresql_enum_convert(): @@ -405,8 +556,8 @@ class Meta: # Check ShoppingCartItem's Properties and Return Types ####################################################### - shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = { - 'hybrid_prop_shopping_cart': List(ShoppingCartType) + shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { + 'hybrid_prop_shopping_cart': graphene.List(ShoppingCartType) } assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ @@ -421,9 +572,9 @@ class Meta: # this is a simple way of showing the failed property name # instead of having to unroll the loop. - assert ( - (hybrid_prop_name, str(hybrid_prop_field.type)) == - (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), ) assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property @@ -431,27 +582,27 @@ class Meta: # Check ShoppingCart's Properties and Return Types ################################################### - shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = { + shopping_cart_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { # Basic types - "hybrid_prop_str": String, - "hybrid_prop_int": Int, - "hybrid_prop_float": Float, - "hybrid_prop_bool": Boolean, - "hybrid_prop_decimal": String, # Decimals should be serialized Strings - "hybrid_prop_date": Date, - "hybrid_prop_time": Time, - "hybrid_prop_datetime": DateTime, + "hybrid_prop_str": graphene.String, + "hybrid_prop_int": graphene.Int, + "hybrid_prop_float": graphene.Float, + "hybrid_prop_bool": graphene.Boolean, + "hybrid_prop_decimal": graphene.String, # Decimals should be serialized Strings + "hybrid_prop_date": graphene.Date, + "hybrid_prop_time": graphene.Time, + "hybrid_prop_datetime": graphene.DateTime, # Lists and Nested Lists - "hybrid_prop_list_int": List(Int), - "hybrid_prop_list_date": List(Date), - "hybrid_prop_nested_list_int": List(List(Int)), - "hybrid_prop_deeply_nested_list_int": List(List(List(Int))), + "hybrid_prop_list_int": graphene.List(graphene.Int), + "hybrid_prop_list_date": graphene.List(graphene.Date), + "hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)), + "hybrid_prop_deeply_nested_list_int": graphene.List(graphene.List(graphene.List(graphene.Int))), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, - "hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType), - "hybrid_prop_unsupported_type_tuple": String, + "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), + "hybrid_prop_unsupported_type_tuple": graphene.String, # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, - "hybrid_prop_self_referential_list": List(ShoppingCartType), + "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), # Optionals "hybrid_prop_optional_self_referential": ShoppingCartType, } @@ -468,8 +619,8 @@ class Meta: # this is a simple way of showing the failed property name # instead of having to unroll the loop. - assert ( - (hybrid_prop_name, str(hybrid_prop_field.type)) == - (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), ) assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 0403c4f0..f451f355 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,12 +1,13 @@ import pytest from sqlalchemy.types import Enum as SQLAlchemyEnum +import graphene from graphene import Enum as GrapheneEnum from ..registry import Registry from ..types import SQLAlchemyObjectType from ..utils import EnumValue -from .models import Pet +from .models import Pet, Reporter def test_register_object_type(): @@ -126,3 +127,56 @@ class Meta: re_err = r"Expected Graphene Enum, but got: .*PetType.*" with pytest.raises(TypeError, match=re_err): reg.register_sort_enum(PetType, PetType) + + +def test_register_union(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = graphene.Union('ReporterPet', tuple(union_types)) + + reg.register_union_type(union, union_types) + + assert reg.get_union_for_object_types(union_types) == union + # Order should not matter + assert reg.get_union_for_object_types([ReporterType, PetType]) == union + + +def test_register_union_scalar(): + reg = Registry() + + union_types = [graphene.String, graphene.Int] + union = graphene.Union('StringInt', tuple(union_types)) + + re_err = r"Expected Graphene ObjectType, but got: .*String.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) + + +def test_register_union_incorrect_types(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = PetType + + re_err = r"Expected graphene.Union, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 6291d4f8..e2510abc 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -7,7 +7,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from ..utils import to_type_name -from .models import Base, HairKind, Pet +from .models import Base, HairKind, KeyedModel, Pet from .test_query import to_std_dicts @@ -383,3 +383,26 @@ def makeNodes(nodeList): assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] ] + + +def test_sort_enum_from_key_issue_330(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class KeyedType(SQLAlchemyObjectType): + class Meta: + model = KeyedModel + + sort_enum = KeyedType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "KeyedTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "REPORTER_NUMBER_ASC", + "REPORTER_NUMBER_DESC", + ] + assert str(sort_enum.REPORTER_NUMBER_ASC.value.value) == 'test330."% reporter_number" ASC' + assert str(sort_enum.REPORTER_NUMBER_DESC.value.value) == 'test330."% reporter_number" DESC' diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 9a2e992d..00e8b3af 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -76,7 +76,7 @@ class Meta: assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Columns - "column_prop", # SQLAlchemy retuns column properties first + "column_prop", "id", "first_name", "last_name", diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index e13d919c..de359e05 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -3,8 +3,8 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, - to_enum_value_name, to_type_name) +from ..utils import (DummyImport, get_session, sort_argument_for_model, + sort_enum_for_model, to_enum_value_name, to_type_name) from .models import Base, Editor, Pet @@ -99,3 +99,7 @@ class MultiplePK(Base): assert set(arg.default_value) == set( (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") ) + +def test_dummy_import(): + dummy_module = DummyImport() + assert dummy_module.foo == object diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 084f9b86..f6ee9b62 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -8,8 +8,6 @@ from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError -from graphene_sqlalchemy.registry import get_global_registry - def get_session(context): return context.get("session") @@ -203,7 +201,14 @@ def safe_isinstance_checker(arg): def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + from graphene_sqlalchemy.registry import get_global_registry try: return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) except StopIteration: pass + + +class DummyImport: + """The dummy module returns 'object' for a query for any member""" + def __getattr__(self, name): + return object From dfee3e9417cdb8a6ec67b5cd79ee203ce4f72ed7 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 15 Jul 2022 10:43:41 +0200 Subject: [PATCH 29/67] Release new beta --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 18d34f1d..c5400cee 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .types import SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "3.0.0b1" +__version__ = "3.0.0b2" __all__ = [ "__version__", From a03a8b19fe5d2927adedb979da66146babf898ed Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 8 Sep 2022 11:30:10 +0200 Subject: [PATCH 30/67] Use Graphene DataLoader in graphene>=3.1.1 (#360) * Use Graphene Datolader in graphene>=3.1.1 --- graphene_sqlalchemy/batching.py | 21 +++++++++++++++++++-- graphene_sqlalchemy/utils.py | 9 ++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 85cc8855..e56b1e4c 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,13 +1,30 @@ +"""The dataloader uses "select in loading" strategy to load related entities.""" +from typing import Any + import aiodataloader import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import is_sqlalchemy_version_less_than +from .utils import (is_graphene_version_less_than, + is_sqlalchemy_version_less_than) -def get_batch_resolver(relationship_prop): +def get_data_loader_impl() -> Any: # pragma: no cover + """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, + aiodataloader is used in conjunction with older versions of graphene""" + if is_graphene_version_less_than("3.1.1"): + from aiodataloader import DataLoader + else: + from graphene.utils.dataloader import DataLoader + + return DataLoader + +DataLoader = get_data_loader_impl() + + +def get_batch_resolver(relationship_prop): # Cache this across `batch_load_fn` calls # This is so SQL string generation is cached under-the-hood via `bakery` selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index f6ee9b62..27117c0c 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -151,11 +151,16 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_sqlalchemy_version_less_than(version_string): +def is_sqlalchemy_version_less_than(version_string): # pragma: no cover """Check the installed SQLAlchemy version""" return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return pkg_resources.get_distribution('graphene').parsed_version < pkg_resources.parse_version(version_string) + + class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function @@ -197,6 +202,7 @@ def safe_isinstance_checker(arg): return isinstance(arg, cls) except TypeError: pass + return safe_isinstance_checker @@ -210,5 +216,6 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: class DummyImport: """The dummy module returns 'object' for a query for any member""" + def __getattr__(self, name): return object From bb7af4b60f35dbd69ce64967eeac04ef6522c8fc Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 8 Sep 2022 11:31:08 +0200 Subject: [PATCH 31/67] 3.0.0b3 --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index c5400cee..33345815 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .types import SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "3.0.0b2" +__version__ = "3.0.0b3" __all__ = [ "__version__", From 43df4ebbd6bcf67b501e3acc04e99664f8382f11 Mon Sep 17 00:00:00 2001 From: Paul Schweizer Date: Fri, 9 Sep 2022 18:59:11 +0200 Subject: [PATCH 32/67] feat: Support Sorting in Batch ConnectionFields & Deprecate UnsortedConnectionField(#355) * Enable sorting when batching is enabled * Deprecate UnsortedSQLAlchemyConnectionField and resetting RelationshipLoader between queries * Use field_name instead of column.key to build sort enum names to ensure the enum will get the actula field_name * Adjust batching test to honor different selet in query structure in sqla1.2 * Ensure that UnsortedSQLAlchemyConnectionField skips sort argument if it gets passed. * add test for batch sorting with custom ormfield Co-authored-by: Sabar Dasgupta --- graphene_sqlalchemy/batching.py | 178 ++++---- graphene_sqlalchemy/enums.py | 4 +- graphene_sqlalchemy/fields.py | 116 ++--- graphene_sqlalchemy/tests/models.py | 18 + graphene_sqlalchemy/tests/test_batching.py | 467 +++++++++++++++------ graphene_sqlalchemy/tests/test_fields.py | 8 + 6 files changed, 534 insertions(+), 257 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index e56b1e4c..f6f14a6e 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,5 +1,6 @@ """The dataloader uses "select in loading" strategy to load related entities.""" -from typing import Any +from asyncio import get_event_loop +from typing import Any, Dict import aiodataloader import sqlalchemy @@ -10,6 +11,90 @@ is_sqlalchemy_version_less_than) +class RelationshipLoader(aiodataloader.DataLoader): + cache = False + + def __init__(self, relationship_prop, selectin_loader): + super().__init__() + self.relationship_prop = relationship_prop + self.selectin_loader = selectin_loader + + async def batch_load_fn(self, parents): + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = self.relationship_prop.mapper + parent_mapper = self.relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = None + if is_sqlalchemy_version_less_than('1.4'): + query_context = QueryContext(session.query(parent_mapper.entity)) + else: + parent_mapper_query = session.query(parent_mapper.entity) + query_context = parent_mapper_query._compile_context() + + if is_sqlalchemy_version_less_than('1.4'): + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + else: + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + ) + return [ + getattr(parent, self.relationship_prop.key) for parent in parents + ] + + +# Cache this across `batch_load_fn` calls +# This is so SQL string generation is cached under-the-hood via `bakery` +# Caching the relationship loader for each relationship prop. +RELATIONSHIP_LOADERS_CACHE: Dict[ + sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader +] = {} + + def get_data_loader_impl() -> Any: # pragma: no cover """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, aiodataloader is used in conjunction with older versions of graphene""" @@ -25,80 +110,23 @@ def get_data_loader_impl() -> Any: # pragma: no cover def get_batch_resolver(relationship_prop): - # Cache this across `batch_load_fn` calls - # This is so SQL string generation is cached under-the-hood via `bakery` - selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - - class RelationshipLoader(aiodataloader.DataLoader): - cache = False - - async def batch_load_fn(self, parents): - """ - Batch loads the relationships of all the parents as one SQL statement. - - There is no way to do this out-of-the-box with SQLAlchemy but - we can piggyback on some internal APIs of the `selectin` - eager loading strategy. It's a bit hacky but it's preferable - than re-implementing and maintainnig a big chunk of the `selectin` - loader logic ourselves. - - The approach here is to build a regular query that - selects the parent and `selectin` load the relationship. - But instead of having the query emits 2 `SELECT` statements - when callling `all()`, we skip the first `SELECT` statement - and jump right before the `selectin` loader is called. - To accomplish this, we have to construct objects that are - normally built in the first part of the query in order - to call directly `SelectInLoader._load_for_path`. - - TODO Move this logic to a util in the SQLAlchemy repo as per - SQLAlchemy's main maitainer suggestion. - See https://git.io/JewQ7 - """ - child_mapper = relationship_prop.mapper - parent_mapper = relationship_prop.parent - session = Session.object_session(parents[0]) - - # These issues are very unlikely to happen in practice... - for parent in parents: - # assert parent.__mapper__ is parent_mapper - # All instances must share the same session - assert session is Session.object_session(parent) - # The behavior of `selectin` is undefined if the parent is dirty - assert parent not in session.dirty - - # Should the boolean be set to False? Does it matter for our purposes? - states = [(sqlalchemy.inspect(parent), True) for parent in parents] - - # For our purposes, the query_context will only used to get the session - query_context = None - if is_sqlalchemy_version_less_than('1.4'): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: - parent_mapper_query = session.query(parent_mapper.entity) - query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than('1.4'): - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper - ) - else: - selectin_loader._load_for_path( - query_context, - parent_mapper._path_registry, - states, - None, - child_mapper, - None - ) - - return [getattr(parent, relationship_prop.key) for parent in parents] - - loader = RelationshipLoader() + """Get the resolve function for the given relationship.""" + + def _get_loader(relationship_prop): + """Retrieve the cached loader of the given relationship.""" + loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) + if loader is None or loader.loop != get_event_loop(): + selectin_loader = strategies.SelectInLoader( + relationship_prop, (('lazy', 'selectin'),) + ) + loader = RelationshipLoader( + relationship_prop=relationship_prop, + selectin_loader=selectin_loader, + ) + RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader + return loader + + loader = _get_loader(relationship_prop) async def resolve(root, info, **args): return await loader.load(root) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index a2ed17ad..19f40b7f 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -144,9 +144,9 @@ def sort_enum_for_object_type( column = orm_field.columns[0] if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(column.key, True) + asc_name = get_name(field_name, True) asc_value = EnumValue(asc_name, column.asc()) - desc_name = get_name(column.key, False) + desc_name = get_name(field_name, False) desc_value = EnumValue(desc_name, column.desc()) if column.primary_key: default.append(asc_value) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d7a83392..9b4b8436 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -14,7 +14,7 @@ from .utils import EnumValue, get_query -class UnsortedSQLAlchemyConnectionField(ConnectionField): +class SQLAlchemyConnectionField(ConnectionField): @property def type(self): from .types import SQLAlchemyObjectType @@ -37,13 +37,45 @@ def type(self): ) return nullable_type.connection + def __init__(self, type_, *args, **kwargs): + nullable_type = get_nullable_type(type_) + if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + nullable_type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) + @property def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, **args): - return get_query(model, info.context) + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if not isinstance(sort, list): + sort = [sort] + sort_args = [] + # ensure consistent handling of graphene Enums, enum values and + # plain strings + for item in sort: + if isinstance(item, enum.Enum): + sort_args.append(item.value.value) + elif isinstance(item, EnumValue): + sort_args.append(item.value) + else: + sort_args.append(item) + query = query.order_by(*sort_args) + return query @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -90,59 +122,49 @@ def wrap_resolve(self, parent_resolver): ) -# TODO Rename this to SortableSQLAlchemyConnectionField -class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +# TODO Remove in next major version +class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField): def __init__(self, type_, *args, **kwargs): - nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and issubclass(nullable_type, Connection): - # Let super class raise if type is not a Connection - try: - kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) - except (AttributeError, TypeError): - raise TypeError( - 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' - " to None to disabling the creation of the sort query argument".format( - nullable_type.__name__ - ) - ) - elif "sort" in kwargs and kwargs["sort"] is None: - del kwargs["sort"] - super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) - - @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if not isinstance(sort, list): - sort = [sort] - sort_args = [] - # ensure consistent handling of graphene Enums, enum values and - # plain strings - for item in sort: - if isinstance(item, enum.Enum): - sort_args.append(item.value.value) - elif isinstance(item, EnumValue): - sort_args.append(item.value) - else: - sort_args.append(item) - query = query.order_by(*sort_args) - return query + if "sort" in kwargs and kwargs["sort"] is not None: + warnings.warn( + "UnsortedSQLAlchemyConnectionField does not support sorting. " + "All sorting arguments will be ignored." + ) + kwargs["sort"] = None + warnings.warn( + "UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyConnectionField instead and either don't " + "provide the `sort` argument or set it to None if you do not want sorting.", + DeprecationWarning, + ) + super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) -class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): +class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): """ This is currently experimental. The API and behavior may change in future versions. Use at your own risk. """ - def wrap_resolve(self, parent_resolver): - return partial( - self.connection_resolver, - self.resolver, - get_nullable_type(self.type), - self.model, - ) + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + if root is None: + resolved = resolver(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + else: + relationship_prop = None + for relationship in root.__class__.__mapper__.relationships: + if relationship.mapper.class_ == model: + relationship_prop = relationship + break + resolved = get_batch_resolver(relationship_prop)(root, info, **args) + on_resolve = partial(cls.resolve_connection, connection_type, root, info, args) + + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) @classmethod def from_relationship(cls, relationship, registry, **field_kwargs): diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index dc399ee0..c7a1d664 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -110,6 +110,24 @@ class Article(Base): headline = Column(String(100)) pub_date = Column(Date()) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) class ReflectedEditor(type): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 1896900b..fc4e6649 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -5,13 +5,13 @@ import pytest import graphene -from graphene import relay +from graphene import Connection, relay from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType from ..utils import is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reporter +from .models import Article, HairKind, Pet, Reader, Reporter from .utils import remove_cache_miss_stat, to_std_dicts @@ -73,6 +73,40 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) +def get_full_relay_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = BatchSQLAlchemyConnectionField(ArticleType.connection) + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + readers = BatchSQLAlchemyConnectionField(ReaderType.connection) + + return graphene.Schema(query=Query) + + if is_sqlalchemy_version_less_than('1.2'): pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) @@ -82,11 +116,11 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -138,20 +172,20 @@ async def test_many_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], } @@ -160,11 +194,11 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -185,14 +219,14 @@ async def test_one_to_one(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - favoriteArticle { - headline - } + query { + reporters { + firstName + favoriteArticle { + headline + } + } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -216,20 +250,20 @@ async def test_one_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], } @@ -238,11 +272,11 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -271,18 +305,18 @@ async def test_one_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - articles(first: 2) { - edges { - node { - headline - } + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -306,42 +340,42 @@ async def test_one_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_1", - }, - }, - { - "node": { - "headline": "Article_2", - }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_3", + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], }, - }, - { - "node": { - "headline": "Article_4", + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], }, - }, - ], - }, - }, - ], + }, + ], } @@ -350,11 +384,11 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name='Reporter_1', ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name='Reporter_2', ) session.add(reporter_2) @@ -385,18 +419,18 @@ async def test_many_to_many(session_factory): # Starts new session to fully reset the engine / connection logging level session = session_factory() result = await schema.execute_async(""" - query { - reporters { - firstName - pets(first: 2) { - edges { - node { - name - } + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } } - } } - } """, context_value={"session": session}) messages = sqlalchemy_logging_handler.messages @@ -420,42 +454,42 @@ async def test_many_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_1", - }, - }, - { - "node": { - "name": "Pet_2", + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_3", + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], }, - }, - { - "node": { - "name": "Pet_4", - }, - }, - ], - }, - }, - ], + }, + ], } @@ -531,6 +565,70 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 +@pytest.mark.asyncio +def test_batch_sorting_with_custom_ormfield(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + batching = True + connection_class = Connection + + firstname = ORMField(model_attr="first_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = BatchSQLAlchemyConnectionField(ReporterType.connection) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters(sort: [FIRSTNAME_DESC]) { + edges { + node { + firstname + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + assert result == { + "reporters": {"edges": [ + {"node": { + "firstname": "Reporter_2", + }}, + {"node": { + "firstname": "Reporter_1", + }}, + ]} + } + select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message] + assert len(select_statements) == 2 + + @pytest.mark.asyncio async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() @@ -642,3 +740,106 @@ def resolve_reporters(self, info): select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] assert len(select_statements) == 2 + + +@pytest.mark.asyncio +async def test_batching_across_nested_relay_schema(session_factory): + session = session_factory() + + for first_name in "fgerbhjikzutzxsdfdqqa": + reporter = Reporter( + first_name=first_name, + ) + session.add(reporter) + article = Article(headline='Article') + article.reporter = reporter + session.add(article) + reader = Reader(name='Reader') + reader.articles = [article] + session.add(reader) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = await schema.execute_async(""" + query { + reporters { + edges { + node { + firstName + articles { + edges { + node { + id + readers { + edges { + node { + name + } + } + } + } + } + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + result = to_std_dicts(result.data) + select_statements = [message for message in messages if 'SELECT' in message] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than('1.3'): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] + + +@pytest.mark.asyncio +async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_factory): + session = session_factory() + + for first_name, email in zip("cadbbb", "aaabac"): + reporter_1 = Reporter( + first_name=first_name, + email=email + ) + session.add(reporter_1) + article_1 = Article(headline="headline") + article_1.reporter = reporter_1 + session.add(article_1) + + session.commit() + session.close() + + schema = get_full_relay_schema() + + session = session_factory() + result = await schema.execute_async(""" + query { + reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { + edges { + node { + firstName + email + } + } + } + } + """, context_value={"session": session}) + + result = to_std_dicts(result.data) + assert [ + r["node"]["firstName"] + r["node"]["email"] + for r in result["reporters"]["edges"] + ] == ['aa', 'ba', 'bb', 'bc', 'ca', 'da'] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 357055e3..2782da89 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -64,6 +64,14 @@ def test_type_assert_object_has_connection(): ## +def test_unsorted_connection_field_removes_sort_arg_if_passed(): + editor = UnsortedSQLAlchemyConnectionField( + Editor.connection, + sort=Editor.sort_argument(has_default=True) + ) + assert "sort" not in editor.args + + def test_sort_added_by_default(): field = SQLAlchemyConnectionField(Pet.connection) assert "sort" in field.args From b3657b069424c1b9f5ae136a1355f685554df761 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 12 Sep 2022 21:36:55 +0200 Subject: [PATCH 33/67] Add Black to pre-commit (#361) This commits re-formats the codebase using black --- .flake8 | 4 - .pre-commit-config.yaml | 10 +- docs/conf.py | 87 ++--- examples/flask_sqlalchemy/database.py | 23 +- examples/flask_sqlalchemy/models.py | 22 +- examples/flask_sqlalchemy/schema.py | 9 +- examples/nameko_sqlalchemy/app.py | 76 +++-- examples/nameko_sqlalchemy/database.py | 23 +- examples/nameko_sqlalchemy/models.py | 22 +- examples/nameko_sqlalchemy/service.py | 4 +- graphene_sqlalchemy/batching.py | 13 +- graphene_sqlalchemy/converter.py | 155 ++++++--- graphene_sqlalchemy/enums.py | 16 +- graphene_sqlalchemy/fields.py | 36 +- graphene_sqlalchemy/registry.py | 18 +- graphene_sqlalchemy/resolvers.py | 2 +- graphene_sqlalchemy/tests/conftest.py | 2 +- graphene_sqlalchemy/tests/models.py | 44 ++- graphene_sqlalchemy/tests/test_batching.py | 268 +++++++++------ graphene_sqlalchemy/tests/test_benchmark.py | 84 +++-- graphene_sqlalchemy/tests/test_converter.py | 197 +++++++---- graphene_sqlalchemy/tests/test_enums.py | 29 +- graphene_sqlalchemy/tests/test_fields.py | 8 +- graphene_sqlalchemy/tests/test_query.py | 22 +- graphene_sqlalchemy/tests/test_query_enums.py | 47 ++- graphene_sqlalchemy/tests/test_reflected.py | 1 - graphene_sqlalchemy/tests/test_registry.py | 4 +- graphene_sqlalchemy/tests/test_sort_enums.py | 12 +- graphene_sqlalchemy/tests/test_types.py | 309 ++++++++++-------- graphene_sqlalchemy/tests/test_utils.py | 18 +- graphene_sqlalchemy/types.py | 151 +++++---- graphene_sqlalchemy/utils.py | 19 +- setup.cfg | 4 +- 33 files changed, 1041 insertions(+), 698 deletions(-) delete mode 100644 .flake8 diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 30f6dedd..00000000 --- a/.flake8 +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -ignore = E203,W503 -exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs -max-line-length = 120 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66db3814..470a29eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.7 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -16,6 +16,14 @@ repos: hooks: - id: isort name: isort (python) + - repo: https://github.com/asottile/pyupgrade + rev: v2.37.3 + hooks: + - id: pyupgrade + - repo: https://github.com/psf/black + rev: 22.6.0 + hooks: + - id: black - repo: https://github.com/PyCQA/flake8 rev: 4.0.0 hooks: diff --git a/docs/conf.py b/docs/conf.py index 3fa6391d..9c9fc1d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,6 +1,6 @@ import os -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" # -*- coding: utf-8 -*- # @@ -34,46 +34,46 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", ] if not on_rtd: extensions += [ - 'sphinx.ext.githubpages', + "sphinx.ext.githubpages", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Graphene Django' -copyright = u'Graphene 2016' -author = u'Syrus Akbary' +project = "Graphene Django" +copyright = "Graphene 2016" +author = "Syrus Akbary" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = u'1.0' +version = "1.0" # The full version, including alpha/beta/rc tags. -release = u'1.0.dev' +release = "1.0.dev" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -94,7 +94,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -116,7 +116,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -175,7 +175,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -255,34 +255,30 @@ # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'Graphenedoc' +htmlhelp_basename = "Graphenedoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Graphene.tex', u'Graphene Documentation', - u'Syrus Akbary', 'manual'), + (master_doc, "Graphene.tex", "Graphene Documentation", "Syrus Akbary", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -323,8 +319,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'graphene_django', u'Graphene Django Documentation', - [author], 1) + (master_doc, "graphene_django", "Graphene Django Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -338,9 +333,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Graphene-Django', u'Graphene Django Documentation', - author, 'Graphene Django', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Graphene-Django", + "Graphene Django Documentation", + author, + "Graphene Django", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -414,7 +415,7 @@ # epub_post_files = [] # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # The depth of the table of contents in toc.ncx. # @@ -446,4 +447,4 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index ca4d4122..74ec7ca9 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -15,24 +15,25 @@ def init_db(): # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index ea525e3b..c4a91e63 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -10,26 +10,27 @@ class Department(SQLAlchemyObjectType): class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee.connection, sort=Employee.sort_argument()) + Employee.connection, sort=Employee.sort_argument() + ) # Allows sorting over multiple columns, by default over the primary key all_roles = SQLAlchemyConnectionField(Role.connection) # Disable sorting over this field diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index 05352529..64d305ea 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,37 +1,45 @@ from database import db_session, init_db from schema import schema -from graphql_server import (HttpQueryError, default_format_error, - encode_execution_results, json_encode, - load_json_body, run_http_query) - - -class App(): - def __init__(self): - init_db() - - def query(self, request): - data = self.parse_body(request) - execution_results, params = run_http_query( - schema, - 'post', - data) - result, status_code = encode_execution_results( - execution_results, - format_error=default_format_error,is_batch=False, encode=json_encode) - return result - - def parse_body(self,request): - # We use mimetype here since we don't need the other - # information provided by content_type - content_type = request.mimetype - if content_type == 'application/graphql': - return {'query': request.data.decode('utf8')} - - elif content_type == 'application/json': - return load_json_body(request.data.decode('utf8')) - - elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): - return request.form - - return {} +from graphql_server import ( + HttpQueryError, + default_format_error, + encode_execution_results, + json_encode, + load_json_body, + run_http_query, +) + + +class App: + def __init__(self): + init_db() + + def query(self, request): + data = self.parse_body(request) + execution_results, params = run_http_query(schema, "post", data) + result, status_code = encode_execution_results( + execution_results, + format_error=default_format_error, + is_batch=False, + encode=json_encode, + ) + return result + + def parse_body(self, request): + # We use mimetype here since we don't need the other + # information provided by content_type + content_type = request.mimetype + if content_type == "application/graphql": + return {"query": request.data.decode("utf8")} + + elif content_type == "application/json": + return load_json_body(request.data.decode("utf8")) + + elif content_type in ( + "application/x-www-form-urlencoded", + "multipart/form-data", + ): + return request.form + + return {} diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index ca4d4122..74ec7ca9 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -2,10 +2,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker -engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) +engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() @@ -15,24 +15,25 @@ def init_db(): # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() from models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) # Create the fixtures - engineering = Department(name='Engineering') + engineering = Department(name="Engineering") db_session.add(engineering) - hr = Department(name='Human Resources') + hr = Department(name="Human Resources") db_session.add(hr) - manager = Role(name='manager') + manager = Role(name="manager") db_session.add(manager) - engineer = Role(name='engineer') + engineer = Role(name="engineer") db_session.add(engineer) - peter = Employee(name='Peter', department=engineering, role=engineer) + peter = Employee(name="Peter", department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering, role=engineer) + roy = Employee(name="Roy", department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr, role=manager) + tracy = Employee(name="Tracy", department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/nameko_sqlalchemy/models.py b/examples/nameko_sqlalchemy/models.py index efbbe690..38f0fd0a 100644 --- a/examples/nameko_sqlalchemy/models.py +++ b/examples/nameko_sqlalchemy/models.py @@ -4,35 +4,31 @@ class Department(Base): - __tablename__ = 'department' + __tablename__ = "department" id = Column(Integer, primary_key=True) name = Column(String) class Role(Base): - __tablename__ = 'roles' + __tablename__ = "roles" role_id = Column(Integer, primary_key=True) name = Column(String) class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String) # Use default=func.now() to set the default hiring time # of an Employee to be the current time when an # Employee record was created hired_on = Column(DateTime, default=func.now()) - department_id = Column(Integer, ForeignKey('department.id')) - role_id = Column(Integer, ForeignKey('roles.role_id')) + department_id = Column(Integer, ForeignKey("department.id")) + role_id = Column(Integer, ForeignKey("roles.role_id")) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( - Department, - backref=backref('employees', - uselist=True, - cascade='delete,all')) + Department, backref=backref("employees", uselist=True, cascade="delete,all") + ) role = relationship( - Role, - backref=backref('roles', - uselist=True, - cascade='delete,all')) + Role, backref=backref("roles", uselist=True, cascade="delete,all") + ) diff --git a/examples/nameko_sqlalchemy/service.py b/examples/nameko_sqlalchemy/service.py index d9c519c9..7f4c5078 100644 --- a/examples/nameko_sqlalchemy/service.py +++ b/examples/nameko_sqlalchemy/service.py @@ -4,8 +4,8 @@ class DepartmentService: - name = 'department' + name = "department" - @http('POST', '/graphql') + @http("POST", "/graphql") def query(self, request): return App().query(request) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index f6f14a6e..275d5904 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -7,8 +7,7 @@ from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import (is_graphene_version_less_than, - is_sqlalchemy_version_less_than) +from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than class RelationshipLoader(aiodataloader.DataLoader): @@ -59,13 +58,13 @@ async def batch_load_fn(self, parents): # For our purposes, the query_context will only used to get the session query_context = None - if is_sqlalchemy_version_less_than('1.4'): + if is_sqlalchemy_version_less_than("1.4"): query_context = QueryContext(session.query(parent_mapper.entity)) else: parent_mapper_query = session.query(parent_mapper.entity) query_context = parent_mapper_query._compile_context() - if is_sqlalchemy_version_less_than('1.4'): + if is_sqlalchemy_version_less_than("1.4"): self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, @@ -82,9 +81,7 @@ async def batch_load_fn(self, parents): child_mapper, None, ) - return [ - getattr(parent, self.relationship_prop.key) for parent in parents - ] + return [getattr(parent, self.relationship_prop.key) for parent in parents] # Cache this across `batch_load_fn` calls @@ -117,7 +114,7 @@ def _get_loader(relationship_prop): loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None) if loader is None or loader.loop != get_event_loop(): selectin_loader = strategies.SelectInLoader( - relationship_prop, (('lazy', 'selectin'),) + relationship_prop, (("lazy", "selectin"),) ) loader = RelationshipLoader( relationship_prop=relationship_prop, diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 1e7846eb..d1873c2b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -15,13 +15,16 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import (BatchSQLAlchemyConnectionField, - default_connection_field_factory) +from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import (DummyImport, registry_sqlalchemy_model_from_str, - safe_isinstance, singledispatchbymatchfunction, - value_equals) +from .utils import ( + DummyImport, + registry_sqlalchemy_model_from_str, + safe_isinstance, + singledispatchbymatchfunction, + value_equals, +) try: from typing import ForwardRef @@ -39,7 +42,7 @@ except ImportError: sqa_utils = DummyImport() -is_selectin_available = getattr(strategies, 'SelectInLoader', None) +is_selectin_available = getattr(strategies, "SelectInLoader", None) def get_column_doc(column): @@ -50,8 +53,14 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, - orm_field_name, **field_kwargs): +def convert_sqlalchemy_relationship( + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, +): """ :param sqlalchemy.RelationshipProperty relationship_prop: :param SQLAlchemyObjectType obj_type: @@ -65,24 +74,34 @@ def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_fiel def dynamic_type(): """:rtype: Field|None""" direction = relationship_prop.direction - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) batching_ = batching if is_selectin_available else False if not child_type: return None if direction == interfaces.MANYTOONE or not relationship_prop.uselist: - return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, - **field_kwargs) + return _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching_, orm_field_name, **field_kwargs + ) if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, - connection_field_factory, **field_kwargs) + return _convert_o2m_or_m2m_relationship( + relationship_prop, + obj_type, + batching_, + connection_field_factory, + **field_kwargs, + ) return graphene.Dynamic(dynamic_type) -def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): +def _convert_o2o_or_m2o_relationship( + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs +): """ Convert one-to-one or many-to-one relationshsip. Return an object field. @@ -93,17 +112,24 @@ def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_ :param dict field_kwargs: :rtype: Field """ - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) resolver = get_custom_resolver(obj_type, orm_field_name) if resolver is None: - resolver = get_batch_resolver(relationship_prop) if batching else \ - get_attr_resolver(obj_type, relationship_prop.key) + resolver = ( + get_batch_resolver(relationship_prop) + if batching + else get_attr_resolver(obj_type, relationship_prop.key) + ) return graphene.Field(child_type, resolver=resolver, **field_kwargs) -def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): +def _convert_o2m_or_m2m_relationship( + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs +): """ Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. @@ -114,30 +140,34 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn :param dict field_kwargs: :rtype: Field """ - child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + child_type = obj_type._meta.registry.get_type_for_model( + relationship_prop.mapper.entity + ) if not child_type._meta.connection: return graphene.Field(graphene.List(child_type), **field_kwargs) # TODO Allow override of connection_field_factory and resolver via ORMField if connection_field_factory is None: - connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ - default_connection_field_factory - - return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) + connection_field_factory = ( + BatchSQLAlchemyConnectionField.from_relationship + if batching + else default_connection_field_factory + ) + + return connection_field_factory( + relationship_prop, obj_type._meta.registry, **field_kwargs + ) def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): - if 'type_' not in field_kwargs: - field_kwargs['type_'] = convert_hybrid_property_return_type(hybrid_prop) + if "type_" not in field_kwargs: + field_kwargs["type_"] = convert_hybrid_property_return_type(hybrid_prop) - if 'description' not in field_kwargs: - field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) + if "description" not in field_kwargs: + field_kwargs["description"] = getattr(hybrid_prop, "__doc__", None) - return graphene.Field( - resolver=resolver, - **field_kwargs - ) + return graphene.Field(resolver=resolver, **field_kwargs) def convert_sqlalchemy_composite(composite_prop, registry, resolver): @@ -177,14 +207,14 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) - field_kwargs.setdefault('required', not is_column_nullable(column)) - field_kwargs.setdefault('description', get_column_doc(column)) - - return graphene.Field( - resolver=resolver, - **field_kwargs + field_kwargs.setdefault( + "type_", + convert_sqlalchemy_type(getattr(column, "type", None), column, registry), ) + field_kwargs.setdefault("required", not is_column_nullable(column)) + field_kwargs.setdefault("description", get_column_doc(column)) + + return graphene.Field(resolver=resolver, **field_kwargs) @singledispatch @@ -271,14 +301,20 @@ def convert_scalar_list_to_list(type, column, registry=None): def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else graphene.List(init_array_list_recursive(inner_type, n - 1)) + return ( + inner_type + if n == 0 + else graphene.List(init_array_list_recursive(inner_type, n - 1)) + ) @convert_sqlalchemy_type.register(sqa_types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return graphene.List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) + return graphene.List( + init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1) + ) @convert_sqlalchemy_type.register(postgresql.HSTORE) @@ -313,8 +349,8 @@ def convert_sqlalchemy_hybrid_property_type(arg: Any): # No valid type found, warn and fall back to graphene.String warnings.warn( - (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." - "Falling back to \"graphene.String\"") + f'I don\'t know how to generate a GraphQL type out of a "{arg}" type.' + 'Falling back to "graphene.String"' ) return graphene.String @@ -368,15 +404,17 @@ def is_union(arg) -> bool: if isinstance(arg, UnionType): return True - return getattr(arg, '__origin__', None) == typing.Union + return getattr(arg, "__origin__", None) == typing.Union -def graphene_union_for_py_union(obj_types: typing.List[graphene.ObjectType], registry) -> graphene.Union: +def graphene_union_for_py_union( + obj_types: typing.List[graphene.ObjectType], registry +) -> graphene.Union: union_type = registry.get_union_for_object_types(obj_types) if union_type is None: # Union Name is name of the three - union_name = ''.join(sorted([obj_type._meta.name for obj_type in obj_types])) + union_name = "".join(sorted(obj_type._meta.name for obj_type in obj_types)) union_type = graphene.Union(union_name, obj_types) registry.register_union_type(union_type, obj_types) @@ -411,16 +449,25 @@ def convert_sqlalchemy_hybrid_property_union(arg): return graphene_types[0] # Now check if every type is instance of an ObjectType - if not all(isinstance(graphene_type, type(graphene.ObjectType)) for graphene_type in graphene_types): - raise ValueError("Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " - "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " - "or use an ORMField to override this behaviour.") - - return graphene_union_for_py_union(cast(typing.List[graphene.ObjectType], list(graphene_types)), - get_global_registry()) + if not all( + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types + ): + raise ValueError( + "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " + "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " + "or use an ORMField to override this behaviour." + ) + + return graphene_union_for_py_union( + cast(typing.List[graphene.ObjectType], list(graphene_types)), + get_global_registry(), + ) -@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) +@convert_sqlalchemy_hybrid_property_type.register( + lambda x: getattr(x, "__origin__", None) in [list, typing.List] +) def convert_sqlalchemy_hybrid_property_type_list_t(arg): # type is either list[T] or List[T], generic argument at __args__[0] internal_type = arg.__args__[0] @@ -459,6 +506,6 @@ def convert_sqlalchemy_hybrid_property_bare_str(arg): def convert_hybrid_property_return_type(hybrid_prop): # Grab the original method's return type annotations from inside the hybrid property - return_type_annotation = hybrid_prop.fget.__annotations__.get('return', str) + return_type_annotation = hybrid_prop.fget.__annotations__.get("return", str) return convert_sqlalchemy_hybrid_property_type(return_type_annotation) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 19f40b7f..97f8997c 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -18,9 +18,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): The Enum value names are converted to upper case if necessary. """ if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum_class = sa_enum.enum_class if enum_class: if all(to_enum_value_name(key) == key for key in enum_class.__members__): @@ -45,9 +43,7 @@ def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): def enum_for_sa_enum(sa_enum, registry): """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)) enum = registry.get_graphene_enum_for_sa_enum(sa_enum) if not enum: enum = _convert_sa_to_graphene_enum(sa_enum) @@ -60,11 +56,9 @@ def enum_for_field(obj_type, field_name): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + raise TypeError("Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): - raise TypeError( - "Expected a field name, but got: {!r}".format(field_name)) + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) if orm_field is None: @@ -166,7 +160,7 @@ def sort_argument_for_object_type( get_symbol_name=None, has_default=True, ): - """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + """ "Returns Graphene Argument for sorting the given SQLAlchemyObjectType. Parameters - obj_type : SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 9b4b8436..2cb53c55 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -26,9 +26,7 @@ def type(self): assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(nullable_type.__name__) - assert ( - nullable_type.connection - ), "The type {} doesn't have a connection".format( + assert nullable_type.connection, "The type {} doesn't have a connection".format( nullable_type.__name__ ) assert type_ == nullable_type, ( @@ -39,7 +37,11 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) - if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection): + if ( + "sort" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): # Let super class raise if type is not a Connection try: kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) @@ -151,7 +153,9 @@ class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField): def connection_resolver(cls, resolver, connection_type, model, root, info, **args): if root is None: resolved = resolver(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + on_resolve = partial( + cls.resolve_connection, connection_type, model, info, args + ) else: relationship_prop = None for relationship in root.__class__.__mapper__.relationships: @@ -159,7 +163,9 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg relationship_prop = relationship break resolved = get_batch_resolver(relationship_prop)(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection_type, root, info, args) + on_resolve = partial( + cls.resolve_connection, connection_type, root, info, args + ) if is_thenable(resolved): return Promise.resolve(resolved).then(on_resolve) @@ -170,7 +176,11 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg def from_relationship(cls, relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + return cls( + model_type.connection, + resolver=get_batch_resolver(relationship), + **field_kwargs + ) def default_connection_field_factory(relationship, registry, **field_kwargs): @@ -185,8 +195,8 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): def createConnectionField(type_, **field_kwargs): warnings.warn( - 'createConnectionField is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "createConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) return __connectionFactory(type_, **field_kwargs) @@ -194,8 +204,8 @@ def createConnectionField(type_, **field_kwargs): def registerConnectionFieldFactory(factoryMethod): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory @@ -204,8 +214,8 @@ def registerConnectionFieldFactory(factoryMethod): def unregisterConnectionFieldFactory(): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 80470d9b..8f2bc9e7 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -20,8 +20,9 @@ def __init__(self): def register(self, obj_type): from .types import SQLAlchemyObjectType + if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -40,7 +41,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -76,8 +77,9 @@ def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): def register_sort_enum(self, obj_type, sort_enum: Enum): from .types import SQLAlchemyObjectType + if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -89,11 +91,11 @@ def register_sort_enum(self, obj_type, sort_enum: Enum): def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) - def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]]): + def register_union_type( + self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]] + ): if not isinstance(union, graphene.Union): - raise TypeError( - "Expected graphene.Union, but got: {!r}".format(union) - ) + raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) for obj_type in obj_types: if not isinstance(obj_type, type(graphene.ObjectType)): @@ -103,7 +105,7 @@ def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphe self._registry_unions[frozenset(obj_types)] = union - def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]): + def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): return self._registry_unions.get(frozenset(obj_types)) diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py index 83a6e35d..e8e61911 100644 --- a/graphene_sqlalchemy/resolvers.py +++ b/graphene_sqlalchemy/resolvers.py @@ -7,7 +7,7 @@ def get_custom_resolver(obj_type, orm_field_name): does not have a `resolver`, we need to re-implement that logic here so users are able to override the default resolvers that we provide. """ - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + resolver = getattr(obj_type, "resolve_{}".format(orm_field_name), None) if resolver: return get_unbound_function(resolver) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 34ba9d8a..357ad96e 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -8,7 +8,7 @@ from ..registry import reset_global_registry from .models import Base, CompositeFullName -test_db_url = 'sqlite://' # use in-memory database for tests +test_db_url = "sqlite://" # use in-memory database for tests @pytest.fixture(autouse=True) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index c7a1d664..fd5d3b21 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -5,8 +5,18 @@ from decimal import Decimal from typing import List, Optional, Tuple -from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, Numeric, - String, Table, func, select) +from sqlalchemy import ( + Column, + Date, + Enum, + ForeignKey, + Integer, + Numeric, + String, + Table, + func, + select, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import column_property, composite, mapper, relationship @@ -15,8 +25,8 @@ class HairKind(enum.Enum): - LONG = 'long' - SHORT = 'short' + LONG = "long" + SHORT = "short" Base = declarative_base() @@ -64,7 +74,9 @@ class Reporter(Base): last_name = Column(String(30), doc="Last name") email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") + pets = relationship( + "Pet", secondary=association_table, backref="reporters", order_by="Pet.id" + ) articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) @@ -101,7 +113,9 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) class Article(Base): @@ -155,7 +169,7 @@ class ShoppingCartItem(Base): id = Column(Integer(), primary_key=True) @hybrid_property - def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + def hybrid_prop_shopping_cart(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] @@ -210,11 +224,17 @@ def hybrid_prop_list_date(self) -> List[datetime.date]: @hybrid_property def hybrid_prop_nested_list_int(self) -> List[List[int]]: - return [self.hybrid_prop_list_int, ] + return [ + self.hybrid_prop_list_int, + ] @hybrid_property def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: - return [[self.hybrid_prop_list_int, ], ] + return [ + [ + self.hybrid_prop_list_int, + ], + ] # Other SQLAlchemy Instances @hybrid_property @@ -234,17 +254,17 @@ def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: # Self-references @hybrid_property - def hybrid_prop_self_referential(self) -> 'ShoppingCart': + def hybrid_prop_self_referential(self) -> "ShoppingCart": return ShoppingCart(id=1) @hybrid_property - def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + def hybrid_prop_self_referential_list(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] # Optional[T] @hybrid_property - def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: + def hybrid_prop_optional_self_referential(self) -> Optional["ShoppingCart"]: return None diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index fc4e6649..90df0279 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -7,8 +7,7 @@ import graphene from graphene import Connection, relay -from ..fields import (BatchSQLAlchemyConnectionField, - default_connection_field_factory) +from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reader, Reporter @@ -17,6 +16,7 @@ class MockLoggingHandler(logging.Handler): """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): self.messages = [] logging.Handler.__init__(self, *args, **kwargs) @@ -28,7 +28,7 @@ def emit(self, record): @contextlib.contextmanager def mock_sqlalchemy_logging_handler(): logging.basicConfig() - sql_logger = logging.getLogger('sqlalchemy.engine') + sql_logger = logging.getLogger("sqlalchemy.engine") previous_level = sql_logger.level sql_logger.setLevel(logging.INFO) @@ -65,10 +65,10 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + return info.context.get("session").query(Article).all() def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() return graphene.Schema(query=Query) @@ -107,8 +107,8 @@ class Query(graphene.ObjectType): return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) @pytest.mark.asyncio @@ -116,19 +116,19 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) @@ -140,7 +140,8 @@ async def test_many_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { articles { headline @@ -149,20 +150,26 @@ async def test_many_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN reporters" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -194,19 +201,19 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) @@ -218,7 +225,8 @@ async def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -227,20 +235,26 @@ async def test_one_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -272,27 +286,27 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) @@ -304,7 +318,8 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -317,20 +332,26 @@ async def test_one_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -384,27 +405,27 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -418,7 +439,8 @@ async def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -431,20 +453,26 @@ async def test_many_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -495,9 +523,9 @@ async def test_many_to_many(session_factory): def test_disable_batching_via_ormfield(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -520,7 +548,7 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) @@ -528,7 +556,8 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { favoriteArticle { @@ -536,17 +565,24 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -558,19 +594,25 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 @pytest.mark.asyncio def test_batch_sorting_with_custom_ormfield(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -601,7 +643,8 @@ class Meta: with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = schema.execute(""" + result = schema.execute( + """ query { reporters(sort: [FIRSTNAME_DESC]) { edges { @@ -611,30 +654,42 @@ class Meta: } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) assert result == { - "reporters": {"edges": [ - {"node": { - "firstname": "Reporter_2", - }}, - {"node": { - "firstname": "Reporter_1", - }}, - ]} + "reporters": { + "edges": [ + { + "node": { + "firstname": "Reporter_2", + } + }, + { + "node": { + "firstname": "Reporter_1", + } + }, + ] + } } - select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM reporters" in message + ] assert len(select_statements) == 2 @pytest.mark.asyncio async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -657,14 +712,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - await schema.execute_async(""" + await schema.execute_async( + """ query { reporters { articles { @@ -676,24 +732,34 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] else: - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 1 def test_connection_factory_field_overrides_batching_is_true(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -716,14 +782,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -735,10 +802,16 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 @@ -751,10 +824,10 @@ async def test_batching_across_nested_relay_schema(session_factory): first_name=first_name, ) session.add(reporter) - article = Article(headline='Article') + article = Article(headline="Article") article.reporter = reporter session.add(article) - reader = Reader(name='Reader') + reader = Reader(name="Reader") reader.articles = [article] session.add(reader) @@ -766,7 +839,8 @@ async def test_batching_across_nested_relay_schema(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { edges { @@ -790,14 +864,16 @@ async def test_batching_across_nested_relay_schema(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) - select_statements = [message for message in messages if 'SELECT' in message] + select_statements = [message for message in messages if "SELECT" in message] assert len(select_statements) == 4 assert select_statements[-1].startswith("SELECT articles_1.id") - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): assert select_statements[-2].startswith("SELECT reporters_1.id") assert "WHERE reporters_1.id IN" in select_statements[-2] else: @@ -810,10 +886,7 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f session = session_factory() for first_name, email in zip("cadbbb", "aaabac"): - reporter_1 = Reporter( - first_name=first_name, - email=email - ) + reporter_1 = Reporter(first_name=first_name, email=email) session.add(reporter_1) article_1 = Article(headline="headline") article_1.reporter = reporter_1 @@ -825,7 +898,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f schema = get_full_relay_schema() session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters(sort: [FIRST_NAME_ASC, EMAIL_ASC]) { edges { @@ -836,10 +910,12 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) result = to_std_dicts(result.data) assert [ r["node"]["firstName"] + r["node"]["email"] for r in result["reporters"]["edges"] - ] == ['aa', 'ba', 'bb', 'bc', 'ca', 'da'] + ] == ["aa", "ba", "bb", "bc", "ca", "da"] diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 11e9d0e0..bb105edd 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -7,8 +7,8 @@ from ..utils import is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) def get_schema(): @@ -32,10 +32,10 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + return info.context.get("session").query(Article).all() def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() return graphene.Schema(query=Query) @@ -46,8 +46,8 @@ def benchmark_query(session_factory, benchmark, query): @benchmark def execute_query(): result = schema.execute( - query, - context_value={"session": session_factory()}, + query, + context_value={"session": session_factory()}, ) assert not result.errors @@ -56,26 +56,29 @@ def test_one_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -84,33 +87,37 @@ def test_one_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { articles { headline @@ -119,41 +126,45 @@ def test_many_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_one_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -166,34 +177,35 @@ def test_one_to_many(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -202,7 +214,10 @@ def test_many_to_many(session_factory, benchmark): session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -215,4 +230,5 @@ def test_many_to_many(session_factory, benchmark): } } } - """) + """, + ) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index a6c2b1bf..812b4cea 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -15,16 +15,23 @@ from graphene.relay import Node from graphene.types.structures import Structure -from ..converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_hybrid_method, - convert_sqlalchemy_relationship) -from ..fields import (UnsortedSQLAlchemyConnectionField, - default_connection_field_factory) +from ..converter import ( + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, +) +from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry from ..types import ORMField, SQLAlchemyObjectType -from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, - ShoppingCartItem) +from .models import ( + Article, + CompositeFullName, + Pet, + Reporter, + ShoppingCart, + ShoppingCartItem, +) def mock_resolver(): @@ -33,32 +40,34 @@ def mock_resolver(): def get_field(sqlalchemy_type, **column_kwargs): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) - column_prop = inspect(Model).column_attrs['column'] + column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) def get_field_from_column(column_): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) column = column_ - column_prop = inspect(Model).column_attrs['column'] + column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) def get_hybrid_property_type(prop_method): class Model(declarative_base()): - __tablename__ = 'model' + __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) prop = prop_method - column_prop = inspect(Model).all_orm_descriptors['prop'] - return convert_sqlalchemy_hybrid_method(column_prop, mock_resolver(), **ORMField().kwargs) + column_prop = inspect(Model).all_orm_descriptors["prop"] + return convert_sqlalchemy_hybrid_method( + column_prop, mock_resolver(), **ORMField().kwargs + ) def test_hybrid_prop_int(): @@ -69,19 +78,25 @@ def prop_method() -> int: assert get_hybrid_property_type(prop_method).type == graphene.Int -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_hybrid_prop_scalar_union_310(): @hybrid_property def prop_method() -> int | str: return "not allowed in gql schema" - with pytest.raises(ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*"): + with pytest.raises( + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", + ): get_hybrid_property_type(prop_method) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_hybrid_prop_scalar_union_and_optional_310(): """Checks if the use of Optionals does not interfere with non-conform scalar return types""" @@ -92,8 +107,7 @@ def prop_method() -> int | None: assert get_hybrid_property_type(prop_method).type == graphene.Int -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") -def test_should_union_work_310(): +def test_should_union_work(): reg = Registry() class PetType(SQLAlchemyObjectType): @@ -123,7 +137,9 @@ def prop_method_2() -> Union[ShoppingCartType, PetType]: # TODO verify types of the union -@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" +) def test_should_union_work_310(): reg = Registry() @@ -244,7 +260,9 @@ def test_should_integer_convert_int(): def test_should_primary_integer_convert_id(): - assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID) + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( + graphene.ID + ) def test_should_boolean_convert_boolean(): @@ -260,7 +278,7 @@ def test_should_numeric_convert_float(): def test_should_choice_convert_enum(): - field = get_field(sqa_utils.ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + field = get_field(sqa_utils.ChoiceType([("es", "Spanish"), ("en", "English")])) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -270,8 +288,8 @@ def test_should_choice_convert_enum(): def test_should_enum_choice_convert_enum(): class TestEnum(enum.Enum): - es = u"Spanish" - en = u"English" + es = "Spanish" + en = "English" field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type @@ -288,10 +306,14 @@ def test_choice_enum_column_key_name_issue_301(): """ class TestEnum(enum.Enum): - es = u"Spanish" - en = u"English" + es = "Spanish" + en = "English" - testChoice = Column("% descuento1", sqa_utils.ChoiceType(TestEnum, impl=types.String()), key="descuento1") + testChoice = Column( + "% descuento1", + sqa_utils.ChoiceType(TestEnum, impl=types.String()), + key="descuento1", + ) field = get_field_from_column(testChoice) graphene_type = field.type @@ -315,9 +337,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): - field = get_field_from_column(column_property( - select([func.sum(func.cast(id, types.Integer))]).where(id == 1) - )) + field = get_field_from_column( + column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1)) + ) assert field.type == graphene.Int @@ -347,7 +369,11 @@ class Meta: model = Article dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -359,7 +385,11 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -375,7 +405,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) @@ -387,7 +421,11 @@ class Meta: model = Article dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -399,7 +437,11 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -414,7 +456,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', + Article.reporter.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -429,7 +475,11 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', + Reporter.favorite_article.property, + A, + default_connection_field_factory, + True, + "orm_field_name", ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -457,7 +507,9 @@ def test_should_postgresql_enum_convert(): def test_should_postgresql_py_enum_convert(): - field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers")) + field = get_field( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers") + ) field_type = field.type() assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) @@ -519,7 +571,11 @@ def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) field = convert_sqlalchemy_composite( - composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"), + composite( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + doc="Custom Help Text", + ), registry, mock_resolver, ) @@ -535,7 +591,10 @@ def __init__(self, col1, col2): re_err = "Don't know how to convert the composite field" with pytest.raises(Exception, match=re_err): convert_sqlalchemy_composite( - composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))), + composite( + CompositeFullName, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + ), Registry(), mock_resolver, ) @@ -557,17 +616,22 @@ class Meta: ####################################################### shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { - 'hybrid_prop_shopping_cart': graphene.List(ShoppingCartType) + "hybrid_prop_shopping_cart": graphene.List(ShoppingCartType) } - assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ - # Columns - "id", - # Append Hybrid Properties from Above - *shopping_cart_item_expected_types.keys() - ]) + assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_item_expected_types.keys(), + ] + ) - for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_item_expected_types.items(): + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] # this is a simple way of showing the failed property name @@ -576,7 +640,9 @@ class Meta: hybrid_prop_name, str(hybrid_prop_expected_return_type), ) - assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property ################################################### # Check ShoppingCart's Properties and Return Types @@ -596,7 +662,9 @@ class Meta: "hybrid_prop_list_int": graphene.List(graphene.Int), "hybrid_prop_list_date": graphene.List(graphene.Date), "hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)), - "hybrid_prop_deeply_nested_list_int": graphene.List(graphene.List(graphene.List(graphene.Int))), + "hybrid_prop_deeply_nested_list_int": graphene.List( + graphene.List(graphene.List(graphene.Int)) + ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), "hybrid_prop_unsupported_type_tuple": graphene.String, @@ -607,14 +675,19 @@ class Meta: "hybrid_prop_optional_self_referential": ShoppingCartType, } - assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted([ - # Columns - "id", - # Append Hybrid Properties from Above - *shopping_cart_expected_types.keys() - ]) + assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted( + [ + # Columns + "id", + # Append Hybrid Properties from Above + *shopping_cart_expected_types.keys(), + ] + ) - for hybrid_prop_name, hybrid_prop_expected_return_type in shopping_cart_expected_types.items(): + for ( + hybrid_prop_name, + hybrid_prop_expected_return_type, + ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] # this is a simple way of showing the failed property name @@ -623,4 +696,6 @@ class Meta: hybrid_prop_name, str(hybrid_prop_expected_return_type), ) - assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property + assert ( + hybrid_prop_field.description is None + ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index ca376964..cd97a00e 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -54,7 +54,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): @@ -65,7 +65,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): assert [ (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() - ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): @@ -80,36 +80,35 @@ class PetType(SQLAlchemyObjectType): class Meta: model = Pet - enum = enum_for_field(PetType, 'pet_kind') + enum = enum_for_field(PetType, "pet_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "PetKind" assert [ - (key, value.value) - for key, value in enum._meta.enum.__members__.items() - ] == [("CAT", 'cat'), ("DOG", 'dog')] - enum2 = enum_for_field(PetType, 'pet_kind') + (key, value.value) for key, value in enum._meta.enum.__members__.items() + ] == [("CAT", "cat"), ("DOG", "dog")] + enum2 = enum_for_field(PetType, "pet_kind") assert enum2 is enum - enum2 = PetType.enum_for_field('pet_kind') + enum2 = PetType.enum_for_field("pet_kind") assert enum2 is enum - enum = enum_for_field(PetType, 'hair_kind') + enum = enum_for_field(PetType, "hair_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "HairKind" assert enum._meta.enum is HairKind - enum2 = PetType.enum_for_field('hair_kind') + enum2 = PetType.enum_for_field("hair_kind") assert enum2 is enum re_err = r"Cannot get PetType\.other_kind" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'other_kind') + enum_for_field(PetType, "other_kind") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('other_kind') + PetType.enum_for_field("other_kind") re_err = r"PetType\.name does not map to enum column" with pytest.raises(TypeError, match=re_err): - enum_for_field(PetType, 'name') + enum_for_field(PetType, "name") with pytest.raises(TypeError, match=re_err): - PetType.enum_for_field('name') + PetType.enum_for_field("name") re_err = r"Expected a field name, but got: None" with pytest.raises(TypeError, match=re_err): @@ -119,4 +118,4 @@ class Meta: re_err = "Expected SQLAlchemyObjectType, but got: None" with pytest.raises(TypeError, match=re_err): - enum_for_field(None, 'other_kind') + enum_for_field(None, "other_kind") diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 2782da89..9fed146d 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -4,8 +4,7 @@ from graphene import NonNull, ObjectType from graphene.relay import Connection, Node -from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField) +from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from .models import Editor as EditorModel from .models import Pet as PetModel @@ -21,6 +20,7 @@ class Editor(SQLAlchemyObjectType): class Meta: model = EditorModel + ## # SQLAlchemyConnectionField ## @@ -59,6 +59,7 @@ def test_type_assert_object_has_connection(): with pytest.raises(AssertionError, match="doesn't have a connection"): SQLAlchemyConnectionField(Editor).type + ## # UnsortedSQLAlchemyConnectionField ## @@ -66,8 +67,7 @@ def test_type_assert_object_has_connection(): def test_unsorted_connection_field_removes_sort_arg_if_passed(): editor = UnsortedSQLAlchemyConnectionField( - Editor.connection, - sort=Editor.sort_argument(has_default=True) + Editor.connection, sort=Editor.sort_argument(has_default=True) ) assert "sort" not in editor.args diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 39140814..c7a173df 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -9,19 +9,17 @@ def add_test_data(session): - reporter = Reporter( - first_name='John', last_name='Doe', favorite_pet_kind='cat') + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) session.add(pet) pet.reporters.append(reporter) - article = Article(headline='Hi!') + article = Article(headline="Hi!") article.reporter = reporter session.add(article) - reporter = Reporter( - first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") session.add(reporter) - pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) pet.reporters.append(reporter) session.add(pet) editor = Editor(name="Jack") @@ -163,12 +161,12 @@ class Meta: model = Reporter interfaces = (Node,) - first_name_v2 = ORMField(model_attr='first_name') - hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') - column_prop_v2 = ORMField(model_attr='column_prop') + first_name_v2 = ORMField(model_attr="first_name") + hybrid_prop_v2 = ORMField(model_attr="hybrid_prop") + column_prop_v2 = ORMField(model_attr="column_prop") composite_prop = ORMField() - favorite_article_v2 = ORMField(model_attr='favorite_article') - articles_v2 = ORMField(model_attr='articles') + favorite_article_v2 = ORMField(model_attr="favorite_article") + articles_v2 = ORMField(model_attr="articles") class ArticleType(SQLAlchemyObjectType): class Meta: diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 5166c45f..923bbed1 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -9,7 +9,6 @@ def test_query_pet_kinds(session): add_test_data(session) class PetType(SQLAlchemyObjectType): - class Meta: model = Pet @@ -20,8 +19,9 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - pets = graphene.List(PetType, kind=graphene.Argument( - PetType.enum_for_field('pet_kind'))) + pets = graphene.List( + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) def resolve_reporter(self, _info): return session.query(Reporter).first() @@ -58,27 +58,24 @@ def resolve_pets(self, _info, kind): } """ expected = { - 'reporter': { - 'firstName': 'John', - 'lastName': 'Doe', - 'email': None, - 'favoritePetKind': 'CAT', - 'pets': [{ - 'name': 'Garfield', - 'petKind': 'CAT' - }] + "reporter": { + "firstName": "John", + "lastName": "Doe", + "email": None, + "favoritePetKind": "CAT", + "pets": [{"name": "Garfield", "petKind": "CAT"}], }, - 'reporters': [{ - 'firstName': 'John', - 'favoritePetKind': 'CAT', - }, { - 'firstName': 'Jane', - 'favoritePetKind': 'DOG', - }], - 'pets': [{ - 'name': 'Lassie', - 'petKind': 'DOG' - }] + "reporters": [ + { + "firstName": "John", + "favoritePetKind": "CAT", + }, + { + "firstName": "Jane", + "favoritePetKind": "DOG", + }, + ], + "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) result = schema.execute(query) @@ -125,8 +122,8 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field( - PetType, - kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) def resolve_pet(self, info, kind=None): query = session.query(Pet) diff --git a/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/tests/test_reflected.py index 46e10de9..a3f6c4aa 100644 --- a/graphene_sqlalchemy/tests/test_reflected.py +++ b/graphene_sqlalchemy/tests/test_reflected.py @@ -1,4 +1,3 @@ - from graphene import ObjectType from ..registry import Registry diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index f451f355..cb7e9034 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -142,7 +142,7 @@ class Meta: model = Reporter union_types = [PetType, ReporterType] - union = graphene.Union('ReporterPet', tuple(union_types)) + union = graphene.Union("ReporterPet", tuple(union_types)) reg.register_union_type(union, union_types) @@ -155,7 +155,7 @@ def test_register_union_scalar(): reg = Registry() union_types = [graphene.String, graphene.Int] - union = graphene.Union('StringInt', tuple(union_types)) + union = graphene.Union("StringInt", tuple(union_types)) re_err = r"Expected Graphene ObjectType, but got: .*String.*" with pytest.raises(TypeError, match=re_err): diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index e2510abc..11c7c9a7 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -354,7 +354,7 @@ def makeNodes(nodeList): """ result = schema.execute(queryError, context_value={"session": session}) assert result.errors is not None - assert 'cannot represent non-enum value' in result.errors[0].message + assert "cannot represent non-enum value" in result.errors[0].message queryNoSort = """ query sortTest { @@ -404,5 +404,11 @@ class Meta: "REPORTER_NUMBER_ASC", "REPORTER_NUMBER_DESC", ] - assert str(sort_enum.REPORTER_NUMBER_ASC.value.value) == 'test330."% reporter_number" ASC' - assert str(sort_enum.REPORTER_NUMBER_DESC.value.value) == 'test330."% reporter_number" DESC' + assert ( + str(sort_enum.REPORTER_NUMBER_ASC.value.value) + == 'test330."% reporter_number" ASC' + ) + assert ( + str(sort_enum.REPORTER_NUMBER_DESC.value.value) + == 'test330."% reporter_number" DESC' + ) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 00e8b3af..4afb120d 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,16 +4,31 @@ import sqlalchemy.exc import sqlalchemy.orm.exc -from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, - Node, NonNull, ObjectType, Schema, String) +from graphene import ( + Boolean, + Dynamic, + Field, + Float, + GlobalID, + Int, + List, + Node, + NonNull, + ObjectType, + Schema, + String, +) from graphene.relay import Connection from .. import utils from ..converter import convert_sqlalchemy_composite -from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField, createConnectionField, - registerConnectionFieldFactory, - unregisterConnectionFieldFactory) +from ..fields import ( + SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + createConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory, +) from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, CompositeFullName, Pet, Reporter @@ -21,6 +36,7 @@ def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): pass @@ -28,6 +44,7 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): class Meta: model = 1 @@ -45,7 +62,7 @@ class Meta: reporter = Reporter() session.add(reporter) session.commit() - info = mock.Mock(context={'session': session}) + info = mock.Mock(context={"session": session}) reporter_node = ReporterType.get_node(info, reporter.id) assert reporter == reporter_node @@ -74,91 +91,93 @@ class Meta: model = Article interfaces = (Node,) - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Columns - "column_prop", - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - # Composite - "composite_prop", - # Hybrid - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - # Relationship - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Columns + "column_prop", + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + # Composite + "composite_prop", + # Hybrid + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + # Relationship + "pets", + "articles", + "favorite_article", + ] + ) # column - first_name_field = ReporterType._meta.fields['first_name'] + first_name_field = ReporterType._meta.fields["first_name"] assert first_name_field.type == String assert first_name_field.description == "First name" # column_property - column_prop_field = ReporterType._meta.fields['column_prop'] + column_prop_field = ReporterType._meta.fields["column_prop"] assert column_prop_field.type == Int # "doc" is ignored by column_property assert column_prop_field.description is None # composite - full_name_field = ReporterType._meta.fields['composite_prop'] + full_name_field = ReporterType._meta.fields["composite_prop"] assert full_name_field.type == String # "doc" is ignored by composite assert full_name_field.description is None # hybrid_property - hybrid_prop = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop.type == String # "doc" is ignored by hybrid_property assert hybrid_prop.description is None # hybrid_property_str - hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + hybrid_prop_str = ReporterType._meta.fields["hybrid_prop_str"] assert hybrid_prop_str.type == String # "doc" is ignored by hybrid_property assert hybrid_prop_str.description is None # hybrid_property_int - hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + hybrid_prop_int = ReporterType._meta.fields["hybrid_prop_int"] assert hybrid_prop_int.type == Int # "doc" is ignored by hybrid_property assert hybrid_prop_int.description is None # hybrid_property_float - hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + hybrid_prop_float = ReporterType._meta.fields["hybrid_prop_float"] assert hybrid_prop_float.type == Float # "doc" is ignored by hybrid_property assert hybrid_prop_float.description is None # hybrid_property_bool - hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + hybrid_prop_bool = ReporterType._meta.fields["hybrid_prop_bool"] assert hybrid_prop_bool.type == Boolean # "doc" is ignored by hybrid_property assert hybrid_prop_bool.description is None # hybrid_property_list - hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + hybrid_prop_list = ReporterType._meta.fields["hybrid_prop_list"] assert hybrid_prop_list.type == List(Int) # "doc" is ignored by hybrid_property assert hybrid_prop_list.description is None # hybrid_prop_with_doc - hybrid_prop_with_doc = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc.type == String # docstring is picked up from hybrid_prop_with_doc assert hybrid_prop_with_doc.description == "Docstring test" # relationship - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType assert favorite_article_field.type().description is None @@ -172,7 +191,7 @@ def convert_composite_class(composite, registry): class ReporterMixin(object): # columns first_name = ORMField(required=True) - last_name = ORMField(description='Overridden') + last_name = ORMField(description="Overridden") class ReporterType(SQLAlchemyObjectType, ReporterMixin): class Meta: @@ -180,8 +199,8 @@ class Meta: interfaces = (Node,) # columns - email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(model_attr='email', type_=Int) + email = ORMField(deprecation_reason="Overridden") + email_v2 = ORMField(model_attr="email", type_=Int) # column_property column_prop = ORMField(type_=String) @@ -190,13 +209,13 @@ class Meta: composite_prop = ORMField() # hybrid_property - hybrid_prop_with_doc = ORMField(description='Overridden') - hybrid_prop = ORMField(description='Overridden') + hybrid_prop_with_doc = ORMField(description="Overridden") + hybrid_prop = ORMField(description="Overridden") # relationships - favorite_article = ORMField(description='Overridden') - articles = ORMField(deprecation_reason='Overridden') - pets = ORMField(description='Overridden') + favorite_article = ORMField(description="Overridden") + articles = ORMField(deprecation_reason="Overridden") + pets = ORMField(description="Overridden") class ArticleType(SQLAlchemyObjectType): class Meta: @@ -209,99 +228,101 @@ class Meta: interfaces = (Node,) use_connection = False - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Fields from ReporterMixin - "first_name", - "last_name", - # Fields from ReporterType - "email", - "email_v2", - "column_prop", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "favorite_article", - "articles", - "pets", - # Then the automatic SQLAlchemy fields - "id", - "favorite_pet_kind", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - ]) - - first_name_field = ReporterType._meta.fields['first_name'] + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Fields from ReporterMixin + "first_name", + "last_name", + # Fields from ReporterType + "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "favorite_article", + "articles", + "pets", + # Then the automatic SQLAlchemy fields + "id", + "favorite_pet_kind", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ] + ) + + first_name_field = ReporterType._meta.fields["first_name"] assert isinstance(first_name_field.type, NonNull) assert first_name_field.type.of_type == String assert first_name_field.description == "First name" assert first_name_field.deprecation_reason is None - last_name_field = ReporterType._meta.fields['last_name'] + last_name_field = ReporterType._meta.fields["last_name"] assert last_name_field.type == String assert last_name_field.description == "Overridden" assert last_name_field.deprecation_reason is None - email_field = ReporterType._meta.fields['email'] + email_field = ReporterType._meta.fields["email"] assert email_field.type == String assert email_field.description == "Email" assert email_field.deprecation_reason == "Overridden" - email_field_v2 = ReporterType._meta.fields['email_v2'] + email_field_v2 = ReporterType._meta.fields["email_v2"] assert email_field_v2.type == Int assert email_field_v2.description == "Email" assert email_field_v2.deprecation_reason is None - hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop_field = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop_field.type == String assert hybrid_prop_field.description == "Overridden" assert hybrid_prop_field.deprecation_reason is None - hybrid_prop_with_doc_field = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc_field = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc_field.type == String assert hybrid_prop_with_doc_field.description == "Overridden" assert hybrid_prop_with_doc_field.deprecation_reason is None - column_prop_field_v2 = ReporterType._meta.fields['column_prop'] + column_prop_field_v2 = ReporterType._meta.fields["column_prop"] assert column_prop_field_v2.type == String assert column_prop_field_v2.description is None assert column_prop_field_v2.deprecation_reason is None - composite_prop_field = ReporterType._meta.fields['composite_prop'] + composite_prop_field = ReporterType._meta.fields["composite_prop"] assert composite_prop_field.type == String assert composite_prop_field.description is None assert composite_prop_field.deprecation_reason is None - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType - assert favorite_article_field.type().description == 'Overridden' + assert favorite_article_field.type().description == "Overridden" - articles_field = ReporterType._meta.fields['articles'] + articles_field = ReporterType._meta.fields["articles"] assert isinstance(articles_field, Dynamic) assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) assert articles_field.type().deprecation_reason == "Overridden" - pets_field = ReporterType._meta.fields['pets'] + pets_field = ReporterType._meta.fields["pets"] assert isinstance(pets_field, Dynamic) assert isinstance(pets_field.type().type, List) assert pets_field.type().type.of_type == PetType - assert pets_field.type().description == 'Overridden' + assert pets_field.type().description == "Overridden" def test_invalid_model_attr(): err_msg = ( - "Cannot map ORMField to a model attribute.\n" - "Field: 'ReporterType.first_name'" + "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - first_name = ORMField(model_attr='does_not_exist') + first_name = ORMField(model_attr="does_not_exist") def test_only_fields(): @@ -325,29 +346,32 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - "first_name", - "last_name", - "column_prop", - "email", - "favorite_pet_kind", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + "first_name", + "last_name", + "column_prop", + "email", + "favorite_pet_kind", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + "pets", + "articles", + "favorite_article", + ] + ) def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -372,14 +396,14 @@ def test_resolvers(session): class ReporterMixin(object): def resolve_id(root, _info): - return 'ID' + return "ID" class ReporterType(ReporterMixin, SQLAlchemyObjectType): class Meta: model = Reporter email = ORMField() - email_v2 = ORMField(model_attr='email') + email_v2 = ORMField(model_attr="email") favorite_pet_kind = Field(String) favorite_pet_kind_v2 = Field(String) @@ -387,10 +411,10 @@ def resolve_last_name(root, _info): return root.last_name.upper() def resolve_email_v2(root, _info): - return root.email + '_V2' + return root.email + "_V2" def resolve_favorite_pet_kind_v2(root, _info): - return str(root.favorite_pet_kind) + '_V2' + return str(root.favorite_pet_kind) + "_V2" class Query(ObjectType): reporter = Field(ReporterType) @@ -398,12 +422,18 @@ class Query(ObjectType): def resolve_reporter(self, _info): return session.query(Reporter).first() - reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat') + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) session.add(reporter) session.commit() schema = Schema(query=Query) - result = schema.execute(""" + result = schema.execute( + """ query { reporter { id @@ -415,27 +445,29 @@ def resolve_reporter(self, _info): favoritePetKindV2 } } - """) + """ + ) assert not result.errors # Custom resolver on a base class - assert result.data['reporter']['id'] == 'ID' + assert result.data["reporter"]["id"] == "ID" # Default field + default resolver - assert result.data['reporter']['firstName'] == 'first_name' + assert result.data["reporter"]["firstName"] == "first_name" # Default field + custom resolver - assert result.data['reporter']['lastName'] == 'LAST_NAME' + assert result.data["reporter"]["lastName"] == "LAST_NAME" # ORMField + default resolver - assert result.data['reporter']['email'] == 'email' + assert result.data["reporter"]["email"] == "email" # ORMField + custom resolver - assert result.data['reporter']['emailV2'] == 'email_V2' + assert result.data["reporter"]["emailV2"] == "email_V2" # Field + default resolver - assert result.data['reporter']['favoritePetKind'] == 'cat' + assert result.data["reporter"]["favoritePetKind"] == "cat" # Field + custom resolver - assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2' + assert result.data["reporter"]["favoritePetKindV2"] == "cat_V2" # Test Custom SQLAlchemyObjectType Implementation + def test_custom_objecttype_registered(): class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): class Meta: @@ -463,9 +495,9 @@ class Meta: def __init_subclass_with_meta__(cls, custom_option=None, **options): _meta = CustomOptions(cls) _meta.custom_option = custom_option - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super( + SQLAlchemyObjectTypeWithCustomOptions, cls + ).__init_subclass_with_meta__(_meta=_meta, **options) class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): class Meta: @@ -479,6 +511,7 @@ class Meta: # Tests for connection_field_factory + class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): pass @@ -494,7 +527,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), UnsortedSQLAlchemyConnectionField + ) def test_custom_connection_field_factory(): @@ -514,7 +549,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_registerConnectionFieldFactory(): @@ -531,7 +568,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_unregisterConnectionFieldFactory(): @@ -549,7 +588,9 @@ class Meta: model = Article interfaces = (Node,) - assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert not isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_createConnectionField(): @@ -557,7 +598,7 @@ def test_deprecated_createConnectionField(): createConnectionField(None) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unique_errors_propagate(class_mapper_mock): # Define unique error to detect class UniqueError(Exception): @@ -569,9 +610,11 @@ class UniqueError(Exception): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleOne(SQLAlchemyObjectType): class Meta(object): model = Article + except UniqueError as e: error = e @@ -580,7 +623,7 @@ class Meta(object): assert isinstance(error, UniqueError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_argument_errors_propagate(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError @@ -588,9 +631,11 @@ def test_argument_errors_propagate(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleTwo(SQLAlchemyObjectType): class Meta(object): model = Article + except sqlalchemy.exc.ArgumentError as e: error = e @@ -599,7 +644,7 @@ class Meta(object): assert isinstance(error, sqlalchemy.exc.ArgumentError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unmapped_errors_reformat(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) @@ -607,9 +652,11 @@ def test_unmapped_errors_reformat(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleThree(SQLAlchemyObjectType): class Meta(object): model = Article + except ValueError as e: error = e diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index de359e05..75328280 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -3,8 +3,14 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import (DummyImport, get_session, sort_argument_for_model, - sort_enum_for_model, to_enum_value_name, to_type_name) +from ..utils import ( + DummyImport, + get_session, + sort_argument_for_model, + sort_enum_for_model, + to_enum_value_name, + to_type_name, +) from .models import Base, Editor, Pet @@ -96,9 +102,11 @@ class MultiplePK(Base): with pytest.warns(DeprecationWarning): arg = sort_argument_for_model(MultiplePK) - assert set(arg.default_value) == set( - (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") - ) + assert set(arg.default_value) == { + MultiplePK.foo.name + "_asc", + MultiplePK.bar.name + "_asc", + } + def test_dummy_import(): dummy_module = DummyImport() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e6c3d14c..fe48e9eb 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -2,8 +2,7 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty) +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound from graphene import Field @@ -12,12 +11,17 @@ from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -from .converter import (convert_sqlalchemy_column, - convert_sqlalchemy_composite, - convert_sqlalchemy_hybrid_method, - convert_sqlalchemy_relationship) -from .enums import (enum_for_field, sort_argument_for_object_type, - sort_enum_for_object_type) +from .converter import ( + convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship, +) +from .enums import ( + enum_for_field, + sort_argument_for_object_type, + sort_enum_for_object_type, +) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import get_query, is_mapped_class, is_mapped_instance @@ -25,15 +29,15 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - _creation_counter=None, - **field_kwargs + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + _creation_counter=None, + **field_kwargs ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -76,20 +80,28 @@ class Meta: super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { - 'model_attr': model_attr, - 'type_': type_, - 'required': required, - 'description': description, - 'deprecation_reason': deprecation_reason, - 'batching': batching, + "model_attr": model_attr, + "type_": type_, + "required": required, + "description": description, + "deprecation_reason": deprecation_reason, + "batching": batching, + } + common_kwargs = { + kwarg: value for kwarg, value in common_kwargs.items() if value is not None } - common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs self.kwargs.update(common_kwargs) def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -112,15 +124,20 @@ def construct_fields( all_model_attrs = OrderedDict( inspected_model.column_attrs.items() + inspected_model.composites.items() - + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() - if isinstance(item, hybrid_property)] + + [ + (name, item) + for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property) + ] + inspected_model.relationships.items() ) # Filter out excluded fields auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): - if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): + if (only_fields and attr_name not in only_fields) or ( + attr_name in exclude_fields + ): continue auto_orm_field_names.append(attr_name) @@ -135,13 +152,15 @@ def construct_fields( # Set the model_attr if not set for orm_field_name, orm_field in custom_orm_fields_items: - attr_name = orm_field.kwargs.get('model_attr', orm_field_name) + attr_name = orm_field.kwargs.get("model_attr", orm_field_name) if attr_name not in all_model_attrs: - raise ValueError(( - "Cannot map ORMField to a model attribute.\n" - "Field: '{}.{}'" - ).format(obj_type.__name__, orm_field_name,)) - orm_field.kwargs['model_attr'] = attr_name + raise ValueError( + ("Cannot map ORMField to a model attribute.\n" "Field: '{}.{}'").format( + obj_type.__name__, + orm_field_name, + ) + ) + orm_field.kwargs["model_attr"] = attr_name # Merge automatic fields with custom ORM fields orm_fields = OrderedDict(custom_orm_fields_items) @@ -153,27 +172,38 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): - attr_name = orm_field.kwargs.pop('model_attr') + attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] - resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( + obj_type, attr_name + ) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_column( + attr, registry, resolver, **orm_field.kwargs + ) elif isinstance(attr, RelationshipProperty): - batching_ = orm_field.kwargs.pop('batching', batching) + batching_ = orm_field.kwargs.pop("batching", batching) field = convert_sqlalchemy_relationship( - attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) + attr, + obj_type, + connection_field_factory, + batching_, + orm_field_name, + **orm_field.kwargs + ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " - "Field: {}.{}".format(obj_type.__name__, orm_field_name)) + "Field: {}.{}".format(obj_type.__name__, orm_field_name) + ) field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: - raise Exception('Property type is not supported') # Should never happen + raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field @@ -191,26 +221,27 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options ): # Make sure model is a valid SQLAlchemy model if not is_mapped_class(model): raise ValueError( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + "You need to pass a valid SQLAlchemy Model in " + '{}.Meta, received "{}".'.format(cls.__name__, model) ) if not registry: @@ -222,7 +253,9 @@ def __init_subclass_with_meta__( ).format(cls.__name__, registry) if only_fields and exclude_fields: - raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") + raise ValueError( + "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." + ) sqla_fields = yank_fields_from_attrs( construct_fields( @@ -240,7 +273,7 @@ def __init_subclass_with_meta__( if use_connection is None and interfaces: use_connection = any( - (issubclass(interface, Node) for interface in interfaces) + issubclass(interface, Node) for interface in interfaces ) if use_connection and not connection: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 27117c0c..54bb8402 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -153,12 +153,16 @@ def sort_argument_for_model(cls, has_default=True): def is_sqlalchemy_version_less_than(version_string): # pragma: no cover """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) def is_graphene_version_less_than(version_string): # pragma: no cover """Check the installed graphene version""" - return pkg_resources.get_distribution('graphene').parsed_version < pkg_resources.parse_version(version_string) + return pkg_resources.get_distribution( + "graphene" + ).parsed_version < pkg_resources.parse_version(version_string) class singledispatchbymatchfunction: @@ -182,7 +186,6 @@ def __call__(self, *args, **kwargs): return self.default(*args, **kwargs) def register(self, matcher_function: Callable[[Any], bool]): - def grab_function_from_outside(f): self.registry[matcher_function] = f return self @@ -192,7 +195,7 @@ def grab_function_from_outside(f): def value_equals(value): """A simple function that makes the equality based matcher functions for - SingleDispatchByMatchFunction prettier""" + SingleDispatchByMatchFunction prettier""" return lambda x: x == value @@ -208,8 +211,14 @@ def safe_isinstance_checker(arg): def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: from graphene_sqlalchemy.registry import get_global_registry + try: - return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + return next( + filter( + lambda x: x.__name__ == model_name, + list(get_global_registry()._registry.keys()), + ) + ) except StopIteration: pass diff --git a/setup.cfg b/setup.cfg index f36334d8..e479585c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,10 +2,12 @@ test=pytest [flake8] -exclude = setup.py,docs/*,examples/*,tests +ignore = E203,W503 +exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs,setup.py,docs/*,examples/*,tests max-line-length = 120 [isort] +profile = black no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy From 0a765a1a0324f0c48e55ae2f0264dc95f094bc1b Mon Sep 17 00:00:00 2001 From: Cadu Date: Tue, 13 Sep 2022 04:22:08 -0300 Subject: [PATCH 34/67] Made Relationshiploader utilize the new and improved DataLoader implementation housed inside graphene, if possible (graphene >=3.1.1) (#362) --- graphene_sqlalchemy/batching.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 275d5904..0800d0e2 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -2,7 +2,6 @@ from asyncio import get_event_loop from typing import Any, Dict -import aiodataloader import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext @@ -10,7 +9,21 @@ from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than -class RelationshipLoader(aiodataloader.DataLoader): +def get_data_loader_impl() -> Any: # pragma: no cover + """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, + aiodataloader is used in conjunction with older versions of graphene""" + if is_graphene_version_less_than("3.1.1"): + from aiodataloader import DataLoader + else: + from graphene.utils.dataloader import DataLoader + + return DataLoader + + +DataLoader = get_data_loader_impl() + + +class RelationshipLoader(DataLoader): cache = False def __init__(self, relationship_prop, selectin_loader): @@ -92,20 +105,6 @@ async def batch_load_fn(self, parents): ] = {} -def get_data_loader_impl() -> Any: # pragma: no cover - """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, - aiodataloader is used in conjunction with older versions of graphene""" - if is_graphene_version_less_than("3.1.1"): - from aiodataloader import DataLoader - else: - from graphene.utils.dataloader import DataLoader - - return DataLoader - - -DataLoader = get_data_loader_impl() - - def get_batch_resolver(relationship_prop): """Get the resolve function for the given relationship.""" From 75abf0b4b3af24c60df87852c18493174fc4daf3 Mon Sep 17 00:00:00 2001 From: Cadu Date: Sat, 1 Oct 2022 09:36:59 -0300 Subject: [PATCH 35/67] feat: Add support for UUIDs in `@hybrid_property`-ies (#363) --- graphene_sqlalchemy/converter.py | 6 ++++++ graphene_sqlalchemy/tests/models.py | 16 ++++++++++++++++ graphene_sqlalchemy/tests/test_converter.py | 4 ++++ 3 files changed, 26 insertions(+) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index d1873c2b..d3ae8123 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,6 +1,7 @@ import datetime import sys import typing +import uuid import warnings from decimal import Decimal from functools import singledispatch @@ -398,6 +399,11 @@ def convert_sqlalchemy_hybrid_property_type_time(arg): return graphene.Time +@convert_sqlalchemy_hybrid_property_type.register(value_equals(uuid.UUID)) +def convert_sqlalchemy_hybrid_property_type_uuid(arg): + return graphene.UUID + + def is_union(arg) -> bool: if sys.version_info >= (3, 10): from types import UnionType diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index fd5d3b21..b433982d 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -2,6 +2,7 @@ import datetime import enum +import uuid from decimal import Decimal from typing import List, Optional, Tuple @@ -267,6 +268,21 @@ def hybrid_prop_self_referential_list(self) -> List["ShoppingCart"]: def hybrid_prop_optional_self_referential(self) -> Optional["ShoppingCart"]: return None + # UUIDS + @hybrid_property + def hybrid_prop_uuid(self) -> uuid.UUID: + return uuid.uuid4() + + @hybrid_property + def hybrid_prop_uuid_list(self) -> List[uuid.UUID]: + return [ + uuid.uuid4(), + ] + + @hybrid_property + def hybrid_prop_optional_uuid(self) -> Optional[uuid.UUID]: + return None + class KeyedModel(Base): __tablename__ = "test330" diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 812b4cea..b9a1c152 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -673,6 +673,10 @@ class Meta: "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), # Optionals "hybrid_prop_optional_self_referential": ShoppingCartType, + # UUIDs + "hybrid_prop_uuid": graphene.UUID, + "hybrid_prop_optional_uuid": graphene.UUID, + "hybrid_prop_uuid_list": graphene.List(graphene.UUID), } assert sorted(list(ShoppingCartType._meta.fields.keys())) == sorted( From 8bfa1e92003aa801481b50e0bd4603445570c066 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 21 Nov 2022 21:15:10 +0100 Subject: [PATCH 36/67] chore: limit CI runs to master pushes & PRs (#366) --- .github/workflows/tests.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index de78190d..428eca1d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,7 +1,12 @@ name: Tests -on: [push, pull_request] - +on: + push: + branches: + - 'master' + pull_request: + branches: + - '*' jobs: test: runs-on: ubuntu-latest From 2edeae98b79dc98d9fe9214df3f4ccd9d2bdbe30 Mon Sep 17 00:00:00 2001 From: Frederick Polgardy Date: Mon, 28 Nov 2022 10:13:03 -0700 Subject: [PATCH 37/67] feat: Support GQL interfaces for polymorphic SQLA models (#365) * Support GQL interfaces for polymorphic SQLA models using SQLALchemyInterface and SQLAlchemyBase. fixes #313 Co-authored-by: Erik Wrede Co-authored-by: Erik Wrede --- docs/inheritance.rst | 107 +++++++++++++++ graphene_sqlalchemy/__init__.py | 3 +- graphene_sqlalchemy/registry.py | 21 +-- graphene_sqlalchemy/tests/models.py | 36 +++++ graphene_sqlalchemy/tests/test_query.py | 67 +++++++++- graphene_sqlalchemy/tests/test_registry.py | 4 +- graphene_sqlalchemy/tests/test_types.py | 104 ++++++++++++++- graphene_sqlalchemy/types.py | 145 +++++++++++++++++++-- 8 files changed, 447 insertions(+), 40 deletions(-) create mode 100644 docs/inheritance.rst diff --git a/docs/inheritance.rst b/docs/inheritance.rst new file mode 100644 index 00000000..ee16f062 --- /dev/null +++ b/docs/inheritance.rst @@ -0,0 +1,107 @@ +Inheritance Examples +==================== + +Create interfaces from inheritance relationships +------------------------------------------------ + +SQLAlchemy has excellent support for class inheritance hierarchies. +These hierarchies can be represented in your GraphQL schema by means +of interfaces_. Much like ObjectTypes, Interfaces in +Graphene-SQLAlchemy are able to infer their fields and relationships +from the attributes of their underlying SQLAlchemy model: + +.. _interfaces: https://docs.graphene-python.org/en/latest/types/interfaces/ + +.. code:: python + + from sqlalchemy import Column, Date, Integer, String + from sqlalchemy.ext.declarative import declarative_base + + import graphene + from graphene import relay + from graphene_sqlalchemy import SQLAlchemyInterface, SQLAlchemyObjectType + + Base = declarative_base() + + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + } + + class Employee(Person): + hire_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "employee", + } + + class Customer(Person): + first_purchase_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "customer", + } + + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (relay.Node, PersonType) + + class CustomerType(SQLAlchemyObjectType): + class Meta: + model = Customer + interfaces = (relay.Node, PersonType) + +Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must +be linked to an abstract Model that does not specify a `polymorphic_identity`, +because we cannot return instances of interfaces from a GraphQL query. +If Person specified a `polymorphic_identity`, instances of Person could +be inserted into and returned by the database, potentially causing +Persons to be returned to the resolvers. + +When querying on the base type, you can refer directly to common fields, +and fields on concrete implementations using the `... on` syntax: + + +.. code:: + + people { + name + birthDate + ... on EmployeeType { + hireDate + } + ... on CustomerType { + firstPurchaseDate + } + } + + +Please note that by default, the "polymorphic_on" column is *not* +generated as a field on types that use polymorphic inheritance, as +this is considered an implentation detail. The idiomatic way to +retrieve the concrete GraphQL type of an object is to query for the +`__typename` field. +To override this behavior, an `ORMField` needs to be created +for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* +as it promotes abiguous schema design + +If your SQLAlchemy model only specifies a relationship to the +base type, you will need to explicitly pass your concrete implementation +class to the Schema constructor via the `types=` argument: + +.. code:: python + + schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType]) + +See also: `Graphene Interfaces `_ diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 33345815..fb32379c 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,11 +1,12 @@ from .fields import SQLAlchemyConnectionField -from .types import SQLAlchemyObjectType +from .types import SQLAlchemyInterface, SQLAlchemyObjectType from .utils import get_query, get_session __version__ = "3.0.0b3" __all__ = [ "__version__", + "SQLAlchemyInterface", "SQLAlchemyObjectType", "SQLAlchemyConnectionField", "get_query", diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 8f2bc9e7..cc4b02b7 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -18,15 +18,10 @@ def __init__(self): self._registry_unions = {} def register(self, obj_type): + from .types import SQLAlchemyBase - from .types import SQLAlchemyObjectType - - if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType - ): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) - ) + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) assert obj_type._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # 'SQLAlchemy model "{}" already associated with ' @@ -38,14 +33,10 @@ def get_type_for_model(self, model): return self._registry.get(model) def register_orm_field(self, obj_type, field_name, orm_field): - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyBase - if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType - ): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) - ) + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index b433982d..4fe91462 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -288,3 +288,39 @@ class KeyedModel(Base): __tablename__ = "test330" id = Column(Integer(), primary_key=True) reporter_number = Column("% reporter_number", Numeric, key="reporter_number") + + +############################################ +# For interfaces +############################################ + + +class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + } + +class NonAbstractPerson(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "non_abstract_person" + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_identity": "person", + } + +class Employee(Person): + hire_date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "employee", + } diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index c7a173df..456254fc 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,10 +1,21 @@ +from datetime import date + import graphene from graphene.relay import Node from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField -from ..types import ORMField, SQLAlchemyObjectType -from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter +from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType +from .models import ( + Article, + CompositeFullName, + Editor, + Employee, + HairKind, + Person, + Pet, + Reporter, +) from .utils import to_std_dicts @@ -334,3 +345,55 @@ class Mutation(graphene.ObjectType): assert not result.errors result = to_std_dicts(result.data) assert result == expected + + +def add_person_data(session): + bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1)) + session.add(bob) + joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1)) + session.add(joe) + jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1)) + session.add(jen) + session.commit() + + +def test_interface_query_on_base_type(session): + add_person_data(session) + + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + def resolve_people(self, _info): + return session.query(Person).all() + + schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType]) + result = schema.execute( + """ + query { + people { + __typename + name + birthDate + ... on EmployeeType { + hireDate + } + } + } + """ + ) + + assert not result.errors + assert len(result.data["people"]) == 3 + assert result.data["people"][0]["__typename"] == "EmployeeType" + assert result.data["people"][0]["name"] == "Bob" + assert result.data["people"][0]["birthDate"] == "1990-01-01" + assert result.data["people"][0]["hireDate"] == "2015-01-01" diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index cb7e9034..68b5404f 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -28,7 +28,7 @@ def test_register_incorrect_object_type(): class Spam: pass - re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + re_err = "Expected SQLAlchemyBase, but got: .*Spam" with pytest.raises(TypeError, match=re_err): reg.register(Spam) @@ -51,7 +51,7 @@ def test_register_orm_field_incorrect_types(): class Spam: pass - re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + re_err = "Expected SQLAlchemyBase, but got: .*Spam" with pytest.raises(TypeError, match=re_err): reg.register_orm_field(Spam, "name", Pet.name) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 4afb120d..813fb134 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,9 +1,9 @@ +import re from unittest import mock import pytest import sqlalchemy.exc import sqlalchemy.orm.exc - from graphene import ( Boolean, Dynamic, @@ -20,6 +20,7 @@ ) from graphene.relay import Connection +from .models import Article, CompositeFullName, Employee, Person, Pet, Reporter, NonAbstractPerson from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import ( @@ -29,14 +30,17 @@ registerConnectionFieldFactory, unregisterConnectionFieldFactory, ) -from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions -from .models import Article, CompositeFullName, Pet, Reporter +from ..types import ( + ORMField, + SQLAlchemyInterface, + SQLAlchemyObjectType, + SQLAlchemyObjectTypeOptions, +) def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): - class Character1(SQLAlchemyObjectType): pass @@ -44,7 +48,6 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): - class Character(SQLAlchemyObjectType): class Meta: model = 1 @@ -317,7 +320,6 @@ def test_invalid_model_attr(): "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -371,7 +373,6 @@ class Meta: def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): - class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -509,6 +510,95 @@ class Meta: assert ReporterWithCustomOptions._meta.custom_option == "custom_option" +def test_interface_with_polymorphic_identity(): + with pytest.raises(AssertionError, + match=re.escape('PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")')): + class PersonType(SQLAlchemyInterface): + class Meta: + model = NonAbstractPerson + + +def test_interface_inherited_fields(): + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # `type` should *not* be in this list because it's the polymorphic_on + # discriminator for Person + assert list(EmployeeType._meta.fields.keys()) == [ + "id", + "name", + "birth_date", + "hire_date", + ] + + +def test_interface_type_field_orm_override(): + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + type = ORMField() + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # type should be in this list because we used ORMField + # to force its presence on the model + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ + "id", + "name", + "type", + "birth_date", + "hire_date", + ]) + + +def test_interface_custom_resolver(): + class PersonType(SQLAlchemyInterface): + class Meta: + model = Person + + custom_field = Field(String) + + class EmployeeType(SQLAlchemyObjectType): + class Meta: + model = Employee + interfaces = (Node, PersonType) + + assert PersonType in EmployeeType._meta.interfaces + + name_field = EmployeeType._meta.fields["name"] + assert name_field.type == String + + # type should be in this list because we used ORMField + # to force its presence on the model + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ + "id", + "name", + "custom_field", + "birth_date", + "hire_date", + ]) + + # Tests for connection_field_factory diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index fe48e9eb..e0ada38e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -7,6 +7,8 @@ from graphene import Field from graphene.relay import Connection, Node +from graphene.types.base import BaseType +from graphene.types.interface import Interface, InterfaceOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType @@ -94,6 +96,18 @@ class Meta: self.kwargs.update(common_kwargs) +def get_polymorphic_on(model): + """ + Check whether this model is a polymorphic type, and if so return the name + of the discriminator field (`polymorphic_on`), so that it won't be automatically + generated as an ORMField. + """ + if hasattr(model, "__mapper__") and model.__mapper__.polymorphic_on is not None: + polymorphic_on = model.__mapper__.polymorphic_on + if isinstance(polymorphic_on, sqlalchemy.Column): + return polymorphic_on.name + + def construct_fields( obj_type, model, @@ -133,10 +147,13 @@ def construct_fields( ) # Filter out excluded fields + polymorphic_on = get_polymorphic_on(model) auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): - if (only_fields and attr_name not in only_fields) or ( - attr_name in exclude_fields + if ( + (only_fields and attr_name not in only_fields) + or (attr_name in exclude_fields) + or attr_name == polymorphic_on ): continue auto_orm_field_names.append(attr_name) @@ -211,14 +228,12 @@ def construct_fields( return fields -class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): - model = None # type: sqlalchemy.Model - registry = None # type: sqlalchemy.Registry - connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] - id = None # type: str - +class SQLAlchemyBase(BaseType): + """ + This class contains initialization code that is common to both ObjectTypes + and Interfaces. You typically don't need to use it directly. + """ -class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__( cls, @@ -237,6 +252,11 @@ def __init_subclass_with_meta__( _meta=None, **options ): + # We always want to bypass this hook unless we're defining a concrete + # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. + if not _meta: + return + # Make sure model is a valid SQLAlchemy model if not is_mapped_class(model): raise ValueError( @@ -290,9 +310,6 @@ def __init_subclass_with_meta__( "The connection must be a Connection. Received {}" ).format(connection.__name__) - if not _meta: - _meta = SQLAlchemyObjectTypeOptions(cls) - _meta.model = model _meta.registry = registry @@ -306,7 +323,7 @@ def __init_subclass_with_meta__( cls.connection = connection # Public way to get the connection - super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + super(SQLAlchemyBase, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options ) @@ -345,3 +362,105 @@ def enum_for_field(cls, field_name): sort_enum = classmethod(sort_enum_for_object_type) sort_argument = classmethod(sort_argument_for_object_type) + + +class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] + id = None # type: str + + +class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): + """ + This type represents the GraphQL ObjectType. It reflects on the + given SQLAlchemy model, and automatically generates an ObjectType + using the column and relationship information defined there. + + Usage: + + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String()) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel + """ + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + if not _meta: + _meta = SQLAlchemyObjectTypeOptions(cls) + + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + +class SQLAlchemyInterfaceOptions(InterfaceOptions): + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] + id = None # type: str + + +class SQLAlchemyInterface(SQLAlchemyBase, Interface): + """ + This type represents the GraphQL Interface. It reflects on the + given SQLAlchemy model, and automatically generates an Interface + using the column and relationship information defined there. This + is used to construct interface relationships based on polymorphic + inheritance hierarchies in SQLAlchemy. + + Please note that by default, the "polymorphic_on" column is *not* + generated as a field on types that use polymorphic inheritance, as + this is considered an implentation detail. The idiomatic way to + retrieve the concrete GraphQL type of an object is to query for the + `__typename` field. + + Usage (using joined table inheritance): + + class MyBaseModel(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + + __mapper_args__ = { + "polymorphic_on": type, + } + + class MyChildModel(Base): + date = Column(Date()) + + __mapper_args__ = { + "polymorphic_identity": "child", + } + + class MyBaseType(SQLAlchemyInterface): + class Meta: + model = MyBaseModel + + class MyChildType(SQLAlchemyObjectType): + class Meta: + model = MyChildModel + interfaces = (MyBaseType,) + """ + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + if not _meta: + _meta = SQLAlchemyInterfaceOptions(cls) + + super(SQLAlchemyInterface, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + # make sure that the model doesn't have a polymorphic_identity defined + if hasattr(_meta.model, "__mapper__"): + polymorphic_identity = _meta.model.__mapper__.polymorphic_identity + assert ( + polymorphic_identity is None + ), '{}: An interface cannot map to a concrete type (polymorphic_identity is "{}")'.format( + cls.__name__, polymorphic_identity + ) From 32d0d184c74386886b8e67763c6b7db836b323ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jendrik=20J=C3=B6rdening?= Date: Wed, 21 Dec 2022 14:08:32 +0100 Subject: [PATCH 38/67] feat: support for async sessions (#350) * feat(async): add support for async sessions This PR brings experimental support for async sessions in SQLAlchemyConnectionFields. Batching is not yet supported and will be subject to a later PR. Co-authored-by: Jendrik Co-authored-by: Erik Wrede --- .github/workflows/tests.yml | 2 +- docs/inheritance.rst | 66 +++- graphene_sqlalchemy/batching.py | 13 +- graphene_sqlalchemy/fields.py | 50 ++- graphene_sqlalchemy/tests/conftest.py | 48 ++- graphene_sqlalchemy/tests/models.py | 17 +- graphene_sqlalchemy/tests/models_batching.py | 91 +++++ graphene_sqlalchemy/tests/test_batching.py | 360 ++++++++++-------- graphene_sqlalchemy/tests/test_benchmark.py | 127 ++++-- graphene_sqlalchemy/tests/test_enums.py | 5 +- graphene_sqlalchemy/tests/test_query.py | 190 +++++++-- graphene_sqlalchemy/tests/test_query_enums.py | 91 ++++- graphene_sqlalchemy/tests/test_sort_enums.py | 16 +- graphene_sqlalchemy/tests/test_types.py | 103 +++-- graphene_sqlalchemy/tests/utils.py | 9 + graphene_sqlalchemy/types.py | 31 +- graphene_sqlalchemy/utils.py | 39 +- setup.py | 5 +- 18 files changed, 931 insertions(+), 332 deletions(-) create mode 100644 graphene_sqlalchemy/tests/models_batching.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 428eca1d..7632fd38 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,6 @@ name: Tests -on: +on: push: branches: - 'master' diff --git a/docs/inheritance.rst b/docs/inheritance.rst index ee16f062..74732162 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -3,7 +3,7 @@ Inheritance Examples Create interfaces from inheritance relationships ------------------------------------------------ - +.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. SQLAlchemy has excellent support for class inheritance hierarchies. These hierarchies can be represented in your GraphQL schema by means of interfaces_. Much like ObjectTypes, Interfaces in @@ -40,7 +40,7 @@ from the attributes of their underlying SQLAlchemy model: __mapper_args__ = { "polymorphic_identity": "employee", } - + class Customer(Person): first_purchase_date = Column(Date()) @@ -56,17 +56,17 @@ from the attributes of their underlying SQLAlchemy model: class Meta: model = Employee interfaces = (relay.Node, PersonType) - + class CustomerType(SQLAlchemyObjectType): class Meta: model = Customer interfaces = (relay.Node, PersonType) -Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must -be linked to an abstract Model that does not specify a `polymorphic_identity`, -because we cannot return instances of interfaces from a GraphQL query. -If Person specified a `polymorphic_identity`, instances of Person could -be inserted into and returned by the database, potentially causing +Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must +be linked to an abstract Model that does not specify a `polymorphic_identity`, +because we cannot return instances of interfaces from a GraphQL query. +If Person specified a `polymorphic_identity`, instances of Person could +be inserted into and returned by the database, potentially causing Persons to be returned to the resolvers. When querying on the base type, you can refer directly to common fields, @@ -85,15 +85,19 @@ and fields on concrete implementations using the `... on` syntax: firstPurchaseDate } } - - + + +.. danger:: + When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications. + See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`. + Please note that by default, the "polymorphic_on" column is *not* generated as a field on types that use polymorphic inheritance, as -this is considered an implentation detail. The idiomatic way to +this is considered an implementation detail. The idiomatic way to retrieve the concrete GraphQL type of an object is to query for the -`__typename` field. +`__typename` field. To override this behavior, an `ORMField` needs to be created -for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* +for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* as it promotes abiguous schema design If your SQLAlchemy model only specifies a relationship to the @@ -103,5 +107,39 @@ class to the Schema constructor via the `types=` argument: .. code:: python schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType]) - + + See also: `Graphene Interfaces `_ + +Eager Loading & Using with AsyncSession +-------------------- +When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. +This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. +To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: + +.. code:: python + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session + } + +Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers: + +.. code:: python + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + async def resolve_people(self, _info): + return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all() + +Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR. + +For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs `_. diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 0800d0e2..23b6712e 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than def get_data_loader_impl() -> Any: # pragma: no cover @@ -71,19 +71,19 @@ async def batch_load_fn(self, parents): # For our purposes, the query_context will only used to get the session query_context = None - if is_sqlalchemy_version_less_than("1.4"): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: parent_mapper_query = session.query(parent_mapper.entity) query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than("1.4"): + else: + query_context = QueryContext(session.query(parent_mapper.entity)) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, None, child_mapper, + None, ) else: self.selectin_loader._load_for_path( @@ -92,7 +92,6 @@ async def batch_load_fn(self, parents): states, None, child_mapper, - None, ) return [getattr(parent, self.relationship_prop.key) for parent in parents] diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 2cb53c55..6dbc134f 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,7 +11,10 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class SQLAlchemyConnectionField(ConnectionField): @@ -81,8 +84,49 @@ def get_query(cls, model, info, sort=None, **args): @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): + session = get_session(info.context) + if resolved is None: + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return await cls.resolve_connection_async( + connection_type, model, info, args, resolved + ) + + return get_result() + + else: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() + else: + _len = len(resolved) + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, + slice_start=0, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, + edge_type=connection_type.Edge, + page_info_type=page_info_adapter, + ) + connection.iterable = resolved + connection.length = _len + return connection + + @classmethod + async def resolve_connection_async( + cls, connection_type, model, info, args, resolved + ): + session = get_session(info.context) if resolved is None: - resolved = cls.get_query(model, info, **args) + query = cls.get_query(model, info, **args) + resolved = (await session.scalars(query)).all() if isinstance(resolved, Query): _len = resolved.count() else: @@ -179,7 +223,7 @@ def from_relationship(cls, relationship, registry, **field_kwargs): return cls( model_type.connection, resolver=get_batch_resolver(relationship), - **field_kwargs + **field_kwargs, ) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 357ad96e..89b357a4 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,14 +1,17 @@ import pytest +import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker import graphene +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry from .models import Base, CompositeFullName -test_db_url = "sqlite://" # use in-memory database for tests +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @pytest.fixture(autouse=True) @@ -22,18 +25,49 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(scope="function") -def session_factory(): - engine = create_engine(test_db_url) - Base.metadata.create_all(engine) +@pytest.fixture(params=[False, True]) +def async_session(request): + return request.param + + +@pytest.fixture +def test_db_url(async_session: bool): + if async_session: + return "sqlite+aiosqlite://" + else: + return "sqlite://" - yield sessionmaker(bind=engine) +@pytest.mark.asyncio +@pytest_asyncio.fixture(scope="function") +async def session_factory(async_session: bool, test_db_url: str): + if async_session: + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") + engine = create_async_engine(test_db_url) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) + await engine.dispose() + else: + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + +@pytest_asyncio.fixture(scope="function") +async def sync_session_factory(): + engine = create_engine("sqlite://") + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) # SQLite in-memory db is deleted when its connection is closed. # https://www.sqlite.org/inmemorydb.html engine.dispose() -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") def session(session_factory): return session_factory() diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 4fe91462..ee286585 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -20,7 +20,7 @@ ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import column_property, composite, mapper, relationship +from sqlalchemy.orm import backref, column_property, composite, mapper, relationship PetKind = Enum("cat", "dog", name="pet_kind") @@ -76,10 +76,16 @@ class Reporter(Base): email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) pets = relationship( - "Pet", secondary=association_table, backref="reporters", order_by="Pet.id" + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + lazy="selectin", ) - articles = relationship("Article", backref="reporter") - favorite_article = relationship("Article", uselist=False) + articles = relationship( + "Article", backref=backref("reporter", lazy="selectin"), lazy="selectin" + ) + favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property def hybrid_prop_with_doc(self): @@ -304,8 +310,10 @@ class Person(Base): __tablename__ = "person" __mapper_args__ = { "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session } + class NonAbstractPerson(Base): id = Column(Integer(), primary_key=True) type = Column(String()) @@ -318,6 +326,7 @@ class NonAbstractPerson(Base): "polymorphic_identity": "person", } + class Employee(Person): hire_date = Column(Date()) diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py new file mode 100644 index 00000000..6f1c42ff --- /dev/null +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import + +import enum + +from sqlalchemy import ( + Column, + Date, + Enum, + ForeignKey, + Integer, + String, + Table, + func, + select, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import column_property, relationship + +PetKind = Enum("cat", "dog", name="pet_kind") + + +class HairKind(enum.Enum): + LONG = "long" + SHORT = "short" + + +Base = declarative_base() + +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class Reporter(Base): + __tablename__ = "reporters" + + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") + favorite_pet_kind = Column(PetKind) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + ) + articles = relationship("Article", backref="reporter") + favorite_article = relationship("Article", uselist=False) + + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) + + +class Article(Base): + __tablename__ = "articles" + id = Column(Integer(), primary_key=True) + headline = Column(String(100)) + pub_date = Column(Date()) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 90df0279..5eccd5fc 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -3,15 +3,23 @@ import logging import pytest +from sqlalchemy import select import graphene from graphene import Connection, relay from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reader, Reporter -from .utils import remove_cache_miss_stat, to_std_dicts +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) +from .models_batching import Article, HairKind, Pet, Reader, Reporter +from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class MockLoggingHandler(logging.Handler): @@ -41,6 +49,44 @@ def mock_sqlalchemy_logging_handler(): sql_logger.setLevel(previous_level) +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + batching = True + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + batching = True + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -65,14 +111,20 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get("session").query(Article).all() + session = get_session(info.context) + return session.query(Article).all() def resolve_reporters(self, info): - return info.context.get("session").query(Reporter).all() + session = get_session(info.context) + return session.query(Reporter).all() return graphene.Schema(query=Query) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) + + def get_full_relay_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -107,14 +159,11 @@ class Query(graphene.ObjectType): return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than("1.2"): - pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) - - @pytest.mark.asyncio -async def test_many_to_one(session_factory): - session = session_factory() - +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_many_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -135,26 +184,43 @@ async def test_many_to_one(session_factory): session.commit() session.close() - schema = get_schema() - with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ - query { - articles { - headline - reporter { - firstName + query { + articles { + headline + reporter { + firstName + } + } } - } - } - """, + """, context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + assert len(messages) == 5 if is_sqlalchemy_version_less_than("1.3"): @@ -169,37 +235,19 @@ async def test_many_to_one(session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) assert ast.literal_eval(messages[2]) == () assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors - result = to_std_dicts(result.data) - assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], - } - @pytest.mark.asyncio -async def test_one_to_one(session_factory): - session = session_factory() - +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_one_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -220,26 +268,43 @@ async def test_one_to_one(session_factory): session.commit() session.close() - schema = get_schema() - with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + + session = sync_session_factory() result = await schema.execute_async( """ - query { - reporters { - firstName - favoriteArticle { - headline - } - } + query { + reporters { + firstName + favoriteArticle { + headline + } } + } """, context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } assert len(messages) == 5 if is_sqlalchemy_version_less_than("1.3"): @@ -254,36 +319,17 @@ async def test_one_to_one(session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) assert ast.literal_eval(messages[2]) == () assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors - result = to_std_dicts(result.data) - assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], - } - @pytest.mark.asyncio -async def test_one_to_many(session_factory): - session = session_factory() +async def test_one_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -309,7 +355,6 @@ async def test_one_to_many(session_factory): article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) - session.commit() session.close() @@ -317,7 +362,8 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -337,27 +383,6 @@ async def test_one_to_many(session_factory): ) messages = sqlalchemy_logging_handler.messages - assert len(messages) == 5 - - if is_sqlalchemy_version_less_than("1.3"): - # The batched SQL statement generated is different in 1.2.x - # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` - # See https://git.io/JewQu - sql_statements = [ - message - for message in messages - if "SELECT" in message and "JOIN articles" in message - ] - assert len(sql_statements) == 1 - return - - if not is_sqlalchemy_version_less_than("1.4"): - messages[2] = remove_cache_miss_stat(messages[2]) - messages[4] = remove_cache_miss_stat(messages[4]) - - assert ast.literal_eval(messages[2]) == () - assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors result = to_std_dicts(result.data) assert result == { @@ -398,11 +423,31 @@ async def test_one_to_many(session_factory): }, ], } + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] @pytest.mark.asyncio -async def test_many_to_many(session_factory): - session = session_factory() +async def test_many_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -430,15 +475,14 @@ async def test_many_to_many(session_factory): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -458,27 +502,6 @@ async def test_many_to_many(session_factory): ) messages = sqlalchemy_logging_handler.messages - assert len(messages) == 5 - - if is_sqlalchemy_version_less_than("1.3"): - # The batched SQL statement generated is different in 1.2.x - # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` - # See https://git.io/JewQu - sql_statements = [ - message - for message in messages - if "SELECT" in message and "JOIN pets" in message - ] - assert len(sql_statements) == 1 - return - - if not is_sqlalchemy_version_less_than("1.4"): - messages[2] = remove_cache_miss_stat(messages[2]) - messages[4] = remove_cache_miss_stat(messages[4]) - - assert ast.literal_eval(messages[2]) == () - assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors result = to_std_dicts(result.data) assert result == { @@ -520,9 +543,30 @@ async def test_many_to_many(session_factory): ], } + assert len(messages) == 5 -def test_disable_batching_via_ormfield(session_factory): - session = session_factory() + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +def test_disable_batching_via_ormfield(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -555,7 +599,7 @@ def resolve_reporters(self, info): # Test one-to-one and many-to-one relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -580,7 +624,7 @@ def resolve_reporters(self, info): # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -607,9 +651,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 -@pytest.mark.asyncio -def test_batch_sorting_with_custom_ormfield(session_factory): - session = session_factory() +def test_batch_sorting_with_custom_ormfield(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -642,7 +685,7 @@ class Meta: # Test one-to-one and many-to-one relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = schema.execute( """ query { @@ -658,7 +701,7 @@ class Meta: context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages - + assert not result.errors result = to_std_dicts(result.data) assert result == { "reporters": { @@ -685,8 +728,10 @@ class Meta: @pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_false(session_factory): - session = session_factory() +async def test_connection_factory_field_overrides_batching_is_false( + sync_session_factory, +): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -718,7 +763,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() await schema.execute_async( """ query { @@ -755,8 +800,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 1 -def test_connection_factory_field_overrides_batching_is_true(session_factory): - session = session_factory() +def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -788,7 +833,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -816,7 +861,9 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_batching_across_nested_relay_schema(session_factory): +async def test_batching_across_nested_relay_schema( + session_factory, async_session: bool +): session = session_factory() for first_name in "fgerbhjikzutzxsdfdqqa": @@ -831,8 +878,8 @@ async def test_batching_across_nested_relay_schema(session_factory): reader.articles = [article] session.add(reader) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() @@ -871,14 +918,17 @@ async def test_batching_across_nested_relay_schema(session_factory): result = to_std_dicts(result.data) select_statements = [message for message in messages if "SELECT" in message] - assert len(select_statements) == 4 - assert select_statements[-1].startswith("SELECT articles_1.id") - if is_sqlalchemy_version_less_than("1.3"): - assert select_statements[-2].startswith("SELECT reporters_1.id") - assert "WHERE reporters_1.id IN" in select_statements[-2] + if async_session: + assert len(select_statements) == 2 # TODO: Figure out why async has less calls else: - assert select_statements[-2].startswith("SELECT articles.reporter_id") - assert "WHERE articles.reporter_id IN" in select_statements[-2] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than("1.3"): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] @pytest.mark.asyncio @@ -892,8 +942,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f article_1.reporter = reporter_1 session.add(article_1) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index bb105edd..dc656f41 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,16 +1,61 @@ +import asyncio + import pytest +from sqlalchemy import select import graphene from graphene import relay from ..types import SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) from .models import Article, HairKind, Pet, Reporter +from .utils import eventually_await_session +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession if is_sqlalchemy_version_less_than("1.2"): pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -40,20 +85,30 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) -def benchmark_query(session_factory, benchmark, query): - schema = get_schema() +async def benchmark_query(session, benchmark, schema, query): + import nest_asyncio - @benchmark - def execute_query(): - result = schema.execute( - query, - context_value={"session": session_factory()}, + nest_asyncio.apply() + loop = asyncio.get_event_loop() + result = benchmark( + lambda: loop.run_until_complete( + schema.execute_async(query, context_value={"session": session}) ) - assert not result.errors + ) + assert not result.errors + + +@pytest.fixture(params=[get_schema, get_async_schema]) +def schema_provider(request, async_session): + if async_session and request.param == get_schema: + pytest.skip("Cannot test sync schema with async sessions") + return request.param -def test_one_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_one(session_factory, benchmark, schema_provider): session = session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", @@ -72,12 +127,13 @@ def test_one_to_one(session_factory, benchmark): article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { @@ -91,9 +147,10 @@ def test_one_to_one(session_factory, benchmark): ) -def test_many_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_one(session_factory, benchmark, schema_provider): session = session_factory() - + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -110,13 +167,14 @@ def test_many_to_one(session_factory, benchmark): article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) + await eventually_await_session(session, "flush") + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - session.commit() - session.close() - - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { articles { @@ -130,8 +188,10 @@ def test_many_to_one(session_factory, benchmark): ) -def test_one_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_many(session_factory, benchmark, schema_provider): session = session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", @@ -158,12 +218,13 @@ def test_one_to_many(session_factory, benchmark): article_4.reporter = reporter_2 session.add(article_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { @@ -181,9 +242,10 @@ def test_one_to_many(session_factory, benchmark): ) -def test_many_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_many(session_factory, benchmark, schema_provider): session = session_factory() - + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -211,12 +273,13 @@ def test_many_to_many(session_factory, benchmark): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index cd97a00e..3de6904b 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -85,7 +85,10 @@ class Meta: assert enum._meta.name == "PetKind" assert [ (key, value.value) for key, value in enum._meta.enum.__members__.items() - ] == [("CAT", "cat"), ("DOG", "dog")] + ] == [ + ("CAT", "cat"), + ("DOG", "dog"), + ] enum2 = enum_for_field(PetType, "pet_kind") assert enum2 is enum enum2 = PetType.enum_for_field("pet_kind") diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 456254fc..055a87f8 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,11 +1,15 @@ from datetime import date +import pytest +from sqlalchemy import select + import graphene from graphene.relay import Node from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from .models import ( Article, CompositeFullName, @@ -16,10 +20,13 @@ Pet, Reporter, ) -from .utils import to_std_dicts +from .utils import eventually_await_session, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession -def add_test_data(session): +async def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) @@ -35,11 +42,12 @@ def add_test_data(session): session.add(pet) editor = Editor(name="Jack") session.add(editor) - session.commit() + await eventually_await_session(session, "commit") -def test_query_fields(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_fields(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -53,10 +61,16 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) query = """ @@ -82,14 +96,15 @@ def resolve_reporters(self, _info): "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_query_node(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_node_sync(session): + await add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -111,6 +126,14 @@ class Query(graphene.ObjectType): all_articles = SQLAlchemyConnectionField(ArticleNode.connection) def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + return session.query(Reporter).first() query = """ @@ -154,14 +177,100 @@ def resolve_reporter(self, _info): "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + result = schema.execute(query, context_value={"session": session}) + assert result.errors + else: + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +@pytest.mark.asyncio +async def test_query_node_async(session): + await add_test_data(session) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + node = Node.Field() + reporter = graphene.Field(ReporterNode) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) + + def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + + return session.query(Reporter).first() + + query = """ + query { + reporter { + id + firstName + articles { + edges { + node { + headline + } + } + } + } + allArticles { + edges { + node { + headline + } + } + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + """ + expected = { + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "John", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_orm_field(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_orm_field(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -187,7 +296,10 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() query = """ @@ -221,14 +333,15 @@ def resolve_reporter(self, _info): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_custom_identifier(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_custom_identifier(session): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -262,14 +375,15 @@ class Query(graphene.ObjectType): } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_mutation(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_mutation(session, session_factory): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -282,8 +396,11 @@ class Meta: interfaces = (Node,) @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name="Cookie Monster") + async def get_node(cls, id, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() + return session.query(Reporter).first() class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -298,11 +415,14 @@ class Arguments: ok = graphene.Boolean() article = graphene.Field(ArticleNode) - def mutate(self, info, headline, reporter_id): + async def mutate(self, info, headline, reporter_id): + reporter = await ReporterNode.get_node(reporter_id, info) new_article = Article(headline=headline, reporter_id=reporter_id) + reporter.articles = [*reporter.articles, new_article] + session = get_session(info.context) + session.add(reporter) - session.add(new_article) - session.commit() + await eventually_await_session(session, "commit") ok = True return CreateArticle(article=new_article, ok=ok) @@ -341,24 +461,28 @@ class Mutation(graphene.ObjectType): } schema = graphene.Schema(query=Query, mutation=Mutation) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def add_person_data(session): +async def add_person_data(session): bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1)) session.add(bob) joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1)) session.add(joe) jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1)) session.add(jen) - session.commit() + await eventually_await_session(session, "commit") -def test_interface_query_on_base_type(session): - add_person_data(session) +@pytest.mark.asyncio +async def test_interface_query_on_base_type(session_factory): + session = session_factory() + await add_person_data(session) class PersonType(SQLAlchemyInterface): class Meta: @@ -372,11 +496,13 @@ class Meta: class Query(graphene.ObjectType): people = graphene.Field(graphene.List(PersonType)) - def resolve_people(self, _info): + async def resolve_people(self, _info): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Person))).all() return session.query(Person).all() schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType]) - result = schema.execute( + result = await schema.execute_async( """ query { people { diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 923bbed1..14c87f74 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,12 +1,22 @@ +import pytest +from sqlalchemy import select + import graphene +from graphene_sqlalchemy.tests.utils import eventually_await_session +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + -def test_query_pet_kinds(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_pet_kinds(session, session_factory): + await add_test_data(session) + await eventually_await_session(session, "close") class PetType(SQLAlchemyObjectType): class Meta: @@ -23,13 +33,25 @@ class Query(graphene.ObjectType): PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) ) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) - def resolve_pets(self, _info, kind): + async def resolve_pets(self, _info, kind): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).unique().all() query = session.query(Pet) if kind: query = query.filter_by(pet_kind=kind.value) @@ -78,13 +100,16 @@ def resolve_pets(self, _info, kind): "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors assert result.data == expected -def test_query_more_enums(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_more_enums(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -93,7 +118,10 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType) - def resolve_pet(self, _info): + async def resolve_pet(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Pet))).first() return session.query(Pet).first() query = """ @@ -107,14 +135,15 @@ def resolve_pet(self, _info): """ expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -125,7 +154,13 @@ class Query(graphene.ObjectType): PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) ) - def resolve_pet(self, info, kind=None): + async def resolve_pet(self, info, kind=None): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).first() query = session.query(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -142,19 +177,24 @@ def resolve_pet(self, info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "CAT"}) + result = await schema.execute_async( + query, variables={"kind": "CAT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "DOG"}) + result = await schema.execute_async( + query, variables={"kind": "DOG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) assert result == expected -def test_py_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_py_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -166,7 +206,14 @@ class Query(graphene.ObjectType): kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), ) - def resolve_pet(self, _info, kind=None): + async def resolve_pet(self, _info, kind=None): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return ( + await session.scalars( + select(Pet).filter(Pet.hair_kind == HairKind(kind)) + ) + ).first() query = session.query(Pet) if kind: # enum arguments are expected to be strings, not PyEnums @@ -184,11 +231,15 @@ def resolve_pet(self, _info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) + result = await schema.execute_async( + query, variables={"kind": "SHORT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "LONG"}) + result = await schema.execute_async( + query, variables={"kind": "LONG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 11c7c9a7..f8f1ff8c 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -9,16 +9,17 @@ from ..utils import to_type_name from .models import Base, HairKind, KeyedModel, Pet from .test_query import to_std_dicts +from .utils import eventually_await_session -def add_pets(session): +async def add_pets(session): pets = [ Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), ] session.add_all(pets) - session.commit() + await eventually_await_session(session, "commit") def test_sort_enum(): @@ -241,8 +242,9 @@ def get_symbol_name(column_name, sort_asc=True): assert sort_arg.default_value == ["IdUp"] -def test_sort_query(session): - add_pets(session) +@pytest.mark.asyncio +async def test_sort_query(session): + await add_pets(session) class PetNode(SQLAlchemyObjectType): class Meta: @@ -336,7 +338,7 @@ def makeNodes(nodeList): } # yapf: disable schema = Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -352,7 +354,7 @@ def makeNodes(nodeList): } } """ - result = schema.execute(queryError, context_value={"session": session}) + result = await schema.execute_async(queryError, context_value={"session": session}) assert result.errors is not None assert "cannot represent non-enum value" in result.errors[0].message @@ -375,7 +377,7 @@ def makeNodes(nodeList): } """ - result = schema.execute(queryNoSort, context_value={"session": session}) + result = await schema.execute_async(queryNoSort, context_value={"session": session}) assert not result.errors # TODO: SQLite usually returns the results ordered by primary key, # so we cannot test this way whether sorting actually happens or not. diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 813fb134..66328427 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,6 +4,9 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc +from graphql.pyutils import is_awaitable +from sqlalchemy import select + from graphene import ( Boolean, Dynamic, @@ -20,7 +23,6 @@ ) from graphene.relay import Connection -from .models import Article, CompositeFullName, Employee, Person, Pet, Reporter, NonAbstractPerson from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import ( @@ -36,11 +38,26 @@ SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions, ) +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 +from .models import ( + Article, + CompositeFullName, + Employee, + NonAbstractPerson, + Person, + Pet, + Reporter, +) +from .utils import eventually_await_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): pass @@ -48,12 +65,14 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): class Meta: model = 1 -def test_sqlalchemy_node(session): +@pytest.mark.asyncio +async def test_sqlalchemy_node(session): class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -64,9 +83,11 @@ class Meta: reporter = Reporter() session.add(reporter) - session.commit() + await eventually_await_session(session, "commit") info = mock.Mock(context={"session": session}) reporter_node = ReporterType.get_node(info, reporter.id) + if is_awaitable(reporter_node): + reporter_node = await reporter_node assert reporter == reporter_node @@ -97,7 +118,7 @@ class Meta: assert sorted(list(ReporterType._meta.fields.keys())) == sorted( [ # Columns - "column_prop", + "column_prop", # SQLAlchemy retuns column properties first "id", "first_name", "last_name", @@ -320,6 +341,7 @@ def test_invalid_model_attr(): "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -373,6 +395,7 @@ class Meta: def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -392,9 +415,19 @@ class Meta: assert first_name_field.type == Int -def test_resolvers(session): +@pytest.mark.asyncio +async def test_resolvers(session): """Test that the correct resolver functions are called""" + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) + session.add(reporter) + await eventually_await_session(session, "commit") + class ReporterMixin(object): def resolve_id(root, _info): return "ID" @@ -420,20 +453,14 @@ def resolve_favorite_pet_kind_v2(root, _info): class Query(ObjectType): reporter = Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = utils.get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - reporter = Reporter( - first_name="first_name", - last_name="last_name", - email="email", - favorite_pet_kind="cat", - ) - session.add(reporter) - session.commit() - schema = Schema(query=Query) - result = schema.execute( + result = await schema.execute_async( """ query { reporter { @@ -446,7 +473,8 @@ def resolve_reporter(self, _info): favoritePetKindV2 } } - """ + """, + context_value={"session": session}, ) assert not result.errors @@ -511,8 +539,13 @@ class Meta: def test_interface_with_polymorphic_identity(): - with pytest.raises(AssertionError, - match=re.escape('PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")')): + with pytest.raises( + AssertionError, + match=re.escape( + 'PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")' + ), + ): + class PersonType(SQLAlchemyInterface): class Meta: model = NonAbstractPerson @@ -562,13 +595,15 @@ class Meta: # type should be in this list because we used ORMField # to force its presence on the model - assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ - "id", - "name", - "type", - "birth_date", - "hire_date", - ]) + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "type", + "birth_date", + "hire_date", + ] + ) def test_interface_custom_resolver(): @@ -590,13 +625,15 @@ class Meta: # type should be in this list because we used ORMField # to force its presence on the model - assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ - "id", - "name", - "custom_field", - "birth_date", - "hire_date", - ]) + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "custom_field", + "birth_date", + "hire_date", + ] + ) # Tests for connection_field_factory diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index c90ee476..4a118243 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,3 +1,4 @@ +import inspect import re @@ -15,3 +16,11 @@ def remove_cache_miss_stat(message): """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) + + +async def eventually_await_session(session, func, *args): + + if inspect.iscoroutinefunction(getattr(session, func)): + await getattr(session, func)(*args) + else: + getattr(session, func)(*args) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e0ada38e..226d1e82 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,6 @@ from collections import OrderedDict +from inspect import isawaitable +from typing import Any import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property @@ -26,7 +28,16 @@ ) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_query, + get_session, + is_mapped_class, + is_mapped_instance, +) + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class ORMField(OrderedType): @@ -334,6 +345,11 @@ def __init_subclass_with_meta__( def is_type_of(cls, root, info): if isinstance(root, cls): return True + if isawaitable(root): + raise Exception( + "Received coroutine instead of sql alchemy model. " + "You seem to use an async engine with synchronous schema execution" + ) if not is_mapped_instance(root): raise Exception(('Received incompatible instance "{}".').format(root)) return isinstance(root, cls._meta.model) @@ -345,6 +361,19 @@ def get_query(cls, info): @classmethod def get_node(cls, info, id): + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None + + session = get_session(info.context) + if isinstance(session, AsyncSession): + + async def get_result() -> Any: + return await session.get(cls._meta.model, id) + + return get_result() try: return cls.get_query(info).get(id) except NoResultFound: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 54bb8402..62c71d8d 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -4,11 +4,34 @@ from typing import Any, Callable, Dict, Optional import pkg_resources +from sqlalchemy import select from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) + + +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return pkg_resources.get_distribution( + "graphene" + ).parsed_version < pkg_resources.parse_version(version_string) + + +SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False + +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + + SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True + + def get_session(context): return context.get("session") @@ -22,6 +45,8 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return select(model) query = session.query(model) return query @@ -151,20 +176,6 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_sqlalchemy_version_less_than(version_string): # pragma: no cover - """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution( - "SQLAlchemy" - ).parsed_version < pkg_resources.parse_version(version_string) - - -def is_graphene_version_less_than(version_string): # pragma: no cover - """Check the installed graphene version""" - return pkg_resources.get_distribution( - "graphene" - ).parsed_version < pkg_resources.parse_version(version_string) - - class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function diff --git a/setup.py b/setup.py index ac9ad7e6..9122baf2 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,13 @@ tests_require = [ "pytest>=6.2.0,<7.0", - "pytest-asyncio>=0.15.1", + "pytest-asyncio>=0.18.3", "pytest-cov>=2.11.0,<3.0", "sqlalchemy_utils>=0.37.0,<1.0", "pytest-benchmark>=3.4.0,<4.0", + "aiosqlite>=0.17.0", + "nest-asyncio", + "greenlet", ] setup( From a03e74dbe37024b2f75fd785e799bd236f64650e Mon Sep 17 00:00:00 2001 From: Vladislav Zahrevsky Date: Mon, 2 Jan 2023 16:16:25 +0200 Subject: [PATCH 39/67] docs: fix installation instruction (#372) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Install instructions in the README.md fails with an error: „Could not find a version that satisfies the requirement graphene-sqlalchemy>=3“ This is because v3 is in beta. Therefore, installing with '--pre' fixes the problem. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 68719f4d..6e96f91e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ A [SQLAlchemy](http://www.sqlalchemy.org/) integration for [Graphene](http://gra For installing Graphene, just run this command in your shell. ```bash -pip install "graphene-sqlalchemy>=3" +pip install --pre "graphene-sqlalchemy" ``` ## Examples From 20418356a3e2fecc0896ab424eb7154fca016900 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Wed, 4 Jan 2023 10:08:20 +0100 Subject: [PATCH 40/67] refactor!: use the same conversion system for hybrids and columns (#371) * refactor!: use the same conversion system for hybrids and columns fix: insert missing create_type in union conversion Breaking Change: convert_sqlalchemy_type now uses a matcher function Breaking Change: convert_sqlalchemy type's column and registry arguments must now be keyword arguments Breaking Change: convert_sqlalchemy_type support for subtypes is dropped, each column type must be explicitly registered Breaking Change: The hybrid property default column type is no longer a string. If no matching column type was found, an exception will be raised. Signed-off-by: Erik Wrede * fix: catch import error in older sqlalchemy versions Signed-off-by: Erik Wrede * fix: union test for 3.10 Signed-off-by: Erik Wrede * fix: use type and value for all columns Signed-off-by: Erik Wrede * refactor: rename value_equals to column_type_eq Signed-off-by: Erik Wrede * tests: add tests for string fallback removal of hybrid property chore: change the exception types Signed-off-by: Erik Wrede * chore: refactor converter for object types and scalars Signed-off-by: Erik Wrede * chore: remove string fallback from forward references Signed-off-by: Erik Wrede * chore: adjust comment Signed-off-by: Erik Wrede * fix: fix regression on id types from last commit Signed-off-by: Erik Wrede * refactor: made registry calls in converters lazy Signed-off-by: Erik Wrede * fix: DeclarativeMeta import path adjusted for sqa<1.4 Signed-off-by: Erik Wrede Signed-off-by: Erik Wrede --- graphene_sqlalchemy/converter.py | 388 ++++++++++++-------- graphene_sqlalchemy/registry.py | 6 +- graphene_sqlalchemy/tests/models.py | 11 +- graphene_sqlalchemy/tests/test_converter.py | 121 +++++- graphene_sqlalchemy/tests/test_registry.py | 4 +- graphene_sqlalchemy/utils.py | 25 +- 6 files changed, 380 insertions(+), 175 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index d3ae8123..7c5330b3 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -2,13 +2,12 @@ import sys import typing import uuid -import warnings from decimal import Decimal -from functools import singledispatch -from typing import Any, cast +from typing import Any, Optional, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import interfaces, strategies import graphene @@ -17,16 +16,31 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory -from .registry import get_global_registry +from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, DummyImport, + column_type_eq, registry_sqlalchemy_model_from_str, safe_isinstance, + safe_issubclass, singledispatchbymatchfunction, - value_equals, ) +# Import path changed in 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.orm import DeclarativeMeta +else: + from sqlalchemy.ext.declarative import DeclarativeMeta + +# We just use MapperProperties for type hints, they don't exist in sqlalchemy < 1.4 +try: + from sqlalchemy import MapperProperty +except ImportError: + # sqlalchemy < 1.4 + MapperProperty = Any + try: from typing import ForwardRef except ImportError: @@ -207,10 +221,15 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - + column_type = getattr(column, "type", None) + # The converter expects a type to find the right conversion function. + # If we get an instance instead, we need to convert it to a type. + # The conversion function will still be able to access the instance via the column argument. + if not isinstance(column_type, type): + column_type = type(column_type) field_kwargs.setdefault( "type_", - convert_sqlalchemy_type(getattr(column, "type", None), column, registry), + convert_sqlalchemy_type(column_type, column=column, registry=registry), ) field_kwargs.setdefault("required", not is_column_nullable(column)) field_kwargs.setdefault("description", get_column_doc(column)) @@ -218,86 +237,178 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): return graphene.Field(resolver=resolver, **field_kwargs) -@singledispatch -def convert_sqlalchemy_type(type, column, registry=None): - raise Exception( - "Don't know how to convert the SQLAlchemy field %s (%s)" - % (column, column.__class__) +@singledispatchbymatchfunction +def convert_sqlalchemy_type( # noqa + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + # No valid type found, raise an error + + raise TypeError( + "Don't know how to convert the SQLAlchemy field %s (%s, %s). " + "Please add a type converter or set the type manually using ORMField(type_=your_type)" + % (column, column.__class__ or "no column provided", type_arg) ) -@convert_sqlalchemy_type.register(sqa_types.String) -@convert_sqlalchemy_type.register(sqa_types.Text) -@convert_sqlalchemy_type.register(sqa_types.Unicode) -@convert_sqlalchemy_type.register(sqa_types.UnicodeText) -@convert_sqlalchemy_type.register(postgresql.INET) -@convert_sqlalchemy_type.register(postgresql.CIDR) -@convert_sqlalchemy_type.register(sqa_utils.TSVectorType) -@convert_sqlalchemy_type.register(sqa_utils.EmailType) -@convert_sqlalchemy_type.register(sqa_utils.URLType) -@convert_sqlalchemy_type.register(sqa_utils.IPAddressType) -def convert_column_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) +def convert_sqlalchemy_model_using_registry( + type_arg: Any, registry: Registry = None, **kwargs +): + registry_ = registry or get_global_registry() + + def get_type_from_registry(): + existing_graphql_type = registry_.get_type_for_model(type_arg) + if existing_graphql_type: + return existing_graphql_type + + raise TypeError( + "No model found in Registry for type %s. " + "Only references to SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg + ) + + return get_type_from_registry() + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.ObjectType)) +def convert_object_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(safe_issubclass(graphene.Scalar)) +def convert_scalar_type(type_arg: Any, **kwargs): + return type_arg + + +@convert_sqlalchemy_type.register(column_type_eq(str)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Unicode)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.UnicodeText)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.INET)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.CIDR)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.TSVectorType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.EmailType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.URLType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.IPAddressType)) +def convert_column_to_string(type_arg: Any, **kwargs): return graphene.String -@convert_sqlalchemy_type.register(postgresql.UUID) -@convert_sqlalchemy_type.register(sqa_utils.UUIDType) -def convert_column_to_uuid(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(postgresql.UUID)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) +@convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) +def convert_column_to_uuid( + type_arg: Any, + **kwargs, +): return graphene.UUID -@convert_sqlalchemy_type.register(sqa_types.DateTime) -def convert_column_to_datetime(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) +def convert_column_to_datetime( + type_arg: Any, + **kwargs, +): return graphene.DateTime -@convert_sqlalchemy_type.register(sqa_types.Time) -def convert_column_to_time(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.time)) +def convert_column_to_time( + type_arg: Any, + **kwargs, +): return graphene.Time -@convert_sqlalchemy_type.register(sqa_types.Date) -def convert_column_to_date(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) +@convert_sqlalchemy_type.register(column_type_eq(datetime.date)) +def convert_column_to_date( + type_arg: Any, + **kwargs, +): return graphene.Date -@convert_sqlalchemy_type.register(sqa_types.SmallInteger) -@convert_sqlalchemy_type.register(sqa_types.Integer) -def convert_column_to_int_or_id(type, column, registry=None): - return graphene.ID if column.primary_key else graphene.Int +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.SmallInteger)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) +@convert_sqlalchemy_type.register(column_type_eq(int)) +def convert_column_to_int_or_id( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + # fixme drop the primary key processing from here in another pr + if column is not None: + if getattr(column, "primary_key", False) is True: + return graphene.ID + return graphene.Int -@convert_sqlalchemy_type.register(sqa_types.Boolean) -def convert_column_to_boolean(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) +@convert_sqlalchemy_type.register(column_type_eq(bool)) +def convert_column_to_boolean( + type_arg: Any, + **kwargs, +): return graphene.Boolean -@convert_sqlalchemy_type.register(sqa_types.Float) -@convert_sqlalchemy_type.register(sqa_types.Numeric) -@convert_sqlalchemy_type.register(sqa_types.BigInteger) -def convert_column_to_float(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Float)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) +def convert_column_to_float( + type_arg: Any, + **kwargs, +): return graphene.Float -@convert_sqlalchemy_type.register(sqa_types.Enum) -def convert_enum_to_enum(type, column, registry=None): - return lambda: enum_for_sa_enum(type, registry or get_global_registry()) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) +def convert_enum_to_enum( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Enum conversion requires a column") + + return lambda: enum_for_sa_enum(column.type, registry or get_global_registry()) # TODO Make ChoiceType conversion consistent with other enums -@convert_sqlalchemy_type.register(sqa_utils.ChoiceType) -def convert_choice_to_enum(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) +def convert_choice_to_enum( + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("ChoiceType conversion requires a column") + name = "{}_{}".format(column.table.name, column.key).upper() - if isinstance(type.type_impl, EnumTypeImpl): + if isinstance(column.type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table - return graphene.Enum(name, list((v.name, v.value) for v in type.choices)) + return graphene.Enum(name, list((v.name, v.value) for v in column.type.choices)) else: - return graphene.Enum(name, type.choices) + return graphene.Enum(name, column.type.choices) -@convert_sqlalchemy_type.register(sqa_utils.ScalarListType) -def convert_scalar_list_to_list(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) +def convert_scalar_list_to_list( + type_arg: Any, + **kwargs, +): return graphene.List(graphene.String) @@ -309,108 +420,79 @@ def init_array_list_recursive(inner_type, n): ) -@convert_sqlalchemy_type.register(sqa_types.ARRAY) -@convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_array_to_list(_type, column, registry=None): - inner_type = convert_sqlalchemy_type(column.type.item_type, column) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) +def convert_array_to_list( + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("SQL-Array conversion requires a column") + item_type = column.type.item_type + if not isinstance(item_type, type): + item_type = type(item_type) + inner_type = convert_sqlalchemy_type( + item_type, column=column, registry=registry, **kwargs + ) return graphene.List( init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1) ) -@convert_sqlalchemy_type.register(postgresql.HSTORE) -@convert_sqlalchemy_type.register(postgresql.JSON) -@convert_sqlalchemy_type.register(postgresql.JSONB) -def convert_json_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(postgresql.HSTORE)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) +@convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) +def convert_json_to_string( + type_arg: Any, + **kwargs, +): return JSONString -@convert_sqlalchemy_type.register(sqa_utils.JSONType) -@convert_sqlalchemy_type.register(sqa_types.JSON) -def convert_json_type_to_string(type, column, registry=None): +@convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) +def convert_json_type_to_string( + type_arg: Any, + **kwargs, +): return JSONString -@convert_sqlalchemy_type.register(sqa_types.Variant) -def convert_variant_to_impl_type(type, column, registry=None): - return convert_sqlalchemy_type(type.impl, column, registry=registry) - - -@singledispatchbymatchfunction -def convert_sqlalchemy_hybrid_property_type(arg: Any): - existing_graphql_type = get_global_registry().get_type_for_model(arg) - if existing_graphql_type: - return existing_graphql_type - - if isinstance(arg, type(graphene.ObjectType)): - return arg - - if isinstance(arg, type(graphene.Scalar)): - return arg - - # No valid type found, warn and fall back to graphene.String - warnings.warn( - f'I don\'t know how to generate a GraphQL type out of a "{arg}" type.' - 'Falling back to "graphene.String"' +@convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) +def convert_variant_to_impl_type( + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, +): + if column is None or isinstance(column, hybrid_property): + raise Exception("Vaiant conversion requires a column") + + type_impl = column.type.impl + if not isinstance(type_impl, type): + type_impl = type(type_impl) + return convert_sqlalchemy_type( + type_impl, column=column, registry=registry, **kwargs ) - return graphene.String - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) -def convert_sqlalchemy_hybrid_property_type_str(arg): - return graphene.String - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) -def convert_sqlalchemy_hybrid_property_type_int(arg): - return graphene.Int - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) -def convert_sqlalchemy_hybrid_property_type_float(arg): - return graphene.Float -@convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) -def convert_sqlalchemy_hybrid_property_type_decimal(arg): +@convert_sqlalchemy_type.register(column_type_eq(Decimal)) +def convert_sqlalchemy_hybrid_property_type_decimal(type_arg: Any, **kwargs): # The reason Decimal should be serialized as a String is because this is a # base10 type used in things like money, and string allows it to not # lose precision (which would happen if we downcasted to a Float, for example) return graphene.String -@convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) -def convert_sqlalchemy_hybrid_property_type_bool(arg): - return graphene.Boolean - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) -def convert_sqlalchemy_hybrid_property_type_datetime(arg): - return graphene.DateTime - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) -def convert_sqlalchemy_hybrid_property_type_date(arg): - return graphene.Date - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) -def convert_sqlalchemy_hybrid_property_type_time(arg): - return graphene.Time - - -@convert_sqlalchemy_hybrid_property_type.register(value_equals(uuid.UUID)) -def convert_sqlalchemy_hybrid_property_type_uuid(arg): - return graphene.UUID - - -def is_union(arg) -> bool: +def is_union(type_arg: Any, **kwargs) -> bool: if sys.version_info >= (3, 10): from types import UnionType - if isinstance(arg, UnionType): + if isinstance(type_arg, UnionType): return True - return getattr(arg, "__origin__", None) == typing.Union + return getattr(type_arg, "__origin__", None) == typing.Union def graphene_union_for_py_union( @@ -421,14 +503,14 @@ def graphene_union_for_py_union( if union_type is None: # Union Name is name of the three union_name = "".join(sorted(obj_type._meta.name for obj_type in obj_types)) - union_type = graphene.Union(union_name, obj_types) + union_type = graphene.Union.create_type(union_name, types=obj_types) registry.register_union_type(union_type, obj_types) return union_type -@convert_sqlalchemy_hybrid_property_type.register(is_union) -def convert_sqlalchemy_hybrid_property_union(arg): +@convert_sqlalchemy_type.register(is_union) +def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): """ Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. Since Optionals are internally represented as Union[T, ], they are handled here as well. @@ -444,11 +526,11 @@ def convert_sqlalchemy_hybrid_property_union(arg): # Option is actually Union[T, ] # Just get the T out of the list of arguments by filtering out the NoneType - nested_types = list(filter(lambda x: not type(None) == x, arg.__args__)) + nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__)) # Map the graphene types to the nested types. # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... - graphene_types = list(map(convert_sqlalchemy_hybrid_property_type, nested_types)) + graphene_types = list(map(convert_sqlalchemy_type, nested_types)) # If only one type is left after filtering out NoneType, the Union was an Optional if len(graphene_types) == 1: @@ -471,20 +553,20 @@ def convert_sqlalchemy_hybrid_property_union(arg): ) -@convert_sqlalchemy_hybrid_property_type.register( +@convert_sqlalchemy_type.register( lambda x: getattr(x, "__origin__", None) in [list, typing.List] ) -def convert_sqlalchemy_hybrid_property_type_list_t(arg): +def convert_sqlalchemy_hybrid_property_type_list_t(type_arg: Any, **kwargs): # type is either list[T] or List[T], generic argument at __args__[0] - internal_type = arg.__args__[0] + internal_type = type_arg.__args__[0] - graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + graphql_internal_type = convert_sqlalchemy_type(internal_type, **kwargs) return graphene.List(graphql_internal_type) -@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) -def convert_sqlalchemy_hybrid_property_forwardref(arg): +@convert_sqlalchemy_type.register(safe_isinstance(ForwardRef)) +def convert_sqlalchemy_hybrid_property_forwardref(type_arg: Any, **kwargs): """ Generate a lambda that will resolve the type at runtime This takes care of self-references @@ -492,26 +574,36 @@ def convert_sqlalchemy_hybrid_property_forwardref(arg): from .registry import get_global_registry def forward_reference_solver(): - model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) + model = registry_sqlalchemy_model_from_str(type_arg.__forward_arg__) if not model: - return graphene.String + raise TypeError( + "No model found in Registry for forward reference for type %s. " + "Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed." % type_arg + ) # Always fall back to string if no ForwardRef type found. return get_global_registry().get_type_for_model(model) return forward_reference_solver -@convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(str)) -def convert_sqlalchemy_hybrid_property_bare_str(arg): +@convert_sqlalchemy_type.register(safe_isinstance(str)) +def convert_sqlalchemy_hybrid_property_bare_str(type_arg: str, **kwargs): """ Convert Bare String into a ForwardRef """ - return convert_sqlalchemy_hybrid_property_type(ForwardRef(arg)) + return convert_sqlalchemy_type(ForwardRef(type_arg), **kwargs) def convert_hybrid_property_return_type(hybrid_prop): # Grab the original method's return type annotations from inside the hybrid property - return_type_annotation = hybrid_prop.fget.__annotations__.get("return", str) + return_type_annotation = hybrid_prop.fget.__annotations__.get("return", None) + if not return_type_annotation: + raise TypeError( + "Cannot convert hybrid property type {} to a valid graphene type. " + "Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.".format(hybrid_prop) + ) - return convert_sqlalchemy_hybrid_property_type(return_type_annotation) + return convert_sqlalchemy_type(return_type_annotation, column=hybrid_prop) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index cc4b02b7..3c463013 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -83,13 +83,13 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) def register_union_type( - self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]] + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] ): - if not isinstance(union, graphene.Union): + if not issubclass(union, graphene.Union): raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) for obj_type in obj_types: - if not isinstance(obj_type, type(graphene.ObjectType)): + if not issubclass(obj_type, graphene.ObjectType): raise TypeError( "Expected Graphene ObjectType, but got: {!r}".format(obj_type) ) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index ee286585..9531aaaa 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -4,7 +4,7 @@ import enum import uuid from decimal import Decimal -from typing import List, Optional, Tuple +from typing import List, Optional from sqlalchemy import ( Column, @@ -88,12 +88,12 @@ class Reporter(Base): favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property - def hybrid_prop_with_doc(self): + def hybrid_prop_with_doc(self) -> str: """Docstring test""" return self.first_name @hybrid_property - def hybrid_prop(self): + def hybrid_prop(self) -> str: return self.first_name @hybrid_property @@ -253,11 +253,6 @@ def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: return [ShoppingCartItem(id=1), ShoppingCartItem(id=2)] - # Unsupported Type - @hybrid_property - def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: - return "this will actually", "be a string" - # Self-references @hybrid_property diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index b9a1c152..e903396f 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,6 +1,6 @@ import enum import sys -from typing import Dict, Union +from typing import Dict, Tuple, Union import pytest import sqlalchemy_utils as sqa_utils @@ -20,6 +20,7 @@ convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship, + convert_sqlalchemy_type, ) from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry @@ -78,6 +79,110 @@ def prop_method() -> int: assert get_hybrid_property_type(prop_method).type == graphene.Int +def test_hybrid_unknown_annotation(): + @hybrid_property + def hybrid_prop(self): + return "This should fail" + + with pytest.raises( + TypeError, + match=r"(.*)Please make sure to annotate the return type of the hybrid property or use the " + "type_ attribute of ORMField to set the type.(.*)", + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_prop_no_type_annotation(): + @hybrid_property + def hybrid_prop(self) -> Tuple[str, str]: + return "This should Fail because", "we don't support Tuples in GQL" + + with pytest.raises( + TypeError, match=r"(.*)Don't know how to convert the SQLAlchemy field(.*)" + ): + get_hybrid_property_type(hybrid_prop) + + +def test_hybrid_invalid_forward_reference(): + class MyTypeNotInRegistry: + pass + + @hybrid_property + def hybrid_prop(self) -> "MyTypeNotInRegistry": + return MyTypeNotInRegistry() + + with pytest.raises( + TypeError, + match=r"(.*)Only forward references to other SQLAlchemy Models mapped to " + "SQLAlchemyObjectTypes are allowed.(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_object_type(): + class MyObjectType(graphene.ObjectType): + string = graphene.String() + + @hybrid_property + def hybrid_prop(self) -> MyObjectType: + return MyObjectType() + + assert get_hybrid_property_type(hybrid_prop).type == MyObjectType + + +def test_hybrid_prop_scalar_type(): + @hybrid_property + def hybrid_prop(self) -> graphene.String: + return "This should work" + + assert get_hybrid_property_type(hybrid_prop).type == graphene.String + + +def test_hybrid_prop_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "This shouldn't work" + + with pytest.raises(TypeError, match=r"(.*)No model found in Registry for type(.*)"): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> ShoppingCartItem: + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + +def test_hybrid_prop_forward_ref_not_mapped_to_graphene_type(): + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "This shouldn't work" + + with pytest.raises( + TypeError, + match=r"(.*)No model found in Registry for forward reference for type(.*)", + ): + get_hybrid_property_type(hybrid_prop).type + + +def test_hybrid_prop_forward_ref_mapped_to_graphene_type(): + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + + @hybrid_property + def hybrid_prop(self) -> "ShoppingCartItem": + return "Dummy return value" + + get_hybrid_property_type(hybrid_prop).type == ShoppingCartType + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -131,11 +236,10 @@ def prop_method_2() -> Union[ShoppingCartType, PetType]: field_type_1 = get_hybrid_property_type(prop_method).type field_type_2 = get_hybrid_property_type(prop_method_2).type - assert isinstance(field_type_1, graphene.Union) + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] assert field_type_1 is field_type_2 - # TODO verify types of the union - @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" @@ -164,10 +268,16 @@ def prop_method_2() -> ShoppingCartType | PetType: field_type_1 = get_hybrid_property_type(prop_method).type field_type_2 = get_hybrid_property_type(prop_method_2).type - assert isinstance(field_type_1, graphene.Union) + assert issubclass(field_type_1, graphene.Union) + assert field_type_1._meta.types == [PetType, ShoppingCartType] assert field_type_1 is field_type_2 +def test_should_unknown_type_raise_error(): + with pytest.raises(Exception): + converted_type = convert_sqlalchemy_type(ZeroDivisionError) # noqa + + def test_should_datetime_convert_datetime(): assert get_field(types.DateTime()).type == graphene.DateTime @@ -667,7 +777,6 @@ class Meta: ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), - "hybrid_prop_unsupported_type_tuple": graphene.String, # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 68b5404f..e54f08b1 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -142,7 +142,7 @@ class Meta: model = Reporter union_types = [PetType, ReporterType] - union = graphene.Union("ReporterPet", tuple(union_types)) + union = graphene.Union.create_type("ReporterPet", types=tuple(union_types)) reg.register_union_type(union, union_types) @@ -155,7 +155,7 @@ def test_register_union_scalar(): reg = Registry() union_types = [graphene.String, graphene.Int] - union = graphene.Union("StringInt", tuple(union_types)) + union = graphene.Union.create_type("StringInt", types=union_types) re_err = r"Expected Graphene ObjectType, but got: .*String.*" with pytest.raises(TypeError, match=re_err): diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 62c71d8d..1bf361f1 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -196,18 +196,17 @@ def __call__(self, *args, **kwargs): # No match, using default. return self.default(*args, **kwargs) - def register(self, matcher_function: Callable[[Any], bool]): - def grab_function_from_outside(f): - self.registry[matcher_function] = f - return self + def register(self, matcher_function: Callable[[Any], bool], func=None): + if func is None: + return lambda f: self.register(matcher_function, f) + self.registry[matcher_function] = func + return func - return grab_function_from_outside - -def value_equals(value): +def column_type_eq(value: Any) -> Callable[[Any], bool]: """A simple function that makes the equality based matcher functions for SingleDispatchByMatchFunction prettier""" - return lambda x: x == value + return lambda x: (x == value) def safe_isinstance(cls): @@ -220,6 +219,16 @@ def safe_isinstance_checker(arg): return safe_isinstance_checker +def safe_issubclass(cls): + def safe_issubclass_checker(arg): + try: + return issubclass(arg, cls) + except TypeError: + pass + + return safe_issubclass_checker + + def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: from graphene_sqlalchemy.registry import get_global_registry From d3a4320c1c5f9ef6b23ec3ac7fea2f567360ddaa Mon Sep 17 00:00:00 2001 From: Frederick Polgardy Date: Fri, 13 Jan 2023 05:12:16 -0700 Subject: [PATCH 41/67] feat!: Stricter non-null fields for relationships (#367) to-many relationships are now non-null by default. (List[MyType] -> List[MyType!]!) The behavior can be adjusted back to legacy using `converter.set_non_null_many_relationships(False)` or using an `ORMField` manually setting the type for more granular Adjustments --- graphene_sqlalchemy/converter.py | 42 ++++++++++++++++++++- graphene_sqlalchemy/tests/test_converter.py | 35 +++++++++++++++++ graphene_sqlalchemy/tests/test_types.py | 6 ++- 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 7c5330b3..26f5b3a7 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -59,6 +59,39 @@ is_selectin_available = getattr(strategies, "SelectInLoader", None) +""" +Flag for whether to generate stricter non-null fields for many-relationships. + +For many-relationships, both the list element and the list field itself will be +non-null by default. This better matches ORM semantics, where there is always a +list for a many relationship (even if it is empty), and it never contains None. + +This option can be set to False to revert to pre-3.0 behavior. + +For example, given a User model with many Comments: + + class User(Base): + comments = relationship("Comment") + +The Schema will be: + + type User { + comments: [Comment!]! + } + +When set to False, the pre-3.0 behavior gives: + + type User { + comments: [Comment] + } +""" +use_non_null_many_relationships = True + + +def set_non_null_many_relationships(non_null_flag): + global use_non_null_many_relationships + use_non_null_many_relationships = non_null_flag + def get_column_doc(column): return getattr(column, "doc", None) @@ -160,7 +193,14 @@ def _convert_o2m_or_m2m_relationship( ) if not child_type._meta.connection: - return graphene.Field(graphene.List(child_type), **field_kwargs) + # check if we need to use non-null fields + list_type = ( + graphene.NonNull(graphene.List(graphene.NonNull(child_type))) + if use_non_null_many_relationships + else graphene.List(child_type) + ) + + return graphene.Field(list_type, **field_kwargs) # TODO Allow override of connection_field_factory and resolver via ORMField if connection_field_factory is None: diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e903396f..b4c6eb24 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -21,6 +21,7 @@ convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship, convert_sqlalchemy_type, + set_non_null_many_relationships, ) from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry @@ -71,6 +72,16 @@ class Model(declarative_base()): ) +@pytest.fixture +def use_legacy_many_relationships(): + set_non_null_many_relationships(False) + try: + yield + finally: + set_non_null_many_relationships(True) + + + def test_hybrid_prop_int(): @hybrid_property def prop_method() -> int: @@ -501,6 +512,30 @@ class Meta: True, "orm_field_name", ) + # field should be [A!]! + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert isinstance(graphene_type.type, graphene.NonNull) + assert isinstance(graphene_type.type.of_type, graphene.List) + assert isinstance(graphene_type.type.of_type.of_type, graphene.NonNull) + assert graphene_type.type.of_type.of_type.of_type == A + + +@pytest.mark.usefixtures("use_legacy_many_relationships") +def test_should_manytomany_convert_connectionorlist_list_legacy(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, + A, + default_connection_field_factory, + True, + "orm_field_name", + ) + # field should be [A] assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 66328427..3de443d5 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -331,8 +331,10 @@ class Meta: pets_field = ReporterType._meta.fields["pets"] assert isinstance(pets_field, Dynamic) - assert isinstance(pets_field.type().type, List) - assert pets_field.type().type.of_type == PetType + assert isinstance(pets_field.type().type, NonNull) + assert isinstance(pets_field.type().type.of_type, List) + assert isinstance(pets_field.type().type.of_type.of_type, NonNull) + assert pets_field.type().type.of_type.of_type.of_type == PetType assert pets_field.type().description == "Overridden" From 1708fcf1881d2af73a59fd6e23f08beb036483c6 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 27 Jan 2023 11:11:38 +0100 Subject: [PATCH 42/67] fix: allow type converter inheritance again (#377) * fix: Make ORMField(type_) work in case there is no registered converter * revert/feat!: Type Converters support subtypes again. this feature adjusts the conversion system to use the MRO of a supplied class * tests: add test cases for mro & orm field fixes * tests: use custom type instead of BIGINT due to version incompatibilities --- graphene_sqlalchemy/converter.py | 15 ++++---- graphene_sqlalchemy/tests/models.py | 38 ++++++++++++++++++++ graphene_sqlalchemy/tests/test_converter.py | 39 ++++++++++++++++++++- graphene_sqlalchemy/utils.py | 18 +++++++--- 4 files changed, 98 insertions(+), 12 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 26f5b3a7..8c7cd7a1 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -261,16 +261,17 @@ def inner(fn): def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - column_type = getattr(column, "type", None) # The converter expects a type to find the right conversion function. # If we get an instance instead, we need to convert it to a type. # The conversion function will still be able to access the instance via the column argument. - if not isinstance(column_type, type): - column_type = type(column_type) - field_kwargs.setdefault( - "type_", - convert_sqlalchemy_type(column_type, column=column, registry=registry), - ) + if "type_" not in field_kwargs: + column_type = getattr(column, "type", None) + if not isinstance(column_type, type): + column_type = type(column_type) + field_kwargs.setdefault( + "type_", + convert_sqlalchemy_type(column_type, column=column, registry=registry), + ) field_kwargs.setdefault("required", not is_column_nullable(column)) field_kwargs.setdefault("description", get_column_doc(column)) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 9531aaaa..5acbc6fd 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -21,6 +21,8 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship +from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter +from sqlalchemy.sql.type_api import TypeEngine PetKind = Enum("cat", "dog", name="pet_kind") @@ -328,3 +330,39 @@ class Employee(Person): __mapper_args__ = { "polymorphic_identity": "employee", } + + +############################################ +# Custom Test Models +############################################ + + +class CustomIntegerColumn(_LookupExpressionAdapter, TypeEngine): + """ + Custom Column Type that our converters don't recognize + Adapted from sqlalchemy.Integer + """ + + """A type for ``int`` integers.""" + + __visit_name__ = "integer" + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + @property + def python_type(self): + return int + + def literal_processor(self, dialect): + def process(value): + return str(int(value)) + + return process + + +class CustomColumnModel(Base): + __tablename__ = "customcolumnmodel" + + id = Column(Integer(), primary_key=True) + custom_col = Column(CustomIntegerColumn) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index b4c6eb24..f70a50f0 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -3,6 +3,7 @@ from typing import Dict, Tuple, Union import pytest +import sqlalchemy import sqlalchemy_utils as sqa_utils from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql @@ -29,6 +30,7 @@ from .models import ( Article, CompositeFullName, + CustomColumnModel, Pet, Reporter, ShoppingCart, @@ -81,7 +83,6 @@ def use_legacy_many_relationships(): set_non_null_many_relationships(True) - def test_hybrid_prop_int(): @hybrid_property def prop_method() -> int: @@ -745,6 +746,42 @@ def __init__(self, col1, col2): ) +def test_raise_exception_unkown_column_type(): + with pytest.raises( + Exception, + match="Don't know how to convert the SQLAlchemy field customcolumnmodel.custom_col", + ): + + class A(SQLAlchemyObjectType): + class Meta: + model = CustomColumnModel + + +def test_prioritize_orm_field_unkown_column_type(): + class A(SQLAlchemyObjectType): + class Meta: + model = CustomColumnModel + + custom_col = ORMField(type_=graphene.Int) + + assert A._meta.fields["custom_col"].type == graphene.Int + + +def test_match_supertype_from_mro_correct_order(): + """ + BigInt and Integer are both superclasses of BIGINT, but a custom converter exists for BigInt that maps to Float. + We expect the correct MRO order to be used and conversion by the nearest match. BIGINT should be converted to Float, + just like BigInt, not to Int like integer which is further up in the MRO. + """ + + class BIGINT(sqlalchemy.types.BigInteger): + pass + + field = get_field_from_column(Column(BIGINT)) + + assert field.type == graphene.Float + + def test_sqlalchemy_hybrid_property_type_inference(): class ShoppingCartItemType(SQLAlchemyObjectType): class Meta: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 1bf361f1..ac5be88d 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,6 +1,7 @@ import re import warnings from collections import OrderedDict +from functools import _c3_mro from typing import Any, Callable, Dict, Optional import pkg_resources @@ -188,10 +189,19 @@ def __init__(self, default: Callable): self.default = default def __call__(self, *args, **kwargs): - for matcher_function, final_method in self.registry.items(): - # Register order is important. First one that matches, runs. - if matcher_function(args[0]): - return final_method(*args, **kwargs) + matched_arg = args[0] + try: + mro = _c3_mro(matched_arg) + except Exception: + # In case of tuples or similar types, we can't use the MRO. + # Fall back to just matching the original argument. + mro = [matched_arg] + + for cls in mro: + for matcher_function, final_method in self.registry.items(): + # Register order is important. First one that matches, runs. + if matcher_function(cls): + return final_method(*args, **kwargs) # No match, using default. return self.default(*args, **kwargs) From 185a662d70dbbc8eaa5c127c1ffe7fe547460d98 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 24 Feb 2023 12:36:39 +0100 Subject: [PATCH 43/67] docs: add docs pipeline --- .github/workflows/docs.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/docs.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..89f44467 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,19 @@ +name: Deploy Docs + +# Runs on pushes targeting the default branch +on: + push: + branches: [master] + +jobs: + pages: + runs-on: ubuntu-22.04 + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + permissions: + pages: write + id-token: write + steps: + - id: deployment + uses: sphinx-notes/pages@v3 From 686613d432e3710c9236e507ad6349e20b242657 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 24 Feb 2023 13:21:25 +0100 Subject: [PATCH 44/67] docs: extend docs and add autodoc api docs --- docs/api.rst | 4 ++ docs/index.rst | 5 +- docs/inheritance.rst | 2 +- docs/relay.rst | 43 ++++++++++++++++ docs/starter.rst | 118 +++++++++++++++++++++++++++++++++++++++++++ docs/tips.rst | 2 +- 6 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 docs/api.rst create mode 100644 docs/relay.rst create mode 100644 docs/starter.rst diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 00000000..66935c7f --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,4 @@ +API Reference +==== + +.. automodule::graphene_sqlalchemy \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 81b2f316..ea30fc8f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,6 +6,9 @@ Contents: .. toctree:: :maxdepth: 0 - tutorial + starter + inheritance tips examples + tutorial + api diff --git a/docs/inheritance.rst b/docs/inheritance.rst index 74732162..ae80c3b6 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -3,7 +3,7 @@ Inheritance Examples Create interfaces from inheritance relationships ------------------------------------------------ -.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. +.. note:: If you're using `AsyncSession`, please check the section `Eager Loading & Using with AsyncSession`_. SQLAlchemy has excellent support for class inheritance hierarchies. These hierarchies can be represented in your GraphQL schema by means of interfaces_. Much like ObjectTypes, Interfaces in diff --git a/docs/relay.rst b/docs/relay.rst new file mode 100644 index 00000000..2cce3b71 --- /dev/null +++ b/docs/relay.rst @@ -0,0 +1,43 @@ +Relay +==== + +:code:`graphene-sqlalchemy` comes with pre-defined +connection fields to quickly create a functioning relay API. +Using the :code:`SQLAlchemyConnectionField`, you have access to relay pagination, +sorting and filtering (filtering is coming soon!). + +To be used in a relay connection, your :code:`SQLAlchemyObjectType` must implement +the :code:`Node` interface from :code:`graphene.relay`. This handles the creation of +the :code:`Connection` and :code:`Edge` types automatically. + +The following example creates a relay-paginated connection: + + + +.. code:: python + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(Enum('cat', 'dog', name='pet_kind'), nullable=False) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces=(Node,) + + + class Query(ObjectType): + all_pets = SQLAlchemyConnectionField(PetNode.connection) + +To disable sorting on the connection, you can set :code:`sort` to :code:`None` the +:code:`SQLAlchemyConnectionField`: + + +.. code:: python + + class Query(ObjectType): + all_pets = SQLAlchemyConnectionField(PetNode.connection, sort=None) + diff --git a/docs/starter.rst b/docs/starter.rst new file mode 100644 index 00000000..a288f998 --- /dev/null +++ b/docs/starter.rst @@ -0,0 +1,118 @@ +Getting Started +==== + +Welcome to the graphene-sqlalchemy documentation! +Graphene is a powerful Python library for building GraphQL APIs, +and SQLAlchemy is a popular ORM (Object-Relational Mapping) +tool for working with databases. When combined, graphene-sqlalchemy +allows developers to quickly and easily create a GraphQL API that +seamlessly interacts with a SQLAlchemy-managed database. +It is fully compatible with SQLAlchemy 1.4 and 2.0. +This documentation provides detailed instructions on how to get +started with graphene-sqlalchemy, including installation, setup, +and usage examples. + +Installation +------------ + +To install :code:`graphene-sqlalchemy`, just run this command in your shell: + +.. code:: bash + + pip install --pre "graphene-sqlalchemy" + +Examples +-------- + +Here is a simple SQLAlchemy model: + +.. code:: python + + from sqlalchemy import Column, Integer, String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class UserModel(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String) + last_name = Column(String) + +To create a GraphQL schema for it, you simply have to write the +following: + +.. code:: python + + import graphene + from graphene_sqlalchemy import SQLAlchemyObjectType + + class User(SQLAlchemyObjectType): + class Meta: + model = UserModel + # use `only_fields` to only expose specific fields ie "name" + # only_fields = ("name",) + # use `exclude_fields` to exclude specific fields ie "last_name" + # exclude_fields = ("last_name",) + + class Query(graphene.ObjectType): + users = graphene.List(User) + + def resolve_users(self, info): + query = User.get_query(info) # SQLAlchemy query + return query.all() + + schema = graphene.Schema(query=Query) + +Then you can simply query the schema: + +.. code:: python + + query = ''' + query { + users { + name, + lastName + } + } + ''' + result = schema.execute(query, context_value={'session': db_session}) + + +It is important to provide a session for graphene-sqlalchemy to resolve the models. +In this example, it is provided using the GraphQL context. See :ref:`querying` for +other ways to implement this. + +You may also subclass SQLAlchemyObjectType by providing +``abstract = True`` in your subclasses Meta: + +.. code:: python + + from graphene_sqlalchemy import SQLAlchemyObjectType + + class ActiveSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def get_node(cls, info, id): + return cls.get_query(info).filter( + and_(cls._meta.model.deleted_at==None, + cls._meta.model.id==id) + ).first() + + class User(ActiveSQLAlchemyObjectType): + class Meta: + model = UserModel + + class Query(graphene.ObjectType): + users = graphene.List(User) + + def resolve_users(self, info): + query = User.get_query(info) # SQLAlchemy query + return query.all() + + schema = graphene.Schema(query=Query) + +More complex inhertiance using SQLAlchemy's polymorphic models is also supported. +You can check out :doc:`inheritance` for a guide. diff --git a/docs/tips.rst b/docs/tips.rst index baa8233f..daee1731 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -4,7 +4,7 @@ Tips Querying -------- - +.. _querying: In order to make querying against the database work, there are two alternatives: - Set the db session when you do the execution: From aa668d100880532c264d1c12c5e64df9d715b546 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 24 Feb 2023 13:26:33 +0100 Subject: [PATCH 45/67] docs: add relay to index --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index ea30fc8f..b663752a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,6 +8,7 @@ Contents: starter inheritance + relay tips examples tutorial From 39a64e1810921cba06f06d2dbe54fd4cd7546f76 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 24 Feb 2023 13:46:54 +0100 Subject: [PATCH 46/67] docs: fix sphinx problems and add autodoc --- docs/api.rst | 18 ++++++++++++++++-- docs/conf.py | 5 ++++- docs/inheritance.rst | 3 ++- docs/relay.rst | 2 +- docs/starter.rst | 4 ++-- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 66935c7f..acdcbf1a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,4 +1,18 @@ API Reference -==== +============== -.. automodule::graphene_sqlalchemy \ No newline at end of file +SQLAlchemyObjectType +-------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyObjectType + +SQLAlchemyInterface +------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyInterface + +ORMField +-------------------- +.. autoclass:: graphene_sqlalchemy.fields.ORMField + +SQLAlchemyConnectionField +------------------------- +.. autoclass:: graphene_sqlalchemy.SQLAlchemyConnectionField \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 9c9fc1d7..b660fc81 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,7 +23,10 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) +import os +import sys +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. @@ -80,7 +83,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: diff --git a/docs/inheritance.rst b/docs/inheritance.rst index ae80c3b6..277f87ea 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -112,12 +112,13 @@ class to the Schema constructor via the `types=` argument: See also: `Graphene Interfaces `_ Eager Loading & Using with AsyncSession --------------------- +---------------------------------------- When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: .. code:: python + class Person(Base): id = Column(Integer(), primary_key=True) type = Column(String()) diff --git a/docs/relay.rst b/docs/relay.rst index 2cce3b71..7b733c76 100644 --- a/docs/relay.rst +++ b/docs/relay.rst @@ -1,5 +1,5 @@ Relay -==== +========== :code:`graphene-sqlalchemy` comes with pre-defined connection fields to quickly create a functioning relay API. diff --git a/docs/starter.rst b/docs/starter.rst index a288f998..6e09ab00 100644 --- a/docs/starter.rst +++ b/docs/starter.rst @@ -1,5 +1,5 @@ Getting Started -==== +================= Welcome to the graphene-sqlalchemy documentation! Graphene is a powerful Python library for building GraphQL APIs, @@ -80,7 +80,7 @@ Then you can simply query the schema: It is important to provide a session for graphene-sqlalchemy to resolve the models. -In this example, it is provided using the GraphQL context. See :ref:`querying` for +In this example, it is provided using the GraphQL context. See :doc:`tips` for other ways to implement this. You may also subclass SQLAlchemyObjectType by providing From e175f8784e89de85b716cccebe6a4911d6224293 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 24 Feb 2023 15:53:53 +0100 Subject: [PATCH 47/67] housekeeping: add issue management workflow --- .github/workflows/manage_issues.yml | 49 +++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/manage_issues.yml diff --git a/.github/workflows/manage_issues.yml b/.github/workflows/manage_issues.yml new file mode 100644 index 00000000..5876acb5 --- /dev/null +++ b/.github/workflows/manage_issues.yml @@ -0,0 +1,49 @@ +name: Issue Manager + +on: + schedule: + - cron: "0 0 * * *" + issue_comment: + types: + - created + issues: + types: + - labeled + pull_request_target: + types: + - labeled + workflow_dispatch: + +permissions: + issues: write + pull-requests: write + +concurrency: + group: lock + +jobs: + lock-old-closed-issues: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v4 + with: + issue-inactive-days: '180' + process-only: 'issues' + issue-comment: > + This issue has been automatically locked since there + has not been any recent activity after it was closed. + Please open a new issue for related topics referencing + this issue. + close-labelled-issues: + runs-on: ubuntu-latest + steps: + - uses: tiangolo/issue-manager@0.4.0 + with: + token: ${{ secrets.GITHUB_TOKEN }} + config: > + { + "needs-reply": { + "delay": 2200000, + "message": "This issue was closed due to inactivity. If your request is still relevant, please open a new issue referencing this one and provide all of the requested information." + } + } From ba0597f7cbaa4dda3d48c534940ca635de8f4494 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 09:32:08 -0800 Subject: [PATCH 48/67] chore: limit lint runs to master pushes and PRs (#382) --- .github/workflows/lint.yml | 8 +++++++- .github/workflows/tests.yml | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9352dbe5..355a94d2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,12 @@ name: Lint -on: [push, pull_request] +on: + push: + branches: + - 'master' + pull_request: + branches: + - '*' jobs: build: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7632fd38..8b3cadfc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,6 +7,7 @@ on: pull_request: branches: - '*' + jobs: test: runs-on: ubuntu-latest From 506f58c10dd2cf5e2301b9e4fe42db090e7baaeb Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 09:35:35 -0800 Subject: [PATCH 49/67] fix: warnings in docs build (#383) --- docs/api.rst | 4 +-- docs/conf.py | 2 +- docs/inheritance.rst | 8 +++++- docs/requirements.txt | 1 + docs/tips.rst | 1 + graphene_sqlalchemy/types.py | 54 +++++++++++++++++++----------------- 6 files changed, 41 insertions(+), 29 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index acdcbf1a..237cf1b0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,8 +11,8 @@ SQLAlchemyInterface ORMField -------------------- -.. autoclass:: graphene_sqlalchemy.fields.ORMField +.. autoclass:: graphene_sqlalchemy.types.ORMField SQLAlchemyConnectionField ------------------------- -.. autoclass:: graphene_sqlalchemy.SQLAlchemyConnectionField \ No newline at end of file +.. autoclass:: graphene_sqlalchemy.SQLAlchemyConnectionField diff --git a/docs/conf.py b/docs/conf.py index b660fc81..1d8830b6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -178,7 +178,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] +# html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/docs/inheritance.rst b/docs/inheritance.rst index 277f87ea..d7fcca9d 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -1,9 +1,13 @@ Inheritance Examples ==================== + Create interfaces from inheritance relationships ------------------------------------------------ -.. note:: If you're using `AsyncSession`, please check the section `Eager Loading & Using with AsyncSession`_. + +.. note:: + If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. + SQLAlchemy has excellent support for class inheritance hierarchies. These hierarchies can be represented in your GraphQL schema by means of interfaces_. Much like ObjectTypes, Interfaces in @@ -111,8 +115,10 @@ class to the Schema constructor via the `types=` argument: See also: `Graphene Interfaces `_ + Eager Loading & Using with AsyncSession ---------------------------------------- + When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: diff --git a/docs/requirements.txt b/docs/requirements.txt index 666a8c9d..220b7cfb 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ +sphinx # Docs template http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/docs/tips.rst b/docs/tips.rst index daee1731..a3ed69ed 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -5,6 +5,7 @@ Tips Querying -------- .. _querying: + In order to make querying against the database work, there are two alternatives: - Set the db session when you do the execution: diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 226d1e82..66db1e64 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -408,13 +408,15 @@ class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): Usage: - class MyModel(Base): - id = Column(Integer(), primary_key=True) - name = Column(String()) + .. code-block:: python - class MyType(SQLAlchemyObjectType): - class Meta: - model = MyModel + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String()) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel """ @classmethod @@ -450,30 +452,32 @@ class SQLAlchemyInterface(SQLAlchemyBase, Interface): Usage (using joined table inheritance): - class MyBaseModel(Base): - id = Column(Integer(), primary_key=True) - type = Column(String()) - name = Column(String()) + .. code-block:: python - __mapper_args__ = { - "polymorphic_on": type, - } + class MyBaseModel(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) - class MyChildModel(Base): - date = Column(Date()) + __mapper_args__ = { + "polymorphic_on": type, + } - __mapper_args__ = { - "polymorphic_identity": "child", - } + class MyChildModel(Base): + date = Column(Date()) - class MyBaseType(SQLAlchemyInterface): - class Meta: - model = MyBaseModel + __mapper_args__ = { + "polymorphic_identity": "child", + } - class MyChildType(SQLAlchemyObjectType): - class Meta: - model = MyChildModel - interfaces = (MyBaseType,) + class MyBaseType(SQLAlchemyInterface): + class Meta: + model = MyBaseModel + + class MyChildType(SQLAlchemyObjectType): + class Meta: + model = MyChildModel + interfaces = (MyBaseType,) """ @classmethod From 3720a23ddd3bdbd8da644f9066e3b136406765c5 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 27 Feb 2023 21:31:00 +0100 Subject: [PATCH 50/67] release: 3.0.0b4 --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index fb32379c..253e1d9c 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .types import SQLAlchemyInterface, SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "3.0.0b3" +__version__ = "3.0.0b4" __all__ = [ "__version__", From 2ca659a7840635a6058f032b9c00488534a07820 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 12:54:03 -0800 Subject: [PATCH 51/67] docs: update PyPI page (#384) --- README.rst | 102 ----------------------------------------------------- setup.py | 5 ++- 2 files changed, 4 insertions(+), 103 deletions(-) delete mode 100644 README.rst diff --git a/README.rst b/README.rst deleted file mode 100644 index d82b8071..00000000 --- a/README.rst +++ /dev/null @@ -1,102 +0,0 @@ -Please read -`UPGRADE-v2.0.md `__ -to learn how to upgrade to Graphene ``2.0``. - --------------- - -|Graphene Logo| Graphene-SQLAlchemy |Build Status| |PyPI version| |Coverage Status| -=================================================================================== - -A `SQLAlchemy `__ integration for -`Graphene `__. - -Installation ------------- - -For instaling graphene, just run this command in your shell - -.. code:: bash - - pip install "graphene-sqlalchemy>=2.0" - -Examples --------- - -Here is a simple SQLAlchemy model: - -.. code:: python - - from sqlalchemy import Column, Integer, String - from sqlalchemy.orm import backref, relationship - - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() - - class UserModel(Base): - __tablename__ = 'department' - id = Column(Integer, primary_key=True) - name = Column(String) - last_name = Column(String) - -To create a GraphQL schema for it you simply have to write the -following: - -.. code:: python - - from graphene_sqlalchemy import SQLAlchemyObjectType - - class User(SQLAlchemyObjectType): - class Meta: - model = UserModel - - class Query(graphene.ObjectType): - users = graphene.List(User) - - def resolve_users(self, info): - query = User.get_query(info) # SQLAlchemy query - return query.all() - - schema = graphene.Schema(query=Query) - -Then you can simply query the schema: - -.. code:: python - - query = ''' - query { - users { - name, - lastName - } - } - ''' - result = schema.execute(query, context_value={'session': db_session}) - -To learn more check out the following `examples `__: - -- **Full example**: `Flask SQLAlchemy - example `__ - -Contributing ------------- - -After cloning this repo, ensure dependencies are installed by running: - -.. code:: sh - - python setup.py install - -After developing, the full test suite can be evaluated by running: - -.. code:: sh - - python setup.py test # Use --pytest-args="-v -s" for verbose mode - -.. |Graphene Logo| image:: http://graphene-python.org/favicon.png -.. |Build Status| image:: https://travis-ci.org/graphql-python/graphene-sqlalchemy.svg?branch=master - :target: https://travis-ci.org/graphql-python/graphene-sqlalchemy -.. |PyPI version| image:: https://badge.fury.io/py/graphene-sqlalchemy.svg - :target: https://badge.fury.io/py/graphene-sqlalchemy -.. |Coverage Status| image:: https://coveralls.io/repos/graphql-python/graphene-sqlalchemy/badge.svg?branch=master&service=github - :target: https://coveralls.io/github/graphql-python/graphene-sqlalchemy?branch=master diff --git a/setup.py b/setup.py index 9122baf2..ad8bd3b9 100644 --- a/setup.py +++ b/setup.py @@ -34,8 +34,11 @@ name="graphene-sqlalchemy", version=version, description="Graphene SQLAlchemy integration", - long_description=open("README.rst").read(), + long_description=open("README.md").read(), url="https://github.com/graphql-python/graphene-sqlalchemy", + project_urls={ + "Documentation": "https://docs.graphene-python.org/projects/sqlalchemy/en/latest", + }, author="Syrus Akbary", author_email="me@syrusakbary.com", license="MIT", From 882205d9f4fe8d89669d4f81ac74b4ef39b46d7e Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 27 Feb 2023 13:39:08 -0800 Subject: [PATCH 52/67] fix: set README content_type (#385) --- README.md | 8 ++++---- setup.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6e96f91e..29da89da 100644 --- a/README.md +++ b/README.md @@ -109,11 +109,11 @@ schema = graphene.Schema(query=Query) ### Full Examples -To learn more check out the following [examples](examples/): +To learn more check out the following [examples](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/): -- [Flask SQLAlchemy example](examples/flask_sqlalchemy) -- [Nameko SQLAlchemy example](examples/nameko_sqlalchemy) +- [Flask SQLAlchemy example](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/flask_sqlalchemy) +- [Nameko SQLAlchemy example](https://github.com/graphql-python/graphene-sqlalchemy/tree/master/examples/nameko_sqlalchemy) ## Contributing -See [CONTRIBUTING.md](/CONTRIBUTING.md) +See [CONTRIBUTING.md](https://github.com/graphql-python/graphene-sqlalchemy/blob/master/CONTRIBUTING.md) diff --git a/setup.py b/setup.py index ad8bd3b9..0f9ec817 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ version=version, description="Graphene SQLAlchemy integration", long_description=open("README.md").read(), + long_description_content_type="text/markdown", url="https://github.com/graphql-python/graphene-sqlalchemy", project_urls={ "Documentation": "https://docs.graphene-python.org/projects/sqlalchemy/en/latest", From d0668cc82dfd349aa418dd6fc16d54e80162960a Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Sun, 14 May 2023 21:49:03 +0200 Subject: [PATCH 53/67] feat: SQLAlchemy 2.0 support (#368) This PR updates the dataloader and unit tests to be compatible with sqlalchemy 2.0 --- .github/workflows/tests.yml | 4 +- .gitignore | 1 + graphene_sqlalchemy/batching.py | 20 ++++++++- graphene_sqlalchemy/tests/models.py | 23 +++++++--- graphene_sqlalchemy/tests/models_batching.py | 5 ++- graphene_sqlalchemy/tests/test_converter.py | 47 +++++++++++++------- graphene_sqlalchemy/tests/utils.py | 13 +++++- graphene_sqlalchemy/utils.py | 8 +++- setup.py | 2 +- tox.ini | 8 +++- 10 files changed, 100 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8b3cadfc..c471166a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,8 +14,8 @@ jobs: strategy: max-parallel: 10 matrix: - sql-alchemy: ["1.2", "1.3", "1.4"] - python-version: ["3.7", "3.8", "3.9", "3.10"] + sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ] + python-version: [ "3.7", "3.8", "3.9", "3.10" ] steps: - uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index c4a735fe..47a82df0 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ __pycache__/ .Python env/ .venv/ +venv/ build/ develop-eggs/ dist/ diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 23b6712e..a5804516 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -5,8 +5,13 @@ import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext +from sqlalchemy.util import immutabledict -from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, + is_graphene_version_less_than, +) def get_data_loader_impl() -> Any: # pragma: no cover @@ -76,7 +81,18 @@ async def batch_load_fn(self, parents): query_context = parent_mapper_query._compile_context() else: query_context = QueryContext(session.query(parent_mapper.entity)) - if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover + self.selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + None, + None, # recursion depth can be none + immutabledict(), # default value for selectinload->lazyload + ) + elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 5acbc6fd..b638b5d4 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -16,14 +16,23 @@ String, Table, func, - select, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship -from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter from sqlalchemy.sql.type_api import TypeEngine +from graphene_sqlalchemy.tests.utils import wrap_select_func +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 + +# fmt: off +import sqlalchemy +if SQL_VERSION_HIGHER_EQUAL_THAN_2: + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip +else: + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip +# fmt: on + PetKind = Enum("cat", "dog", name="pet_kind") @@ -119,7 +128,7 @@ def hybrid_prop_list(self) -> List[int]: return [1, 2, 3] column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" ) composite_prop = composite( @@ -163,7 +172,11 @@ def __subclasses__(cls): editor_table = Table("editors", Base.metadata, autoload=True) -mapper(ReflectedEditor, editor_table) +# TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4 +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + Base.registry.map_imperatively(ReflectedEditor, editor_table) +else: + mapper(ReflectedEditor, editor_table) ############################################ @@ -337,7 +350,7 @@ class Employee(Person): ############################################ -class CustomIntegerColumn(_LookupExpressionAdapter, TypeEngine): +class CustomIntegerColumn(HasExpressionLookup, TypeEngine): """ Custom Column Type that our converters don't recognize Adapted from sqlalchemy.Integer diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 6f1c42ff..5dde366f 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -11,11 +11,12 @@ String, Table, func, - select, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship +from graphene_sqlalchemy.tests.utils import wrap_select_func + PetKind = Enum("cat", "dog", name="pet_kind") @@ -61,7 +62,7 @@ class Reporter(Base): favorite_article = relationship("Article", uselist=False) column_prop = column_property( - select([func.cast(func.count(id), Integer)]), doc="Column property" + wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property" ) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f70a50f0..884af7d6 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -2,20 +2,28 @@ import sys from typing import Dict, Tuple, Union +import graphene import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from sqlalchemy import Column, func, select, types +from graphene.relay import Node +from graphene.types.structures import Structure +from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -import graphene -from graphene.relay import Node -from graphene.types.structures import Structure - +from .models import ( + Article, + CompositeFullName, + Pet, + Reporter, + ShoppingCart, + ShoppingCartItem, +) +from .utils import wrap_select_func from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -27,6 +35,7 @@ from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory from ..registry import Registry, get_global_registry from ..types import ORMField, SQLAlchemyObjectType +from ..utils import is_sqlalchemy_version_less_than from .models import ( Article, CompositeFullName, @@ -204,9 +213,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -460,7 +469,7 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): field = get_field_from_column( - column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1)) + column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) ) assert field.type == graphene.Int @@ -477,10 +486,18 @@ def test_should_jsontype_convert_jsonstring(): assert get_field(types.JSON).type == graphene.JSONString +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) def test_should_variant_int_convert_int(): assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int +@pytest.mark.skipif( + (not is_sqlalchemy_version_less_than("2.0.0b1")), + reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy", +) def test_should_variant_string_convert_string(): assert get_field(types.Variant(types.String(), {})).type == graphene.String @@ -811,8 +828,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -823,7 +840,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -870,8 +887,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -882,5 +899,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index 4a118243..6e843316 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,6 +1,10 @@ import inspect import re +from sqlalchemy import select + +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 + def to_std_dicts(value): """Convert nested ordered dicts to normal dicts for better comparison.""" @@ -18,8 +22,15 @@ def remove_cache_miss_stat(message): return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) -async def eventually_await_session(session, func, *args): +def wrap_select_func(query): + # TODO remove this when we drop support for sqa < 2.0 + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + return select(query) + else: + return select([query]) + +async def eventually_await_session(session, func, *args): if inspect.iscoroutinefunction(getattr(session, func)): await getattr(session, func)(*args) else: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index ac5be88d..bb9386e8 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -27,12 +27,18 @@ def is_graphene_version_less_than(version_string): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False -if not is_sqlalchemy_version_less_than("1.4"): +if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover from sqlalchemy.ext.asyncio import AsyncSession SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True +SQL_VERSION_HIGHER_EQUAL_THAN_2 = False + +if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover + SQL_VERSION_HIGHER_EQUAL_THAN_2 = True + + def get_session(context): return context.get("session") diff --git a/setup.py b/setup.py index 0f9ec817..fdace116 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ # To keep things simple, we only support newer versions of Graphene "graphene>=3.0.0b7", "promise>=2.3", - "SQLAlchemy>=1.1,<2", + "SQLAlchemy>=1.1", "aiodataloader>=0.2.0,<1.0", ] diff --git a/tox.ini b/tox.ini index 2802dee0..9ce901e4 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pre-commit,py{37,38,39,310}-sql{12,13,14} +envlist = pre-commit,py{37,38,39,310}-sql{12,13,14,20} skipsdist = true minversion = 3.7.0 @@ -15,6 +15,7 @@ SQLALCHEMY = 1.2: sql12 1.3: sql13 1.4: sql14 + 2.0: sql20 [testenv] passenv = GITHUB_* @@ -23,8 +24,11 @@ deps = sql12: sqlalchemy>=1.2,<1.3 sql13: sqlalchemy>=1.3,<1.4 sql14: sqlalchemy>=1.4,<1.5 + sql20: sqlalchemy>=2.0.0b3 +setenv = + SQLALCHEMY_WARN_20 = 1 commands = - pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} + python -W always -m pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs} [testenv:pre-commit] basepython=python3.10 From f5f05d18806838c8cb9dc3d0eb21a84ff8347e11 Mon Sep 17 00:00:00 2001 From: Clemens Tolboom Date: Fri, 6 Oct 2023 22:29:36 +0200 Subject: [PATCH 54/67] docs: Add database session to the example (#249) * Add database session to the example Coming from https://docs.graphene-python.org/projects/sqlalchemy/en/latest/tutorial/ as a python noob I failed to run their example but could fix this example by adding the database session. * Update README.md --------- Co-authored-by: Erik Wrede --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 29da89da..4e61f96c 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,21 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) ``` +We need a database session first: + +```python +from sqlalchemy import (create_engine) +from sqlalchemy.orm import (scoped_session, sessionmaker) + +engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) +db_session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) +# We will need this for querying, Graphene extracts the session from the base. +# Alternatively it can be provided in the GraphQLResolveInfo.context dictionary under context["session"] +Base.query = db_session.query_property() +``` + Then you can simply query the schema: ```python From 1436807fe43d028bd31a06329953e4e2b021eb36 Mon Sep 17 00:00:00 2001 From: Daniel Pepper Date: Fri, 6 Oct 2023 13:33:38 -0700 Subject: [PATCH 55/67] feat: association_proxy support (#267) * association_proxy support * better support for assoc proxy lists (rather than one-to-one) * scope down * add support for sqlalchemy 1.1 * fix pytest due to master merge * fix: throw error when association proxy could not be converted * fix: adjust association proxy to new relationship handling --------- Co-authored-by: Erik Wrede --- graphene_sqlalchemy/converter.py | 51 +++++++++++++++++- graphene_sqlalchemy/tests/models.py | 16 ++++++ graphene_sqlalchemy/tests/test_converter.py | 60 +++++++++++++++++++++ graphene_sqlalchemy/tests/test_query.py | 2 + graphene_sqlalchemy/tests/test_types.py | 16 +++++- graphene_sqlalchemy/types.py | 15 +++++- 6 files changed, 157 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 8c7cd7a1..84c7886c 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,8 +7,14 @@ from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import ( + ColumnProperty, + RelationshipProperty, + class_mapper, + interfaces, + strategies, +) from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import interfaces, strategies import graphene from graphene.types.json import JSONString @@ -101,6 +107,49 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) +def convert_sqlalchemy_association_proxy( + parent, + assoc_prop, + obj_type, + registry, + connection_field_factory, + batching, + resolver, + **field_kwargs, +): + def dynamic_type(): + prop = class_mapper(parent).attrs[assoc_prop.target_collection] + scalar = not prop.uselist + model = prop.mapper.class_ + attr = class_mapper(model).attrs[assoc_prop.value_attr] + + if isinstance(attr, ColumnProperty): + field = convert_sqlalchemy_column(attr, registry, resolver, **field_kwargs) + if not scalar: + # repackage as List + field.__dict__["_type"] = graphene.List(field.type) + return field + elif isinstance(attr, RelationshipProperty): + return convert_sqlalchemy_relationship( + attr, + obj_type, + connection_field_factory, + field_kwargs.pop("batching", batching), + assoc_prop.value_attr, + **field_kwargs, + ).get_type() + else: + raise TypeError( + "Unsupported association proxy target type: {} for prop {} on type {}. " + "Please disable the conversion of this field using an ORMField.".format( + type(attr), assoc_prop, obj_type + ) + ) + # else, not supported + + return graphene.Dynamic(dynamic_type) + + def convert_sqlalchemy_relationship( relationship_prop, obj_type, diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index b638b5d4..c871bedd 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -17,6 +17,7 @@ Table, func, ) +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, column_property, composite, mapper, relationship @@ -78,6 +79,18 @@ def __repr__(self): return "{} {}".format(self.first_name, self.last_name) +class ProxiedReporter(Base): + __tablename__ = "reporters_error" + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + reporter = relationship("Reporter", uselist=False) + + # This is a hybrid property, we don't support proxies on hybrids yet + composite_prop = association_proxy("reporter", "composite_prop") + + class Reporter(Base): __tablename__ = "reporters" @@ -135,6 +148,8 @@ def hybrid_prop_list(self) -> List[int]: CompositeFullName, first_name, last_name, doc="Composite" ) + headlines = association_proxy("articles", "headline") + class Article(Base): __tablename__ = "articles" @@ -145,6 +160,7 @@ class Article(Base): readers = relationship( "Reader", secondary="articles_readers", back_populates="articles" ) + recommended_reads = association_proxy("reporter", "articles") class Reader(Base): diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 884af7d6..84069245 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -25,6 +25,7 @@ ) from .utils import wrap_select_func from ..converter import ( + convert_sqlalchemy_association_proxy, convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, @@ -41,6 +42,7 @@ CompositeFullName, CustomColumnModel, Pet, + ProxiedReporter, Reporter, ShoppingCart, ShoppingCartItem, @@ -650,6 +652,64 @@ class Meta: assert graphene_type.type == A +def test_should_convert_association_proxy(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + + field = convert_sqlalchemy_association_proxy( + Reporter, + Reporter.headlines, + ReporterType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + assert isinstance(field, graphene.Dynamic) + assert isinstance(field.get_type().type, graphene.List) + assert field.get_type().type.of_type == graphene.String + + dynamic_field = convert_sqlalchemy_association_proxy( + Article, + Article.recommended_reads, + ArticleType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + dynamic_field_type = dynamic_field.get_type().type + assert isinstance(dynamic_field, graphene.Dynamic) + assert isinstance(dynamic_field_type, graphene.NonNull) + assert isinstance(dynamic_field_type.of_type, graphene.List) + assert isinstance(dynamic_field_type.of_type.of_type, graphene.NonNull) + assert dynamic_field_type.of_type.of_type.of_type == ArticleType + + +def test_should_throw_error_association_proxy_unsupported_target(): + class ProxiedReporterType(SQLAlchemyObjectType): + class Meta: + model = ProxiedReporter + + field = convert_sqlalchemy_association_proxy( + ProxiedReporter, + ProxiedReporter.composite_prop, + ProxiedReporterType, + get_global_registry(), + default_connection_field_factory, + True, + mock_resolver, + ) + + with pytest.raises(TypeError): + field.get_type() + + def test_should_postgresql_uuid_convert(): assert get_field(postgresql.UUID()).type == graphene.UUID diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 055a87f8..168a82f9 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -80,6 +80,7 @@ async def resolve_reporters(self, _info): columnProp hybridProp compositeProp + headlines } reporters { firstName @@ -92,6 +93,7 @@ async def resolve_reporters(self, _info): "hybridProp": "John", "columnProp": 2, "compositeProp": "John Doe", + "headlines": ["Hi!"], }, "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 3de443d5..e5b154cd 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -138,6 +138,8 @@ class Meta: "pets", "articles", "favorite_article", + # AssociationProxy + "headlines", ] ) @@ -206,6 +208,16 @@ class Meta: assert favorite_article_field.type().type == ArticleType assert favorite_article_field.type().description is None + # assocation proxy + assoc_field = ReporterType._meta.fields["headlines"] + assert isinstance(assoc_field, Dynamic) + assert isinstance(assoc_field.type().type, List) + assert assoc_field.type().type.of_type == String + + assoc_field = ArticleType._meta.fields["recommended_reads"] + assert isinstance(assoc_field, Dynamic) + assert assoc_field.type().type == ArticleType.connection + def test_sqlalchemy_override_fields(): @convert_sqlalchemy_composite.register(CompositeFullName) @@ -275,6 +287,7 @@ class Meta: "hybrid_prop_float", "hybrid_prop_bool", "hybrid_prop_list", + "headlines", ] ) @@ -390,6 +403,7 @@ class Meta: "pets", "articles", "favorite_article", + "headlines", ] ) @@ -510,7 +524,7 @@ class Meta: assert issubclass(CustomReporterType, ObjectType) assert CustomReporterType._meta.model == Reporter - assert len(CustomReporterType._meta.fields) == 17 + assert len(CustomReporterType._meta.fields) == 18 # Test Custom SQLAlchemyObjectType with Custom Options diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 66db1e64..dac5b15f 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,6 +3,7 @@ from typing import Any import sqlalchemy +from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound @@ -16,6 +17,7 @@ from graphene.utils.orderedtype import OrderedType from .converter import ( + convert_sqlalchemy_association_proxy, convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, @@ -152,7 +154,7 @@ def construct_fields( + [ (name, item) for name, item in inspected_model.all_orm_descriptors.items() - if isinstance(item, hybrid_property) + if isinstance(item, hybrid_property) or isinstance(item, AssociationProxy) ] + inspected_model.relationships.items() ) @@ -230,6 +232,17 @@ def construct_fields( field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) + elif isinstance(attr, AssociationProxy): + field = convert_sqlalchemy_association_proxy( + model, + attr, + obj_type, + registry, + connection_field_factory, + batching, + resolver, + **orm_field.kwargs + ) else: raise Exception("Property type is not supported") # Should never happen From b94230e0d85c7f165bb8f4fd320430ffb43dd486 Mon Sep 17 00:00:00 2001 From: Charlie Andrews Date: Mon, 9 Oct 2023 15:26:31 -0400 Subject: [PATCH 56/67] chore: recreate loader if old loader is on different loop (#395) * Recreate loader if old loader is on incorrect loop * Lint --------- Co-authored-by: Cadu --- graphene_sqlalchemy/batching.py | 4 +--- graphene_sqlalchemy/tests/models.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index a5804516..731d7645 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -137,9 +137,7 @@ def _get_loader(relationship_prop): RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader return loader - loader = _get_loader(relationship_prop) - async def resolve(root, info, **args): - return await loader.load(root) + return await _get_loader(relationship_prop).load(root) return resolve diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index c871bedd..be07b896 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -27,7 +27,6 @@ from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 # fmt: off -import sqlalchemy if SQL_VERSION_HIGHER_EQUAL_THAN_2: from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: From c927ada0af29f567a33ff1aa004e85efb9ee7549 Mon Sep 17 00:00:00 2001 From: Sabar Dasgupta Date: Mon, 4 Dec 2023 15:28:33 -0500 Subject: [PATCH 57/67] feat: add filters (#357) Co-authored-by: Paul Schweizer Co-authored-by: Erik Wrede --- .gitignore | 3 + .pre-commit-config.yaml | 4 +- docs/filters.rst | 213 ++++ docs/index.rst | 1 + examples/filters/README.md | 47 + examples/filters/__init__.py | 0 examples/filters/app.py | 16 + examples/filters/database.py | 49 + examples/filters/models.py | 34 + examples/filters/requirements.txt | 3 + examples/filters/run.sh | 1 + examples/filters/schema.py | 42 + graphene_sqlalchemy/converter.py | 15 +- graphene_sqlalchemy/fields.py | 38 +- graphene_sqlalchemy/filters.py | 525 ++++++++ graphene_sqlalchemy/registry.py | 135 +- graphene_sqlalchemy/tests/conftest.py | 22 +- graphene_sqlalchemy/tests/models.py | 50 +- graphene_sqlalchemy/tests/models_batching.py | 11 +- graphene_sqlalchemy/tests/test_converter.py | 53 +- graphene_sqlalchemy/tests/test_filters.py | 1201 ++++++++++++++++++ graphene_sqlalchemy/tests/test_sort_enums.py | 10 +- graphene_sqlalchemy/types.py | 224 +++- graphene_sqlalchemy/utils.py | 13 + 24 files changed, 2635 insertions(+), 75 deletions(-) create mode 100644 docs/filters.rst create mode 100644 examples/filters/README.md create mode 100644 examples/filters/__init__.py create mode 100644 examples/filters/app.py create mode 100644 examples/filters/database.py create mode 100644 examples/filters/models.py create mode 100644 examples/filters/requirements.txt create mode 100755 examples/filters/run.sh create mode 100644 examples/filters/schema.py create mode 100644 graphene_sqlalchemy/filters.py create mode 100644 graphene_sqlalchemy/tests/test_filters.py diff --git a/.gitignore b/.gitignore index 47a82df0..1c86b9be 100644 --- a/.gitignore +++ b/.gitignore @@ -71,5 +71,8 @@ target/ *.sqlite3 .vscode +# Schema +*.gql + # mypy cache .mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 470a29eb..262e7608 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.7 + python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace exclude: README.md - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/docs/filters.rst b/docs/filters.rst new file mode 100644 index 00000000..ac36803d --- /dev/null +++ b/docs/filters.rst @@ -0,0 +1,213 @@ +======= +Filters +======= + +Starting in graphene-sqlalchemy version 3, the SQLAlchemyConnectionField class implements filtering by default. The query utilizes a ``filter`` keyword to specify a filter class that inherits from ``graphene.InputObjectType``. + +Migrating from graphene-sqlalchemy-filter +--------------------------------------------- + +If like many of us, you have been using |graphene-sqlalchemy-filter|_ to implement filters and would like to use the in-built mechanism here, there are a couple key differences to note. Mainly, in an effort to simplify the generated schema, filter keywords are nested under their respective fields instead of concatenated. For example, the filter partial ``{usernameIn: ["moderator", "cool guy"]}`` would be represented as ``{username: {in: ["moderator", "cool guy"]}}``. + +.. |graphene-sqlalchemy-filter| replace:: ``graphene-sqlalchemy-filter`` +.. _graphene-sqlalchemy-filter: https://github.com/art1415926535/graphene-sqlalchemy-filter + +Further, some of the constructs found in libraries like `DGraph's DQL `_ have been implemented, so if you have created custom implementations for these features, you may want to take a look at the examples below. + + +Example model +------------- + +Take as example a Pet model similar to that in the sorting example. We will use variations on this arrangement for the following examples. + +.. code:: + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + + + class Query(graphene.ObjectType): + allPets = SQLAlchemyConnectionField(PetNode.connection) + + +Simple filter example +--------------------- + +Filters are defined at the object level through the ``BaseTypeFilter`` class. The ``BaseType`` encompasses both Graphene ``ObjectType``\ s and ``Interface``\ s. Each ``BaseTypeFilter`` instance may define fields via ``FieldFilter`` and relationships via ``RelationshipFilter``. Here's a basic example querying a single field on the Pet model: + +.. code:: + + allPets(filter: {name: {eq: "Fido"}}){ + edges { + node { + name + } + } + } + +This will return all pets with the name "Fido". + + +Custom filter types +------------------- + +If you'd like to implement custom behavior for filtering a field, you can do so by extending one of the base filter classes in ``graphene_sqlalchemy.filters``. For example, if you'd like to add a ``divisible_by`` keyword to filter the age attribute on the ``Pet`` model, you can do so as follows: + +.. code:: python + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + ... + + age = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + +Filtering over relationships with RelationshipFilter +---------------------------------------------------- + +When a filter class field refers to another object in a relationship, you may nest filters on relationship object attributes. This happens directly for 1:1 and m:1 relationships and through the ``contains`` and ``containsExactly`` keywords for 1:n and m:n relationships. + + +:1 relationships +^^^^^^^^^^^^^^^^ + +When an object or interface defines a singular relationship, relationship object attributes may be filtered directly like so: + +Take the following SQLAlchemy model definition as an example: + +.. code:: python + + class Pet + ... + person_id = Column(Integer(), ForeignKey("people.id")) + + class Person + ... + pets = relationship("Pet", backref="person") + + +Then, this query will return all pets whose person is named "Ada": + +.. code:: + + allPets(filter: { + person: {name: {eq: "Ada"}} + }) { + ... + } + + +:n relationships +^^^^^^^^^^^^^^^^ + +However, for plural relationships, relationship object attributes must be filtered through either ``contains`` or ``containsExactly``: + +Now, using a many-to-many model definition: + +.. code:: python + + people_pets_table = sqlalchemy.Table( + "people_pets", + Base.metadata, + Column("person_id", ForeignKey("people.id")), + Column("pet_id", ForeignKey("pets.id")), + ) + + class Pet + ... + + class Person + ... + pets = relationship("Pet", backref="people") + + +this query will return all pets which have a person named "Ben" in their ``people`` list. + +.. code:: + + allPets(filter: { + people: { + contains: [{name: {eq: "Ben"}}], + } + }) { + ... + } + + +and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. + +.. code:: + + allPets(filter: { + articles: { + containsExactly: [ + {name: {eq: "Ada"}}, + {name: {eq: "Ben"}}, + ], + } + }) { + ... + } + +And/Or Logic +------------ + +Filters can also be chained together logically using `and` and `or` keywords nested under `filter`. Clauses are passed directly to `sqlalchemy.and_` and `slqlalchemy.or_`, respectively. To return all pets named "Fido" or "Spot", use: + + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + {name: {eq: "Spot"}}, + ] + }) { + ... + } + +And to return all pets that are named "Fido" or are 5 years old and named "Spot", use: + +.. code:: + + allPets(filter: { + or: [ + {name: {eq: "Fido"}}, + { and: [ + {name: {eq: "Spot"}}, + {age: {eq: 5}} + } + ] + }) { + ... + } + + +Hybrid Property support +----------------------- + +Filtering over SQLAlchemy `hybrid properties `_ is fully supported. + + +Reporting feedback and bugs +--------------------------- + +Filtering is a new feature to graphene-sqlalchemy, so please `post an issue on Github `_ if you run into any problems or have ideas on how to improve the implementation. diff --git a/docs/index.rst b/docs/index.rst index b663752a..4245eba8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ Contents: inheritance relay tips + filters examples tutorial api diff --git a/examples/filters/README.md b/examples/filters/README.md new file mode 100644 index 00000000..a72e75de --- /dev/null +++ b/examples/filters/README.md @@ -0,0 +1,47 @@ +Example Filters Project +================================ + +This example highlights the ability to filter queries in graphene-sqlalchemy. + +The project contains two models, one named `Department` and another +named `Employee`. + +Getting started +--------------- + +First you'll need to get the source of the project. Do this by cloning the +whole Graphene-SQLAlchemy repository: + +```bash +# Get the example project code +git clone https://github.com/graphql-python/graphene-sqlalchemy.git +cd graphene-sqlalchemy/examples/filters +``` + +It is recommended to create a virtual environment +for this project. We'll do this using +[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) +to keep things simple, +but you may also find something like +[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/) +to be useful: + +```bash +# Create a virtualenv in which we can install the dependencies +virtualenv env +source env/bin/activate +``` + +Install our dependencies: + +```bash +pip install -r requirements.txt +``` + +The following command will setup the database, and start the server: + +```bash +python app.py +``` + +Now head over to your favorite GraphQL client, POST to [http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) and run some queries! diff --git a/examples/filters/__init__.py b/examples/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/filters/app.py b/examples/filters/app.py new file mode 100644 index 00000000..ab918da7 --- /dev/null +++ b/examples/filters/app.py @@ -0,0 +1,16 @@ +from database import init_db +from fastapi import FastAPI +from schema import schema +from starlette_graphene3 import GraphQLApp, make_playground_handler + + +def create_app() -> FastAPI: + init_db() + app = FastAPI() + + app.mount("/graphql", GraphQLApp(schema, on_get=make_playground_handler())) + + return app + + +app = create_app() diff --git a/examples/filters/database.py b/examples/filters/database.py new file mode 100644 index 00000000..8f6522f7 --- /dev/null +++ b/examples/filters/database.py @@ -0,0 +1,49 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, echo=True +) +session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +from sqlalchemy.orm import scoped_session as scoped_session_factory + +scoped_session = scoped_session_factory(session_factory) + +Base.query = scoped_session.query_property() +Base.metadata.bind = engine + + +def init_db(): + from models import Person, Pet, Toy + + Base.metadata.create_all() + scoped_session.execute("PRAGMA foreign_keys=on") + db = scoped_session() + + person1 = Person(name="A") + person2 = Person(name="B") + + pet1 = Pet(name="Spot") + pet2 = Pet(name="Milo") + + toy1 = Toy(name="disc") + toy2 = Toy(name="ball") + + person1.pet = pet1 + person2.pet = pet2 + + pet1.toys.append(toy1) + pet2.toys.append(toy1) + pet2.toys.append(toy2) + + db.add(person1) + db.add(person2) + db.add(pet1) + db.add(pet2) + db.add(toy1) + db.add(toy2) + + db.commit() diff --git a/examples/filters/models.py b/examples/filters/models.py new file mode 100644 index 00000000..1b22956b --- /dev/null +++ b/examples/filters/models.py @@ -0,0 +1,34 @@ +import sqlalchemy +from database import Base +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + age = Column(Integer()) + person_id = Column(Integer(), ForeignKey("people.id")) + + +class Person(Base): + __tablename__ = "people" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + pets = relationship("Pet", backref="person") + + +pets_toys_table = sqlalchemy.Table( + "pets_toys", + Base.metadata, + Column("pet_id", ForeignKey("pets.id")), + Column("toy_id", ForeignKey("toys.id")), +) + + +class Toy(Base): + __tablename__ = "toys" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pets = relationship("Pet", secondary=pets_toys_table, backref="toys") diff --git a/examples/filters/requirements.txt b/examples/filters/requirements.txt new file mode 100644 index 00000000..b433ec59 --- /dev/null +++ b/examples/filters/requirements.txt @@ -0,0 +1,3 @@ +-e ../../ +fastapi +uvicorn diff --git a/examples/filters/run.sh b/examples/filters/run.sh new file mode 100755 index 00000000..ec365444 --- /dev/null +++ b/examples/filters/run.sh @@ -0,0 +1 @@ +uvicorn app:app --port 5000 diff --git a/examples/filters/schema.py b/examples/filters/schema.py new file mode 100644 index 00000000..2728cab7 --- /dev/null +++ b/examples/filters/schema.py @@ -0,0 +1,42 @@ +from models import Person as PersonModel +from models import Pet as PetModel +from models import Toy as ToyModel + +import graphene +from graphene import relay +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene_sqlalchemy.fields import SQLAlchemyConnectionField + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + name = "Pet" + interfaces = (relay.Node,) + batching = True + + +class Person(SQLAlchemyObjectType): + class Meta: + model = PersonModel + name = "Person" + interfaces = (relay.Node,) + batching = True + + +class Toy(SQLAlchemyObjectType): + class Meta: + model = ToyModel + name = "Toy" + interfaces = (relay.Node,) + batching = True + + +class Query(graphene.ObjectType): + node = relay.Node.Field() + pets = SQLAlchemyConnectionField(Pet.connection) + people = SQLAlchemyConnectionField(Person.connection) + toys = SQLAlchemyConnectionField(Toy.connection) + + +schema = graphene.Schema(query=Query) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 84c7886c..efcf3c6c 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,7 +3,7 @@ import typing import uuid from decimal import Decimal -from typing import Any, Optional, Union, cast +from typing import Any, Dict, Optional, TypeVar, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql @@ -21,7 +21,6 @@ from .batching import get_batch_resolver from .enums import enum_for_sa_enum -from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -237,6 +236,8 @@ def _convert_o2m_or_m2m_relationship( :param dict field_kwargs: :rtype: Field """ + from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory + child_type = obj_type._meta.registry.get_type_for_model( relationship_prop.mapper.entity ) @@ -332,8 +333,12 @@ def convert_sqlalchemy_type( # noqa type_arg: Any, column: Optional[Union[MapperProperty, hybrid_property]] = None, registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, **kwargs, ): + if replace_type_vars and type_arg in replace_type_vars: + return replace_type_vars[type_arg] + # No valid type found, raise an error raise TypeError( @@ -373,6 +378,11 @@ def convert_scalar_type(type_arg: Any, **kwargs): return type_arg +@convert_sqlalchemy_type.register(safe_isinstance(TypeVar)) +def convert_type_var(type_arg: Any, replace_type_vars: Dict[TypeVar, Any], **kwargs): + return replace_type_vars[type_arg] + + @convert_sqlalchemy_type.register(column_type_eq(str)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.String)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Text)) @@ -618,6 +628,7 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): # Just get the T out of the list of arguments by filtering out the NoneType nested_types = list(filter(lambda x: not type(None) == x, type_arg.__args__)) + # TODO redo this for , *args, **kwargs # Map the graphene types to the nested types. # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... graphene_types = list(map(convert_sqlalchemy_type, nested_types)) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 6dbc134f..ef798852 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -5,13 +5,19 @@ from promise import Promise, is_thenable from sqlalchemy.orm.query import Query -from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session +from .filters import BaseTypeFilter +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + EnumValue, + get_nullable_type, + get_query, + get_session, +) if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -40,6 +46,7 @@ def type(self): def __init__(self, type_, *args, **kwargs): nullable_type = get_nullable_type(type_) + # Handle Sorting and Filtering if ( "sort" not in kwargs and nullable_type @@ -57,6 +64,19 @@ def __init__(self, type_, *args, **kwargs): ) elif "sort" in kwargs and kwargs["sort"] is None: del kwargs["sort"] + + if ( + "filter" not in kwargs + and nullable_type + and issubclass(nullable_type, Connection) + ): + # Only add filtering if a filter argument exists on the object type + filter_argument = nullable_type.Edge.node._type.get_filter_argument() + if filter_argument: + kwargs.setdefault("filter", filter_argument) + elif "filter" in kwargs and kwargs["filter"] is None: + del kwargs["filter"] + super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs) @property @@ -64,7 +84,7 @@ def model(self): return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): + def get_query(cls, model, info, sort=None, filter=None, **args): query = get_query(model, info.context) if sort is not None: if not isinstance(sort, list): @@ -80,6 +100,12 @@ def get_query(cls, model, info, sort=None, **args): else: sort_args.append(item) query = query.order_by(*sort_args) + + if filter is not None: + assert isinstance(filter, dict) + filter_type: BaseTypeFilter = type(filter) + query, clauses = filter_type.execute_filters(query, filter) + query = query.filter(*clauses) return query @classmethod @@ -264,9 +290,3 @@ def unregisterConnectionFieldFactory(): ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField - - -def get_nullable_type(_type): - if isinstance(_type, NonNull): - return _type.of_type - return _type diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py new file mode 100644 index 00000000..bb422724 --- /dev/null +++ b/graphene_sqlalchemy/filters.py @@ -0,0 +1,525 @@ +import re +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union + +from graphql import Undefined +from sqlalchemy import and_, not_, or_ +from sqlalchemy.orm import Query, aliased # , selectinload + +import graphene +from graphene.types.inputobjecttype import ( + InputObjectTypeContainer, + InputObjectTypeOptions, +) +from graphene_sqlalchemy.utils import is_list + +BaseTypeFilterSelf = TypeVar( + "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer +) + + +class SQLAlchemyFilterInputField(graphene.InputField): + def __init__( + self, + type_, + model_attr, + name=None, + default_value=Undefined, + deprecation_reason=None, + description=None, + required=False, + _creation_counter=None, + **extra_args, + ): + super(SQLAlchemyFilterInputField, self).__init__( + type_, + name, + default_value, + deprecation_reason, + description, + required, + _creation_counter, + **extra_args, + ) + + self.model_attr = model_attr + + +def _get_functions_by_regex( + regex: str, subtract_regex: str, class_: Type +) -> List[Tuple[str, Dict[str, Any]]]: + function_regex = re.compile(regex) + + matching_functions = [] + + # Search the entire class for functions matching the filter regex + for fn in dir(class_): + func_attr = getattr(class_, fn) + # Check if attribute is a function + if callable(func_attr) and function_regex.match(fn): + # add function and attribute name to the list + matching_functions.append( + (re.sub(subtract_regex, "", fn), func_attr.__annotations__) + ) + return matching_functions + + +class BaseTypeFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, filter_fields=None, model=None, _meta=None, **options + ): + from graphene_sqlalchemy.converter import convert_sqlalchemy_type + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls) + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in logic_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + + replace_type_vars = {BaseTypeFilterSelf: cls} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + + if _meta.fields: + _meta.fields.update(filter_fields) + else: + _meta.fields = filter_fields + _meta.fields.update(new_filter_fields) + + _meta.model = model + + super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + @classmethod + def and_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [and_(*clauses)] + + @classmethod + def or_logic( + cls, + query, + filter_type: "BaseTypeFilter", + val: List[BaseTypeFilterSelf], + model_alias=None, + ): + # # Get the model to join on the Filter Query + # joined_model = filter_type._meta.model + # # Always alias the model + # joined_model_alias = aliased(joined_model) + + clauses = [] + for value in val: + # # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + + query, _clauses = filter_type.execute_filters( + query, value, model_alias=model_alias + ) # , model_alias=joined_model_alias) + clauses += _clauses + + return query, [or_(*clauses)] + + @classmethod + def execute_filters( + cls, query, filter_dict: Dict[str, Any], model_alias=None + ) -> Tuple[Query, List[Any]]: + model = cls._meta.model + if model_alias: + model = model_alias + + clauses = [] + + for field, field_filters in filter_dict.items(): + # Relationships are Dynamic, we need to resolve them fist + # Maybe we can cache these dynamics to improve efficiency + # Check with a profiler is required to determine necessity + input_field = cls._meta.fields[field] + if isinstance(input_field, graphene.Dynamic): + input_field = input_field.get_type() + field_filter_type = input_field.type + else: + field_filter_type = cls._meta.fields[field].type + # raise Exception + # TODO we need to save the relationship props in the meta fields array + # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) + if field == "and": + query, _clauses = cls.and_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + elif field == "or": + query, _clauses = cls.or_logic( + query, field_filter_type.of_type, field_filters, model_alias=model + ) + clauses.extend(_clauses) + else: + # Get the model attr from the inputfield in case the field is aliased in graphql + model_field = getattr(model, input_field.model_attr or field) + if issubclass(field_filter_type, BaseTypeFilter): + # Get the model to join on the Filter Query + joined_model = field_filter_type._meta.model + # Always alias the model + joined_model_alias = aliased(joined_model) + # Join the aliased model onto the query + query = query.join(model_field.of_type(joined_model_alias)) + # Pass the joined query down to the next object type filter for processing + query, _clauses = field_filter_type.execute_filters( + query, field_filters, model_alias=joined_model_alias + ) + clauses.extend(_clauses) + if issubclass(field_filter_type, RelationshipFilter): + # TODO see above; not yet working + relationship_prop = field_filter_type._meta.model + # Always alias the model + # joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + # query = query.join(model_field.of_type(joined_model_alias)) + # todo should we use selectinload here instead of join for large lists? + + query, _clauses = field_filter_type.execute_filters( + query, model, model_field, field_filters, relationship_prop + ) + clauses.extend(_clauses) + elif issubclass(field_filter_type, FieldFilter): + query, _clauses = field_filter_type.execute_filters( + query, model_field, field_filters + ) + clauses.extend(_clauses) + + return query, clauses + + +ScalarFilterInputType = TypeVar("ScalarFilterInputType") + + +class FieldFilterOptions(InputObjectTypeOptions): + graphene_type: Type = None + + +class FieldFilter(graphene.InputObjectType): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + @classmethod + def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): + from .converter import convert_sqlalchemy_type + + # get all filter functions + + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + # Init meta options class if it doesn't exist already + if not _meta: + _meta = FieldFilterOptions(cls) + + if not _meta.graphene_type: + _meta.graphene_type = graphene_type + + new_filter_fields = {} + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} + field_type = convert_sqlalchemy_type( + _annotations.get("val", str), replace_type_vars=replace_type_vars + ) + new_filter_fields.update({field_name: graphene.InputField(field_type)}) + + # Add all fields to the meta options. graphene.InputbjectType will take care of the rest + if _meta.fields: + _meta.fields.update(new_filter_fields) + else: + _meta.fields = new_filter_fields + + # Pass modified meta to the super class + super(FieldFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + @classmethod + def in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.in_(val) + + @classmethod + def not_in_filter(cls, query, field, val: List[ScalarFilterInputType]): + return field.notin_(val) + + # TODO add like/ilike + + @classmethod + def execute_filters( + cls, query, field, filter_dict: Dict[str, any] + ) -> Tuple[Query, List[Any]]: + clauses = [] + for filt, val in filter_dict.items(): + clause = getattr(cls, filt + "_filter")(query, field, val) + if isinstance(clause, tuple): + query, clause = clause + clauses.append(clause) + + return query, clauses + + +class SQLEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val.value + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val.value) + + +class PyEnumFilter(FieldFilter): + """Basic Filter for Scalars in Graphene. + We want this filter to use Dynamic fields so it provides the base + filtering methods ("eq, nEq") for different types of scalars. + The Dynamic fields will resolve to Meta.filtered_type""" + + class Meta: + graphene_type = graphene.Enum + + # Abstract methods can be marked using ScalarFilterInputType. See comment on the init method + @classmethod + def eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return field == val + + @classmethod + def n_eq_filter( + cls, query, field, val: ScalarFilterInputType + ) -> Union[Tuple[Query, Any], Any]: + return not_(field == val) + + +class StringFilter(FieldFilter): + class Meta: + graphene_type = graphene.String + + @classmethod + def like_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.like(val) + + @classmethod + def ilike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.ilike(val) + + @classmethod + def notlike_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field.notlike(val) + + +class BooleanFilter(FieldFilter): + class Meta: + graphene_type = graphene.Boolean + + +class OrderedFilter(FieldFilter): + class Meta: + abstract = True + + @classmethod + def gt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field > val + + @classmethod + def gte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field >= val + + @classmethod + def lt_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field < val + + @classmethod + def lte_filter(cls, query, field, val: ScalarFilterInputType) -> bool: + return field <= val + + +class NumberFilter(OrderedFilter): + """Intermediate Filter class since all Numbers are in an order relationship (support <, > etc)""" + + class Meta: + abstract = True + + +class FloatFilter(NumberFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Float + + +class IntFilter(NumberFilter): + class Meta: + graphene_type = graphene.Int + + +class DateFilter(OrderedFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.Date + + +class IdFilter(FieldFilter): + class Meta: + graphene_type = graphene.ID + + +class RelationshipFilter(graphene.InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, base_type_filter=None, model=None, _meta=None, **options + ): + if not base_type_filter: + raise Exception("Relationship Filters must be specific to an object type") + # Init meta options class if it doesn't exist already + if not _meta: + _meta = InputObjectTypeOptions(cls) + + # get all filter functions + filter_functions = _get_functions_by_regex(".+_filter$", "_filter$", cls) + + relationship_filters = {} + + # Generate Graphene Fields from the filter functions based on type hints + for field_name, _annotations in filter_functions: + assert ( + "val" in _annotations + ), "Each filter method must have a value field with valid type annotations" + # If type is generic, replace with actual type of filter class + if is_list(_annotations["val"]): + relationship_filters.update( + {field_name: graphene.InputField(graphene.List(base_type_filter))} + ) + else: + relationship_filters.update( + {field_name: graphene.InputField(base_type_filter)} + ) + + # Add all fields to the meta options. graphene.InputObjectType will take care of the rest + if _meta.fields: + _meta.fields.update(relationship_filters) + else: + _meta.fields = relationship_filters + + _meta.model = model + _meta.base_type_filter = base_type_filter + super(RelationshipFilter, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + @classmethod + def contains_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + clauses = [] + for v in val: + # Always alias the model + joined_model_alias = aliased(relationship_prop) + + # Join the aliased model onto the query + query = query.join(field.of_type(joined_model_alias)).distinct() + # pass the alias so group can join group + query, _clauses = cls._meta.base_type_filter.execute_filters( + query, v, model_alias=joined_model_alias + ) + clauses.append(and_(*_clauses)) + return query, [or_(*clauses)] + + @classmethod + def contains_exactly_filter( + cls, + query, + parent_model, + field, + relationship_prop, + val: List[ScalarFilterInputType], + ): + raise NotImplementedError + + @classmethod + def execute_filters( + cls: Type[FieldFilter], + query, + parent_model, + field, + filter_dict: Dict, + relationship_prop, + ) -> Tuple[Query, List[Any]]: + query, clauses = (query, []) + + for filt, val in filter_dict.items(): + query, _clauses = getattr(cls, filt + "_filter")( + query, parent_model, field, relationship_prop, val + ) + clauses += _clauses + + return query, clauses diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 3c463013..b959d221 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,10 +1,15 @@ +import inspect from collections import defaultdict -from typing import List, Type +from typing import TYPE_CHECKING, List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType import graphene from graphene import Enum +from graphene.types.base import BaseType + +if TYPE_CHECKING: # pragma: no_cover + from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter class Registry(object): @@ -16,6 +21,30 @@ def __init__(self): self._registry_enums = {} self._registry_sort_enums = {} self._registry_unions = {} + self._registry_scalar_filters = {} + self._registry_base_type_filters = {} + self._registry_relationship_filters = {} + + self._init_base_filters() + + def _init_base_filters(self): + import graphene_sqlalchemy.filters as gsqa_filters + + from .filters import FieldFilter + + field_filter_classes = [ + filter_cls[1] + for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) + if ( + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) + ) + ] + for field_filter_class in field_filter_classes: + self.register_filter_for_scalar_type( + field_filter_class._meta.graphene_type, field_filter_class + ) def register(self, obj_type): from .types import SQLAlchemyBase @@ -99,6 +128,110 @@ def register_union_type( def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]): return self._registry_unions.get(frozenset(obj_types)) + # Filter Scalar Fields of Object Types + def register_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not isinstance(scalar_type, type(graphene.Scalar)): + raise TypeError("Expected Scalar, but got: {!r}".format(scalar_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[scalar_type] = filter_obj + + def get_filter_for_sql_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import SQLEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = SQLEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_py_enum_type( + self, enum_type: Type[graphene.Enum] + ) -> Type["FieldFilter"]: + from .filters import PyEnumFilter + + filter_type = self._registry_scalar_filters.get(enum_type) + if not filter_type: + filter_type = PyEnumFilter.create_type( + f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type + ) + self._registry_scalar_filters[enum_type] = filter_type + return filter_type + + def get_filter_for_scalar_type( + self, scalar_type: Type[graphene.Scalar] + ) -> Type["FieldFilter"]: + from .filters import FieldFilter + + filter_type = self._registry_scalar_filters.get(scalar_type) + if not filter_type: + filter_type = FieldFilter.create_type( + f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type + ) + self._registry_scalar_filters[scalar_type] = filter_type + + return filter_type + + # TODO register enums automatically + def register_filter_for_enum_type( + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + ): + from .filters import FieldFilter + + if not issubclass(enum_type, graphene.Enum): + raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) + + if not issubclass(filter_obj, FieldFilter): + raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) + self._registry_scalar_filters[enum_type] = filter_obj + + # Filter Base Types + def register_filter_for_base_type( + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], + ): + from .filters import BaseTypeFilter + + if not issubclass(base_type, BaseType): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, BaseTypeFilter): + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) + self._registry_base_type_filters[base_type] = filter_obj + + def get_filter_for_base_type(self, base_type: Type[BaseType]): + return self._registry_base_type_filters.get(base_type) + + # Filter Relationships between base types + def register_relationship_filter_for_base_type( + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + ): + from .filters import RelationshipFilter + + if not isinstance(base_type, type(BaseType)): + raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) + + if not issubclass(filter_obj, RelationshipFilter): + raise TypeError( + "Expected RelationshipFilter, but got: {!r}".format(filter_obj) + ) + self._registry_relationship_filters[base_type] = filter_obj + + def get_relationship_filter_for_base_type( + self, base_type: Type[BaseType] + ) -> "RelationshipFilter": + return self._registry_relationship_filters.get(base_type) + registry = None diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 89b357a4..2c749da7 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -2,6 +2,7 @@ import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing_extensions import Literal import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 @@ -25,14 +26,23 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(params=[False, True]) -def async_session(request): +# make a typed literal for session one is sync and one is async +SESSION_TYPE = Literal["sync", "session_factory"] + + +@pytest.fixture(params=["sync", "async"]) +def session_type(request) -> SESSION_TYPE: return request.param @pytest.fixture -def test_db_url(async_session: bool): - if async_session: +def async_session(session_type): + return session_type == "async" + + +@pytest.fixture +def test_db_url(session_type: SESSION_TYPE): + if session_type == "async": return "sqlite+aiosqlite://" else: return "sqlite://" @@ -40,8 +50,8 @@ def test_db_url(async_session: bool): @pytest.mark.asyncio @pytest_asyncio.fixture(scope="function") -async def session_factory(async_session: bool, test_db_url: str): - if async_session: +async def session_factory(session_type: SESSION_TYPE, test_db_url: str): + if session_type == "async": if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") engine = create_async_engine(test_db_url) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index be07b896..8911b0a2 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,6 +6,7 @@ from decimal import Decimal from typing import List, Optional +# fmt: off from sqlalchemy import ( Column, Date, @@ -24,13 +25,16 @@ from sqlalchemy.sql.type_api import TypeEngine from graphene_sqlalchemy.tests.utils import wrap_select_func -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 +from graphene_sqlalchemy.utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, +) # fmt: off if SQL_VERSION_HIGHER_EQUAL_THAN_2: - from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: - from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip # fmt: on PetKind = Enum("cat", "dog", name="pet_kind") @@ -64,6 +68,7 @@ class Pet(Base): pet_kind = Column(PetKind, nullable=False) hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + legs = Column(Integer(), default=4) class CompositeFullName(object): @@ -150,6 +155,27 @@ def hybrid_prop_list(self) -> List[int]: headlines = association_proxy("articles", "headline") +articles_tags_table = Table( + "articles_tags", + Base.metadata, + Column("article_id", ForeignKey("articles.id")), + Column("tag_id", ForeignKey("tags.id")), +) + + +class Image(Base): + __tablename__ = "images" + id = Column(Integer(), primary_key=True) + external_id = Column(Integer()) + description = Column(String(30)) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + + class Article(Base): __tablename__ = "articles" id = Column(Integer(), primary_key=True) @@ -161,6 +187,13 @@ class Article(Base): ) recommended_reads = association_proxy("reporter", "articles") + # one-to-one relationship with image + image_id = Column(Integer(), ForeignKey("images.id"), unique=True) + image = relationship("Image", backref=backref("articles", uselist=False)) + + # many-to-many relationship with tags + tags = relationship("Tag", secondary=articles_tags_table, backref="articles") + class Reader(Base): __tablename__ = "readers" @@ -273,11 +306,20 @@ def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: ], ] - # Other SQLAlchemy Instances + # Other SQLAlchemy Instance @hybrid_property def hybrid_prop_first_shopping_cart_item(self) -> ShoppingCartItem: return ShoppingCartItem(id=1) + # Other SQLAlchemy Instance with expression + @hybrid_property + def hybrid_prop_first_shopping_cart_item_expression(self) -> ShoppingCartItem: + return ShoppingCartItem(id=1) + + @hybrid_prop_first_shopping_cart_item_expression.expression + def hybrid_prop_first_shopping_cart_item_expression(cls): + return ShoppingCartItem + # Other SQLAlchemy Instances @hybrid_property def hybrid_prop_shopping_cart_item_list(self) -> List[ShoppingCartItem]: diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 5dde366f..e0f5d4bd 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -2,16 +2,7 @@ import enum -from sqlalchemy import ( - Column, - Date, - Enum, - ForeignKey, - Integer, - String, - Table, - func, -) +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 84069245..e62e07d2 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,13 +1,10 @@ import enum import sys -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, TypeVar, Union -import graphene import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from graphene.relay import Node -from graphene.types.structures import Structure from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -15,15 +12,10 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from .models import ( - Article, - CompositeFullName, - Pet, - Reporter, - ShoppingCart, - ShoppingCartItem, -) -from .utils import wrap_select_func +import graphene +from graphene.relay import Node +from graphene.types.structures import Structure + from ..converter import ( convert_sqlalchemy_association_proxy, convert_sqlalchemy_column, @@ -47,6 +39,7 @@ ShoppingCart, ShoppingCartItem, ) +from .utils import wrap_select_func def mock_resolver(): @@ -206,6 +199,17 @@ def hybrid_prop(self) -> "ShoppingCartItem": get_hybrid_property_type(hybrid_prop).type == ShoppingCartType +def test_converter_replace_type_var(): + + T = TypeVar("T") + + replace_type_vars = {T: graphene.String} + + field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars) + + assert field_type == graphene.String + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -215,9 +219,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -471,7 +475,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): field = get_field_from_column( - column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) + column_property( + wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1) + ) ) assert field.type == graphene.Int @@ -888,8 +894,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -900,7 +906,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -925,6 +931,7 @@ class Meta: graphene.List(graphene.List(graphene.Int)) ), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, + "hybrid_prop_first_shopping_cart_item_expression": ShoppingCartItemType, "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, @@ -947,8 +954,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -959,5 +966,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py new file mode 100644 index 00000000..4acf89a8 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -0,0 +1,1201 @@ +import pytest +from sqlalchemy.sql.operators import is_ + +import graphene +from graphene import Connection, relay + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType +from .models import ( + Article, + Editor, + HairKind, + Image, + Pet, + Reader, + Reporter, + ShoppingCart, + ShoppingCartItem, + Tag, +) +from .utils import eventually_await_session, to_std_dicts + +# TODO test that generated schema is correct for all examples with: +# with open('schema.gql', 'w') as fp: +# fp.write(str(schema)) + + +def assert_and_raise_result(result, expected): + if result.errors: + for error in result.errors: + raise error + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +async def add_test_data(session): + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") + session.add(reporter) + + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT, legs=4) + pet.reporter = reporter + session.add(pet) + + pet = Pet(name="Snoopy", pet_kind="dog", hair_kind=HairKind.SHORT, legs=3) + pet.reporter = reporter + session.add(pet) + + reporter = Reporter(first_name="John", last_name="Woe", favorite_pet_kind="cat") + session.add(reporter) + + article = Article(headline="Hi!") + article.reporter = reporter + session.add(article) + + article = Article(headline="Hello!") + article.reporter = reporter + session.add(article) + + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") + session.add(reporter) + + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) + pet.reporter = reporter + session.add(pet) + + editor = Editor(name="Jack") + session.add(editor) + + await eventually_await_session(session, "commit") + + +def create_schema(session): + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + + class ImageType(SQLAlchemyObjectType): + class Meta: + model = Image + name = "Image" + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + + class ReaderType(SQLAlchemyObjectType): + class Meta: + model = Reader + name = "Reader" + interfaces = (relay.Node,) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + class TagType(SQLAlchemyObjectType): + class Meta: + model = Tag + name = "Tag" + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection) + images = SQLAlchemyConnectionField(ImageType.connection) + readers = SQLAlchemyConnectionField(ReaderType.connection) + reporters = SQLAlchemyConnectionField(ReporterType.connection) + pets = SQLAlchemyConnectionField(PetType.connection) + tags = SQLAlchemyConnectionField(TagType.connection) + + return Query + + +# Test a simple example of filtering +@pytest.mark.asyncio +async def test_filter_simple(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: {lastName: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_alias(session): + """ + Test aliasing of column names in the type + """ + await add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + name = "Reporter" + interfaces = (relay.Node,) + + lastNameAlias = ORMField(model_attr="last_name") + + class Query(graphene.ObjectType): + node = relay.Node.Field() + reporters = SQLAlchemyConnectionField(ReporterType.connection) + + query = """ + query { + reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"firstName": "Jane"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a custom filter type +@pytest.mark.asyncio +async def test_filter_custom_type(session): + await add_test_data(session) + + class MathFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def divisible_by_filter(cls, query, field, val: int) -> bool: + return is_(field % val, 0) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection + + legs = ORMField(filter_type=MathFilter) + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + query = """ + query { + pets (filter: { + legs: {divisibleBy: 2} + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": { + "edges": [{"node": {"name": "Garfield"}}, {"node": {"name": "Lassie"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test filtering on enums +@pytest.mark.asyncio +async def test_filter_enum(session): + await add_test_data(session) + + Query = create_schema(session) + + # test sqlalchemy enum + query = """ + query { + reporters (filter: { + favoritePetKind: {eq: DOG} + } + ) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + } + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test Python enum and sqlalchemy enum + query = """ + query { + pets (filter: { + and: [ + { hairKind: {eq: LONG} }, + { petKind: {eq: DOG} } + ]}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Lassie"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:1 relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_one(session): + article = Article(headline="Hi!") + image = Image(external_id=1, description="A beautiful image.") + article.image = image + session.add(article) + session.add(image) + await eventually_await_session(session, "commit") + + Query = create_schema(session) + + query = """ + query { + articles (filter: { + image: {description: {eq: "A beautiful image."}} + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test a 1:n relationship +@pytest.mark.asyncio +async def test_filter_relationship_one_to_many(session): + await add_test_data(session) + Query = create_schema(session) + + # test contains + query = """ + query { + reporters (filter: { + articles: { + contains: [{headline: {eq: "Hi!"}}], + } + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # TODO test containsExactly + # # test containsExactly + # query = """ + # query { + # reporters (filter: { + # articles: { + # containsExactly: [ + # {headline: {eq: "Hi!"}} + # {headline: {eq: "Hello!"}} + # ] + # } + # }) { + # edges { + # node { + # firstName + # lastName + # } + # } + # } + # } + # """ + # expected = { + # "reporters": {"edges": [{"node": {"firstName": "John", "lastName": "Woe"}}]} + # } + # schema = graphene.Schema(query=Query) + # result = await schema.execute_async(query, context_value={"session": session}) + # assert_and_raise_result(result, expected) + + +async def add_n2m_test_data(session): + # create objects + reader1 = Reader(name="Ada") + reader2 = Reader(name="Bip") + article1 = Article(headline="Article! Look!") + article2 = Article(headline="Woah! Another!") + tag1 = Tag(name="sensational") + tag2 = Tag(name="eye-grabbing") + image1 = Image(description="article 1") + image2 = Image(description="article 2") + + # set relationships + article1.tags = [tag1] + article2.tags = [tag1, tag2] + article1.image = image1 + article2.image = image2 + reader1.articles = [article1] + reader2.articles = [article1, article2] + + # save + session.add(image1) + session.add(image2) + session.add(tag1) + session.add(tag2) + session.add(article1) + session.add(article2) + session.add(reader1) + session.add(reader2) + await eventually_await_session(session, "commit") + + +# Test n:m relationship contains +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Article! Look!"}}, + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_with_and(session): + """ + This test is necessary to ensure we don't accidentally turn and-contains filter + into or-contains filters due to incorrect aliasing of the joined table. + """ + await add_n2m_test_data(session) + Query = create_schema(session) + + # test contains 1 + query = """ + query { + articles (filter: { + tags: { + contains: [{ + and: [ + { name: { in: ["sensational", "eye-grabbing"] } }, + { name: { eq: "eye-grabbing" } }, + ] + + } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test contains 2 + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": { + "edges": [ + {"node": {"headline": "Woah! Another!"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { headline: { eq: "Article! Look!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": { + "edges": [ + {"node": {"name": "sensational"}}, + ], + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + # test containsExactly 1 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test containsExactly 2 + query = """ + query { + articles (filter: { + tags: { + containsExactly: [ + { name: { eq: "sensational" } } + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Article! Look!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test reverse + query = """ + query { + tags (filter: { + articles: { + containsExactly: [ + { headline: { eq: "Article! Look!" } }, + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "eye-grabbing"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m relationship both contains and containsExactly +@pytest.mark.xfail +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_contains_and_contains_exactly(session): + raise NotImplementedError + await add_n2m_test_data(session) + Query = create_schema(session) + + query = """ + query { + articles (filter: { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + containsExactly: [ + { name: { eq: "eye-grabbing" } }, + { name: { eq: "sensational" } }, + ] + } + }) { + edges { + node { + headline + } + } + } + } + """ + expected = { + "articles": {"edges": [{"node": {"headline": "Woah! Another!"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test n:m nested relationship +# TODO add containsExactly +@pytest.mark.asyncio +async def test_filter_relationship_many_to_many_nested(session): + await add_n2m_test_data(session) + Query = create_schema(session) + + # test readers->articles relationship + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested readers->articles->tags + query = """ + query { + readers (filter: { + articles: { + contains: [ + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test nested reverse + query = """ + query { + tags (filter: { + articles: { + contains: [ + { + readers: { + contains: [ + { name: { eq: "Ada" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "tags": {"edges": [{"node": {"name": "sensational"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test filter on both levels of nesting + query = """ + query { + readers (filter: { + articles: { + contains: [ + { headline: { eq: "Woah! Another!" } }, + { + tags: { + contains: [ + { name: { eq: "eye-grabbing" } }, + ] + } + } + ] + } + }) { + edges { + node { + name + } + } + } + } + """ + expected = { + "readers": {"edges": [{"node": {"name": "Bip"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" +@pytest.mark.asyncio +async def test_filter_logic_and(session): + await add_test_data(session) + + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { favoritePetKind: { eq: CAT } }, + ] + }) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [{"node": {"lastName": "Doe"}}, {"node": {"lastName": "Woe"}}] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "or" +@pytest.mark.asyncio +async def test_filter_logic_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + or: [ + { lastName: { eq: "Woe" } }, + { favoritePetKind: { eq: DOG } }, + ] + }) { + edges { + node { + firstName + lastName + favoritePetKind + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + { + "node": { + "firstName": "John", + "lastName": "Woe", + "favoritePetKind": "CAT", + } + }, + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + }, + ] + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +# Test connecting filters with "and" and "or" together +@pytest.mark.asyncio +async def test_filter_logic_and_or(session): + await add_test_data(session) + Query = create_schema(session) + + query = """ + query { + reporters (filter: { + and: [ + { firstName: { eq: "John" } }, + { + or: [ + { lastName: { eq: "Doe" } }, + # TODO get enums working for filters + # { favoritePetKind: { eq: "cat" } }, + ] + } + ] + }) { + edges { + node { + firstName + } + } + } + } + """ + expected = { + "reporters": { + "edges": [ + {"node": {"firstName": "John"}}, + # {"node": {"firstName": "Jane"}}, + ], + } + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + +async def add_hybrid_prop_test_data(session): + cart = ShoppingCart() + session.add(cart) + await eventually_await_session(session, "commit") + + +def create_hybrid_prop_schema(session): + class ShoppingCartItemType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + name = "ShoppingCartItem" + interfaces = (relay.Node,) + connection_class = Connection + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCart + name = "ShoppingCart" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + items = SQLAlchemyConnectionField(ShoppingCartItemType.connection) + carts = SQLAlchemyConnectionField(ShoppingCartType.connection) + + return Query + + +# Test filtering over and returning hybrid_property +@pytest.mark.asyncio +async def test_filter_hybrid_property(session): + await add_hybrid_prop_test_data(session) + Query = create_hybrid_prop_schema(session) + + # test hybrid_prop_int + query = """ + query { + carts (filter: {hybridPropInt: {eq: 42}}) { + edges { + node { + hybridPropInt + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropInt": 42}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop_float + query = """ + query { + carts (filter: {hybridPropFloat: {gt: 42}}) { + edges { + node { + hybridPropFloat + } + } + } + } + """ + expected = { + "carts": { + "edges": [ + {"node": {"hybridPropFloat": 42.3}}, + ] + }, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test hybrid_prop different model without expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItem { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop different model with expression + query = """ + query { + carts { + edges { + node { + hybridPropFirstShoppingCartItemExpression { + id + } + } + } + } + } + """ + + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + + # test hybrid_prop list of models + query = """ + query { + carts { + edges { + node { + hybridPropShoppingCartItemList { + id + } + } + } + } + } + """ + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert len(result["carts"]["edges"]) == 1 + assert ( + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + ) + + +# Test edge cases to improve test coverage +@pytest.mark.asyncio +async def test_filter_edge_cases(session): + await add_test_data(session) + + # test disabling filtering + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + name = "Article" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + node = relay.Node.Field() + articles = SQLAlchemyConnectionField(ArticleType.connection, filter=None) + + schema = graphene.Schema(query=Query) + assert not hasattr(schema, "ArticleTypeFilter") + + +# Test additional filter types to improve test coverage +@pytest.mark.asyncio +async def test_additional_filters(session): + await add_test_data(session) + Query = create_schema(session) + + # test n_eq and not_in filters + query = """ + query { + reporters (filter: {firstName: {nEq: "Jane"}, lastName: {notIn: "Doe"}}) { + edges { + node { + lastName + } + } + } + } + """ + expected = { + "reporters": {"edges": [{"node": {"lastName": "Woe"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + # test gt, lt, gte, and lte filters + query = """ + query { + pets (filter: {legs: {gt: 2, lt: 4, gte: 3, lte: 3}}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "pets": {"edges": [{"node": {"name": "Snoopy"}}]}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index f8f1ff8c..bb530f2c 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -41,6 +41,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -95,6 +97,8 @@ class Meta: "PET_KIND_DESC", "HAIR_KIND_ASC", "HAIR_KIND_DESC", + "LEGS_ASC", + "LEGS_DESC", ] @@ -135,6 +139,8 @@ class Meta: "HAIR_KIND_DESC", "REPORTER_ID_ASC", "REPORTER_ID_DESC", + "LEGS_ASC", + "LEGS_DESC", ] assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" @@ -149,7 +155,7 @@ def test_sort_argument_with_excluded_fields_in_object_type(): class PetType(SQLAlchemyObjectType): class Meta: model = Pet - exclude_fields = ["hair_kind", "reporter_id"] + exclude_fields = ["hair_kind", "reporter_id", "legs"] sort_arg = PetType.sort_argument() sort_enum = sort_arg.type._of_type @@ -238,6 +244,8 @@ def get_symbol_name(column_name, sort_asc=True): "HairKindDown", "ReporterIdUp", "ReporterIdDown", + "LegsUp", + "LegsDown", ] assert sort_arg.default_value == ["IdUp"] diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index dac5b15f..18d06eef 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,6 +1,10 @@ +import inspect +import logging +import warnings from collections import OrderedDict +from functools import partial from inspect import isawaitable -from typing import Any +from typing import Any, Optional, Type, Union import sqlalchemy from sqlalchemy.ext.associationproxy import AssociationProxy @@ -8,11 +12,13 @@ from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty from sqlalchemy.orm.exc import NoResultFound -from graphene import Field +import graphene +from graphene import Dynamic, Field, InputField from graphene.relay import Connection, Node from graphene.types.base import BaseType from graphene.types.interface import Interface, InterfaceOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType @@ -28,10 +34,12 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) +from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_nullable_type, get_query, get_session, is_mapped_class, @@ -41,6 +49,8 @@ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession +logger = logging.getLogger(__name__) + class ORMField(OrderedType): def __init__( @@ -51,8 +61,10 @@ def __init__( description=None, deprecation_reason=None, batching=None, + create_filter=None, + filter_type: Optional[Type] = None, _creation_counter=None, - **field_kwargs + **field_kwargs, ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -89,6 +101,12 @@ class Meta: Same behavior as in graphene.Field. Defaults to None. :param bool batching: Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. + :param bool create_filter: + Create a filter for this field. Defaults to True. + :param Type filter_type: + Override for the filter of this field with a custom filter type. + Default behavior is to get a matching filter type for this field from the registry. + Create_filter needs to be true :param int _creation_counter: Same behavior as in graphene.Field. """ @@ -100,6 +118,8 @@ class Meta: "required": required, "description": description, "deprecation_reason": deprecation_reason, + "create_filter": create_filter, + "filter_type": filter_type, "batching": batching, } common_kwargs = { @@ -109,6 +129,139 @@ class Meta: self.kwargs.update(common_kwargs) +def get_or_create_relationship_filter( + base_type: Type[BaseType], registry: Registry +) -> Type[RelationshipFilter]: + relationship_filter = registry.get_relationship_filter_for_base_type(base_type) + + if not relationship_filter: + try: + base_type_filter = registry.get_filter_for_base_type(base_type) + relationship_filter = RelationshipFilter.create_type( + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, + ) + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter + ) + except Exception as e: + print("e") + raise e + + return relationship_filter + + +def filter_field_from_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, + model_attr_name: str, +) -> Optional[graphene.InputField]: + # Field might be a SQLAlchemyObjectType, due to hybrid properties + if issubclass(type_, SQLAlchemyObjectType): + filter_class = registry.get_filter_for_base_type(type_) + # Enum Special Case + elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): + column = model_attr.columns[0] + model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None) + if not getattr(model_enum_type, "enum_class", None): + filter_class = registry.get_filter_for_sql_enum_type(type_) + else: + filter_class = registry.get_filter_for_py_enum_type(type_) + else: + filter_class = registry.get_filter_for_scalar_type(type_) + if not filter_class: + warnings.warn( + f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field." + ) + return None + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + + +def resolve_dynamic_relationship_filter( + field: graphene.Dynamic, registry: Registry, model_attr_name: str +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # Resolve Dynamic Type + type_ = get_nullable_type(field.get_type()) + from graphene_sqlalchemy import SQLAlchemyConnectionField + + # Connections always result in list filters + if isinstance(type_, SQLAlchemyConnectionField): + inner_type = get_nullable_type(type_.type.Edge.node._type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + # Field relationships can either be a list or a single object + elif isinstance(type_, Field): + if isinstance(type_.type, graphene.List): + inner_type = get_nullable_type(type_.type.of_type) + reg_res = get_or_create_relationship_filter(inner_type, registry) + else: + reg_res = registry.get_filter_for_base_type(type_.type) + else: + # Other dynamic type constellation are not yet supported, + # please open an issue with reproduction if you need them + reg_res = None + + if not reg_res: + warnings.warn( + f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field." + ) + return None + + return SQLAlchemyFilterInputField(reg_res, model_attr_name) + + +def filter_field_from_type_field( + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, + model_attr_name: str, +) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: + # If a custom filter type was set for this field, use it here + if filter_type: + return SQLAlchemyFilterInputField(filter_type, model_attr_name) + elif issubclass(type(field), graphene.Scalar): + filter_class = registry.get_filter_for_scalar_type(type(field)) + return SQLAlchemyFilterInputField(filter_class, model_attr_name) + # If the generated field is Dynamic, it is always a relationship + # (due to graphene-sqlalchemy's conversion mechanism). + elif isinstance(field, graphene.Dynamic): + return Dynamic( + partial( + resolve_dynamic_relationship_filter, field, registry, model_attr_name + ) + ) + # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them + elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): + # Pure lists are not yet supported + pass + elif isinstance(field._type, graphene.Dynamic): + # Fields with nested dynamic Dynamic are not yet supported + pass + # Order matters, this comes last as field._type == list also matches Field + elif isinstance(field, graphene.Field): + if inspect.isfunction(field._type) or isinstance(field._type, partial): + return Dynamic( + lambda: filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + ) + else: + return filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + + def get_polymorphic_on(model): """ Check whether this model is a polymorphic type, and if so return the name @@ -121,13 +274,14 @@ def get_polymorphic_on(model): return polymorphic_on.name -def construct_fields( +def construct_fields_and_filters( obj_type, model, registry, only_fields, exclude_fields, batching, + create_filters, connection_field_factory, ): """ @@ -143,6 +297,7 @@ def construct_fields( :param tuple[string] only_fields: :param tuple[string] exclude_fields: :param bool batching: + :param bool create_filters: Enable filter generation for this type :param function|None connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ @@ -201,7 +356,12 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() + filters = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): + filtering_enabled_for_field = orm_field.kwargs.pop( + "create_filter", create_filters + ) + filter_type = orm_field.kwargs.pop("filter_type", None) attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( @@ -220,7 +380,7 @@ def construct_fields( connection_field_factory, batching_, orm_field_name, - **orm_field.kwargs + **orm_field.kwargs, ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: @@ -241,15 +401,21 @@ def construct_fields( connection_field_factory, batching, resolver, - **orm_field.kwargs + **orm_field.kwargs, ) else: raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field + if filtering_enabled_for_field and not isinstance(attr, AssociationProxy): + # we don't support filtering on association proxies yet. + # Support will be patched in a future release of graphene-sqlalchemy + filters[orm_field_name] = filter_field_from_type_field( + field, registry, filter_type, attr, attr_name + ) - return fields + return fields, filters class SQLAlchemyBase(BaseType): @@ -274,7 +440,7 @@ def __init_subclass_with_meta__( batching=False, connection_field_factory=None, _meta=None, - **options + **options, ): # We always want to bypass this hook unless we're defining a concrete # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. @@ -301,16 +467,19 @@ def __init_subclass_with_meta__( "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." ) + fields, filters = construct_fields_and_filters( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + batching=batching, + create_filters=True, + connection_field_factory=connection_field_factory, + ) + sqla_fields = yank_fields_from_attrs( - construct_fields( - obj_type=cls, - model=model, - registry=registry, - only_fields=only_fields, - exclude_fields=exclude_fields, - batching=batching, - connection_field_factory=connection_field_factory, - ), + fields, _as=Field, sort=False, ) @@ -342,6 +511,19 @@ def __init_subclass_with_meta__( else: _meta.fields = sqla_fields + # Save Generated filter class in Meta Class + if not _meta.filter_class: + # Map graphene fields to filters + # TODO we might need to pass the ORMFields containing the SQLAlchemy models + # to the scalar filters here (to generate expressions from the model) + + filter_fields = yank_fields_from_attrs(filters, _as=InputField, sort=False) + + _meta.filter_class = BaseTypeFilter.create_type( + f"{cls.__name__}Filter", filter_fields=filter_fields, model=model + ) + registry.register_filter_for_base_type(cls, _meta.filter_class) + _meta.connection = connection _meta.id = id or "id" @@ -401,6 +583,12 @@ def resolve_id(self, info): def enum_for_field(cls, field_name): return enum_for_field(cls, field_name) + @classmethod + def get_filter_argument(cls): + if cls._meta.filter_class: + return graphene.Argument(cls._meta.filter_class) + return None + sort_enum = classmethod(sort_enum_for_object_type) sort_argument = classmethod(sort_argument_for_object_type) @@ -411,6 +599,7 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): @@ -447,6 +636,7 @@ class SQLAlchemyInterfaceOptions(InterfaceOptions): registry = None # type: sqlalchemy.Registry connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str + filter_class: Type[BaseTypeFilter] = None class SQLAlchemyInterface(SQLAlchemyBase, Interface): diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index bb9386e8..3ba14865 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,4 +1,5 @@ import re +import typing import warnings from collections import OrderedDict from functools import _c3_mro @@ -10,6 +11,14 @@ from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene import NonNull + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type + def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" @@ -259,6 +268,10 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: pass +def is_list(x): + return getattr(x, "__origin__", None) in [list, typing.List] + + class DummyImport: """The dummy module returns 'object' for a query for any member""" From ae4f87c771763c6b511218158d1f1af55d1708fb Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 4 Dec 2023 21:45:15 +0100 Subject: [PATCH 58/67] fix: keep converting tuples to strings for composite primary keys in relay ID field (#399) --- graphene_sqlalchemy/tests/models.py | 7 ++++ graphene_sqlalchemy/tests/test_types.py | 51 +++++++++++++++++++++++++ graphene_sqlalchemy/types.py | 2 +- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 8911b0a2..e1ee9858 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -436,3 +436,10 @@ class CustomColumnModel(Base): id = Column(Integer(), primary_key=True) custom_col = Column(CustomIntegerColumn) + + +class CompositePrimaryKeyTestModel(Base): + __tablename__ = "compositekeytestmodel" + + first_name = Column(String(30), primary_key=True) + last_name = Column(String(30), primary_key=True) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index e5b154cd..f25b0dc2 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -9,6 +9,7 @@ from graphene import ( Boolean, + DefaultGlobalIDType, Dynamic, Field, Float, @@ -42,6 +43,7 @@ from .models import ( Article, CompositeFullName, + CompositePrimaryKeyTestModel, Employee, NonAbstractPerson, Person, @@ -513,6 +515,55 @@ async def resolve_reporter(self, _info): # Test Custom SQLAlchemyObjectType Implementation +@pytest.mark.asyncio +async def test_composite_id_resolver(session): + """Test that the correct resolver functions are called""" + + composite_reporter = CompositePrimaryKeyTestModel( + first_name="graphql", last_name="foundation" + ) + + session.add(composite_reporter) + await eventually_await_session(session, "commit") + + class CompositePrimaryKeyTestModelType(SQLAlchemyObjectType): + class Meta: + model = CompositePrimaryKeyTestModel + interfaces = (Node,) + + class Query(ObjectType): + composite_reporter = Field(CompositePrimaryKeyTestModelType) + + async def resolve_composite_reporter(self, _info): + session = utils.get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return ( + (await session.scalars(select(CompositePrimaryKeyTestModel))) + .unique() + .first() + ) + return session.query(CompositePrimaryKeyTestModel).first() + + schema = Schema(query=Query) + result = await schema.execute_async( + """ + query { + compositeReporter { + id + firstName + lastName + } + } + """, + context_value={"session": session}, + ) + + assert not result.errors + assert result.data["compositeReporter"]["id"] == DefaultGlobalIDType.to_global_id( + CompositePrimaryKeyTestModelType, str(("graphql", "foundation")) + ) + + def test_custom_objecttype_registered(): class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): class Meta: diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 18d06eef..70539880 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -577,7 +577,7 @@ async def get_result() -> Any: def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) - return tuple(keys) if len(keys) > 1 else keys[0] + return str(tuple(keys)) if len(keys) > 1 else keys[0] @classmethod def enum_for_field(cls, field_name): From 9c2bc8468f4c88fee2b2f6fd2c0b8725fa9ccee3 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 4 Dec 2023 22:33:32 +0100 Subject: [PATCH 59/67] release: 3.0rc1 --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 253e1d9c..f0e7a45b 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .types import SQLAlchemyInterface, SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "3.0.0b4" +__version__ = "3.0.0rc1" __all__ = [ "__version__", From b30bc921cb3881a7d8cf9873d9b192788e749c6b Mon Sep 17 00:00:00 2001 From: Adam Schubert Date: Tue, 5 Mar 2024 16:29:06 +0100 Subject: [PATCH 60/67] feat(filters): Added DateTimeFilter (#404) --- graphene_sqlalchemy/filters.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index bb422724..cbe3d09d 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -423,6 +423,13 @@ class Meta: graphene_type = graphene.Date +class DateTimeFilter(OrderedFilter): + """Concrete Filter Class which specifies a type for all the abstract filter methods defined in the super classes""" + + class Meta: + graphene_type = graphene.DateTime + + class IdFilter(FieldFilter): class Meta: graphene_type = graphene.ID From eb9c663cc0e314987397626573e3d2f940bea138 Mon Sep 17 00:00:00 2001 From: Zet Date: Fri, 13 Sep 2024 17:28:35 +0200 Subject: [PATCH 61/67] fix: create_filters option now does what it says (#414) Co-authored-by: zbynek.skola --- graphene_sqlalchemy/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 70539880..06957511 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -440,6 +440,7 @@ def __init_subclass_with_meta__( batching=False, connection_field_factory=None, _meta=None, + create_filters=True, **options, ): # We always want to bypass this hook unless we're defining a concrete @@ -474,7 +475,7 @@ def __init_subclass_with_meta__( only_fields=only_fields, exclude_fields=exclude_fields, batching=batching, - create_filters=True, + create_filters=create_filters, connection_field_factory=connection_field_factory, ) From a6161dd488810440c7be06fc4dea924b55032eeb Mon Sep 17 00:00:00 2001 From: Ricardo Madriz Date: Thu, 5 Dec 2024 05:58:43 -0600 Subject: [PATCH 62/67] hoursekeeping: add support for python 3.12 (#417) * Add support for python 3.12 Fixes #416 * Remove python 3.7 * Drop python 3.8, add 3.13 * housekeeping: ci 3.9-3.13 --------- Co-authored-by: Erik Wrede --- .github/workflows/tests.yml | 2 +- graphene_sqlalchemy/converter.py | 2 +- graphene_sqlalchemy/utils.py | 11 ++++------- setup.py | 8 +++++--- tox.ini | 7 ++++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c471166a..f03a405f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: max-parallel: 10 matrix: sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ] - python-version: [ "3.7", "3.8", "3.9", "3.10" ] + python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index efcf3c6c..6502412f 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,6 +7,7 @@ from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( ColumnProperty, RelationshipProperty, @@ -14,7 +15,6 @@ interfaces, strategies, ) -from sqlalchemy.ext.hybrid import hybrid_property import graphene from graphene.types.json import JSONString diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 3ba14865..17d774d2 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -3,9 +3,10 @@ import warnings from collections import OrderedDict from functools import _c3_mro +from importlib.metadata import version as get_version from typing import Any, Callable, Dict, Optional -import pkg_resources +from packaging import version from sqlalchemy import select from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper @@ -22,16 +23,12 @@ def get_nullable_type(_type): def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution( - "SQLAlchemy" - ).parsed_version < pkg_resources.parse_version(version_string) + return version.parse(get_version("SQLAlchemy")) < version.parse(version_string) def is_graphene_version_less_than(version_string): # pragma: no cover """Check the installed graphene version""" - return pkg_resources.get_distribution( - "graphene" - ).parsed_version < pkg_resources.parse_version(version_string) + return version.parse(get_version("graphene")) < version.parse(version_string) SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False diff --git a/setup.py b/setup.py index fdace116..33eabcb6 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "promise>=2.3", "SQLAlchemy>=1.1", "aiodataloader>=0.2.0,<1.0", + "packaging>=23.0", ] tests_require = [ @@ -48,13 +49,14 @@ "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: PyPy", ], - keywords="api graphql protocol rest relay graphene", + keywords="api graphql protocol rest relay graphene sqlalchemy", packages=find_packages(exclude=["tests"]), install_requires=requirements, extras_require={ diff --git a/tox.ini b/tox.ini index 9ce901e4..6ec4699e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,15 @@ [tox] -envlist = pre-commit,py{37,38,39,310}-sql{12,13,14,20} +envlist = pre-commit,py{39,310,311,312,313}-sql{12,13,14,20} skipsdist = true minversion = 3.7.0 [gh-actions] python = - 3.7: py37 - 3.8: py38 3.9: py39 3.10: py310 + 3.11: py311 + 3.12: py312 + 3.13: py313 [gh-actions:env] SQLALCHEMY = From febdc451edc3e45af51f7332f0353401e051091c Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Thu, 5 Dec 2024 13:00:22 +0100 Subject: [PATCH 63/67] release: 3.0.0rc2 --- graphene_sqlalchemy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index f0e7a45b..69bb79bb 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .types import SQLAlchemyInterface, SQLAlchemyObjectType from .utils import get_query, get_session -__version__ = "3.0.0rc1" +__version__ = "3.0.0rc2" __all__ = [ "__version__", From 72c3cceb9cd2917a2932c6acf24809addc3ac542 Mon Sep 17 00:00:00 2001 From: Yonatan Romero <4235177+romeroyonatan@users.noreply.github.com> Date: Mon, 7 Apr 2025 04:12:03 -0300 Subject: [PATCH 64/67] fix: Do not create filter class if create_filters is False (#420) --- graphene_sqlalchemy/tests/test_filters.py | 27 +++++++++++++++++++++++ graphene_sqlalchemy/types.py | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 4acf89a8..87bbceae 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1199,3 +1199,30 @@ async def test_additional_filters(session): schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) assert_and_raise_result(result, expected) + + +@pytest.mark.asyncio +async def test_do_not_create_filters(): + class WithoutFilters(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, _meta=None, **options): + super().__init_subclass_with_meta__( + _meta=_meta, create_filters=False, **options + ) + + class PetType(WithoutFilters): + class Meta: + model = Pet + name = "Pet" + interfaces = (relay.Node,) + connection_class = Connection + + class Query(graphene.ObjectType): + pets = SQLAlchemyConnectionField(PetType.connection) + + schema = graphene.Schema(query=Query) + + assert "filter" not in str(schema).lower() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 06957511..894ebfdb 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -513,7 +513,7 @@ def __init_subclass_with_meta__( _meta.fields = sqla_fields # Save Generated filter class in Meta Class - if not _meta.filter_class: + if create_filters and not _meta.filter_class: # Map graphene fields to filters # TODO we might need to pass the ORMFields containing the SQLAlchemy models # to the scalar filters here (to generate expressions from the model) From 83e0c17ef8c203540f818a4aecd37d4375d44aaf Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 7 Apr 2025 09:49:38 +0200 Subject: [PATCH 65/67] chore: update tests actions --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f03a405f..66fe306b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: @@ -34,7 +34,7 @@ jobs: TOXENV: ${{ matrix.toxenv }} - name: Upload coverage.xml if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.10' }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: graphene-sqlalchemy-coverage path: coverage.xml From 6dbd94fd3419b9642f6f74be4c6948e4f156ede7 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 7 Apr 2025 09:50:01 +0200 Subject: [PATCH 66/67] chore: update deploy actions --- .github/workflows/deploy.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 9cc136a1..30ed9526 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -10,9 +10,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.10' - name: Build wheel and source tarball From 4ea6ee819600d65ad784c783a68321105a643d76 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Mon, 7 Apr 2025 09:51:12 +0200 Subject: [PATCH 67/67] chore: update lint actions (#421) --- .github/workflows/lint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 355a94d2..099e9177 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,9 +13,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.10' - name: Install dependencies