diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index ec34535f218..b7b845c510d 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -8,3 +8,4 @@ 1e1a38e7801f410f244e4bbb44ec795ae152e04e # initial blackification 1e278de4cc9a4181e0747640a960e80efcea1ca9 # follow up mass style changes 058c230cea83811c3bebdd8259988c5c501f4f7e # Update black to v23.3.0 and flake8 to v6 +573d004e5f210a199d8b25335c71f973fee21a4b # Update black to 24.1.1 diff --git a/.github/workflows/create-wheels.yaml b/.github/workflows/create-wheels.yaml index b5c0126be68..e3089cd2d1b 100644 --- a/.github/workflows/create-wheels.yaml +++ b/.github/workflows/create-wheels.yaml @@ -20,15 +20,17 @@ jobs: matrix: # emulated wheels on linux take too much time, split wheels into multiple runs python: - - "cp37-* cp38-*" - - "cp39-* cp310-*" - - "cp311-* cp312-*" + - "cp37-* cp38-* cp39-*" + - "cp310-* cp311-*" + - "cp312-* cp313-*" wheel_mode: - compiled os: - "windows-2022" - - "macos-12" + # TODO: macos-14 uses arm macs (only python 3.10+) - make arm wheel on it + - "macos-13" - "ubuntu-22.04" + - "ubuntu-22.04-arm" linux_archs: # this is only meaningful on linux. windows and macos ignore exclude all but one arch - "aarch64" @@ -38,13 +40,17 @@ jobs: # create pure python build - os: ubuntu-22.04 wheel_mode: pure-python - python: "cp-311*" + python: "cp-312*" exclude: - os: "windows-2022" linux_archs: "aarch64" - - os: "macos-12" + - os: "macos-13" linux_archs: "aarch64" + - os: "ubuntu-22.04" + linux_archs: "aarch64" + - os: "ubuntu-22.04-arm" + linux_archs: "x86_64" fail-fast: false @@ -65,15 +71,16 @@ jobs: (cat setup.cfg) | %{$_ -replace "tag_build.?=.?dev",""} | set-content setup.cfg # See details at https://cibuildwheel.readthedocs.io/en/stable/faq/#emulation - - name: Set up QEMU on linux - if: ${{ runner.os == 'Linux' }} - uses: docker/setup-qemu-action@v3 - with: - platforms: all + # no longer needed since arm runners are now available + # - name: Set up QEMU on linux + # if: ${{ runner.os == 'Linux' }} + # uses: docker/setup-qemu-action@v3 + # with: + # platforms: all - name: Build compiled wheels if: ${{ matrix.wheel_mode == 'compiled' }} - uses: pypa/cibuildwheel@v2.16.2 + uses: pypa/cibuildwheel@v2.22.0 env: CIBW_ARCHS_LINUX: ${{ matrix.linux_archs }} CIBW_BUILD: ${{ matrix.python }} @@ -82,9 +89,9 @@ jobs: - name: Set up Python for twine and pure-python wheel - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Build pure-python wheel if: ${{ matrix.wheel_mode == 'pure-python' && runner.os == 'Linux' }} diff --git a/.github/workflows/run-on-pr.yaml b/.github/workflows/run-on-pr.yaml index c19e7a59018..889da8499f3 100644 --- a/.github/workflows/run-on-pr.yaml +++ b/.github/workflows/run-on-pr.yaml @@ -10,7 +10,7 @@ on: env: # global env to all steps - TOX_WORKERS: -n2 + TOX_WORKERS: -n4 permissions: contents: read @@ -23,9 +23,9 @@ jobs: # run this job using this matrix, excluding some combinations below. matrix: os: - - "ubuntu-latest" + - "ubuntu-22.04" python-version: - - "3.11" + - "3.13" build-type: - "cext" - "nocext" @@ -40,7 +40,7 @@ jobs: uses: actions/checkout@v4 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.architecture }} @@ -60,9 +60,9 @@ jobs: strategy: matrix: os: - - "ubuntu-latest" + - "ubuntu-22.04" python-version: - - "3.11" + - "3.12" tox-env: - mypy - lint @@ -75,7 +75,7 @@ jobs: uses: actions/checkout@v4 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.architecture }} diff --git a/.github/workflows/run-test.yaml b/.github/workflows/run-test.yaml index fa2fa54f2ea..6c93ef1b4f7 100644 --- a/.github/workflows/run-test.yaml +++ b/.github/workflows/run-test.yaml @@ -13,22 +13,24 @@ on: env: # global env to all steps - TOX_WORKERS: -n2 + TOX_WORKERS: -n4 permissions: contents: read jobs: run-test: - name: test-${{ matrix.python-version }}-${{ matrix.build-type }}-${{ matrix.architecture }}-${{ matrix.os }} + name: test-${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.architecture }}-${{ matrix.build-type }} runs-on: ${{ matrix.os }} strategy: # run this job using this matrix, excluding some combinations below. matrix: os: - - "ubuntu-latest" + - "ubuntu-22.04" + - "ubuntu-22.04-arm" - "windows-latest" - "macos-latest" + - "macos-13" python-version: - "3.7" - "3.8" @@ -36,32 +38,66 @@ jobs: - "3.10" - "3.11" - "3.12" - - "pypy-3.9" + - "3.13" + - "pypy-3.10" build-type: - "cext" - "nocext" architecture: - x64 - x86 + - arm64 include: # autocommit tests fail on the ci for some reason - - python-version: "pypy-3.9" + - python-version: "pypy-3.10" pytest-args: "-k 'not test_autocommit_on and not test_turn_autocommit_off_via_default_iso_level and not test_autocommit_isolation_level'" - - os: "ubuntu-latest" + - os: "ubuntu-22.04" pytest-args: "--dbdriver pysqlite --dbdriver aiosqlite" + - os: "ubuntu-22.04-arm" + pytest-args: "--dbdriver pysqlite --dbdriver aiosqlite" + exclude: - # linux and osx do not have x86 python - - os: "ubuntu-latest" + # linux does not have x86 / arm64 python + - os: "ubuntu-22.04" + architecture: x86 + - os: "ubuntu-22.04" + architecture: arm64 + # linux-arm does not have x86 / x64 python + - os: "ubuntu-22.04-arm" architecture: x86 + - os: "ubuntu-22.04-arm" + architecture: x64 + # linux-arm does not have 3.7 python + - os: "ubuntu-22.04-arm" + python-version: "3.7" + # windows does not have arm64 python + - os: "windows-latest" + architecture: arm64 + # macos: latests uses arm macs. only 3.10+; no x86/x64 - os: "macos-latest" architecture: x86 - # pypy does not have cext or x86 - - python-version: "pypy-3.9" + - os: "macos-latest" + architecture: x64 + - os: "macos-latest" + python-version: "3.7" + - os: "macos-latest" + python-version: "3.8" + - os: "macos-latest" + python-version: "3.9" + # macos 13: uses intel macs. no arm64, x86 + - os: "macos-13" + architecture: arm64 + - os: "macos-13" + architecture: x86 + # pypy does not have cext or x86 or arm on linux + - python-version: "pypy-3.10" build-type: "cext" + - os: "ubuntu-22.04-arm" + python-version: "pypy-3.10" - os: "windows-latest" - python-version: "pypy-3.9" + python-version: "pypy-3.10" architecture: x86 fail-fast: false @@ -72,7 +108,7 @@ jobs: uses: actions/checkout@v4 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.architecture }} @@ -91,45 +127,7 @@ jobs: - name: Run tests run: tox -e github-${{ matrix.build-type }} -- -q --nomemory --notimingintensive ${{ matrix.pytest-args }} - continue-on-error: ${{ matrix.python-version == 'pypy-3.9' }} - - run-test-arm64: - name: test-arm64-${{ matrix.python-version }}-${{ matrix.build-type }}-${{ matrix.os }} - runs-on: ubuntu-latest - strategy: - matrix: - python-version: - - cp37-cp37m - - cp38-cp38 - - cp39-cp39 - - cp310-cp310 - - cp311-cp311 - build-type: - - "cext" - - "nocext" - - fail-fast: false - - steps: - - name: Checkout repo - uses: actions/checkout@v4 - - - name: Set up emulation - run: | - docker run --rm --privileged multiarch/qemu-user-static --reset -p yes - - - name: Run tests - uses: docker://quay.io/pypa/manylinux2014_aarch64 - with: - args: | - bash -c " - export PATH=/opt/python/${{ matrix.python-version }}/bin:$PATH && - python --version && - python -m pip install --upgrade pip && - pip install --upgrade tox setuptools && - pip list && - tox -e github-${{ matrix.build-type }} -- -q --nomemory --notimingintensive ${{ matrix.pytest-args }} - " + continue-on-error: ${{ matrix.python-version == 'pypy-3.10' }} run-tox: name: ${{ matrix.tox-env }}-${{ matrix.python-version }} @@ -138,25 +136,24 @@ jobs: # run this job using this matrix, excluding some combinations below. matrix: os: - - "ubuntu-latest" + - "ubuntu-22.04" python-version: - "3.8" - "3.9" - "3.10" - "3.11" + - "3.12" + - "3.13" tox-env: - mypy - - lint - pep484 - exclude: - # run lint only on 3.11 - - tox-env: lint - python-version: "3.8" - - tox-env: lint - python-version: "3.9" + include: + # run lint only on 3.12 - tox-env: lint - python-version: "3.10" + python-version: "3.12" + os: "ubuntu-22.04" + exclude: # run pep484 only on 3.10+ - tox-env: pep484 python-version: "3.8" @@ -171,7 +168,7 @@ jobs: uses: actions/checkout@v4 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.architecture }} diff --git a/.gitignore b/.gitignore index 13b40c819ad..d2ee9a2f4ad 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ test/test_schema.db /db_idents.txt .DS_Store .vs +/scratch diff --git a/.gitreview b/.gitreview index 01d8b1770f7..3e5e2b50dac 100644 --- a/.gitreview +++ b/.gitreview @@ -1,4 +1,4 @@ [gerrit] host=gerrit.sqlalchemy.org project=sqlalchemy/sqlalchemy -defaultbranch=main +defaultbranch=rel_2_0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab722e4f309..c7d225e1ae0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/python/black - rev: 23.3.0 + rev: 25.1.0 hooks: - id: black @@ -12,7 +12,7 @@ repos: - id: zimports - repo: https://github.com/pycqa/flake8 - rev: 5.0.0 + rev: 7.2.0 hooks: - id: flake8 additional_dependencies: @@ -33,6 +33,8 @@ repos: - id: black-docs name: Format docs code block with black entry: python tools/format_docs_code.py -f - language: system + language: python types: [rst] exclude: README.* + additional_dependencies: + - black==25.1.0 diff --git a/LICENSE b/LICENSE index 7bf9bbe9683..dfe1a4d815b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2005-2023 SQLAlchemy authors and contributors . +Copyright 2005-2025 SQLAlchemy authors and contributors . Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/README.dialects.rst b/README.dialects.rst index 810267a20cf..798ed21fbd3 100644 --- a/README.dialects.rst +++ b/README.dialects.rst @@ -26,7 +26,9 @@ compliance suite" should be viewed as the primary target for new dialects. Dialect Layout =============== -The file structure of a dialect is typically similar to the following:: +The file structure of a dialect is typically similar to the following: + +.. sourcecode:: text sqlalchemy-/ setup.py @@ -52,9 +54,9 @@ Key aspects of this file layout include: dialect to be usable from create_engine(), e.g.:: entry_points = { - 'sqlalchemy.dialects': [ - 'access.pyodbc = sqlalchemy_access.pyodbc:AccessDialect_pyodbc', - ] + "sqlalchemy.dialects": [ + "access.pyodbc = sqlalchemy_access.pyodbc:AccessDialect_pyodbc", + ] } Above, the entrypoint ``access.pyodbc`` allow URLs to be used such as:: @@ -63,7 +65,9 @@ Key aspects of this file layout include: * setup.cfg - this file contains the traditional contents such as [tool:pytest] directives, but also contains new directives that are used - by SQLAlchemy's testing framework. E.g. for Access:: + by SQLAlchemy's testing framework. E.g. for Access: + + .. sourcecode:: text [tool:pytest] addopts= --tb native -v -r fxX --maxfail=25 -p no:warnings @@ -129,6 +133,7 @@ Key aspects of this file layout include: from sqlalchemy.testing import exclusions + class Requirements(SuiteRequirements): @property def nullable_booleans(self): @@ -148,7 +153,9 @@ Key aspects of this file layout include: The requirements system can also be used when running SQLAlchemy's primary test suite against the external dialect. In this use case, a ``--dburi`` as well as a ``--requirements`` flag are passed to SQLAlchemy's - test runner so that exclusions specific to the dialect take place:: + test runner so that exclusions specific to the dialect take place: + + .. sourcecode:: text cd /path/to/sqlalchemy pytest -v \ @@ -175,6 +182,7 @@ Key aspects of this file layout include: from sqlalchemy.testing.suite import IntegerTest as _IntegerTest + class IntegerTest(_IntegerTest): @testing.skip("access") diff --git a/README.unittests.rst b/README.unittests.rst index 9cf309d2d7e..ce280bb4d23 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -49,7 +49,7 @@ database options and test selection. A generic pytest run looks like:: - pytest -n4 + pytest - n4 Above, the full test suite will run against SQLite, using four processes. If the "-n" flag is not used, the pytest-xdist is skipped and the tests will @@ -280,7 +280,7 @@ intended for production use! # configure the database sleep 20 - docker exec -ti mariadb mysql -u root -ppassword -w -e "CREATE DATABASE test_schema CHARSET utf8mb4; GRANT ALL ON test_schema.* TO scott;" + docker exec -ti mariadb mariadb -u root -ppassword -w -e "CREATE DATABASE test_schema CHARSET utf8mb4; GRANT ALL ON test_schema.* TO scott;" # To stop the container. It will also remove it. docker stop mariadb @@ -307,11 +307,11 @@ be used with pytest by using ``--db docker_mssql``. **Oracle configuration**:: # create the container with the proper configuration for sqlalchemy - docker run --rm --name oracle -p 127.0.0.1:1521:1521 -d -e ORACLE_PASSWORD=tiger -e ORACLE_DATABASE=test -e APP_USER=scott -e APP_USER_PASSWORD=tiger gvenzl/oracle-xe:21-slim + docker run --rm --name oracle -p 127.0.0.1:1521:1521 -d -e ORACLE_PASSWORD=tiger -e ORACLE_DATABASE=test -e APP_USER=scott -e APP_USER_PASSWORD=tiger gvenzl/oracle-free:23-slim # enter the database container and run the command docker exec -ti oracle bash - >> sqlplus system/tiger@//localhost/XEPDB1 <> sqlplus system/tiger@//localhost/FREEPDB1 <`` option for PostgreSQL ``CREATE TABLE`` to + specify the access method to use to store the contents for the new table. + Pull request courtesy Edgar Ramírez-Mondragón. + + .. seealso:: + + :ref:`postgresql_table_options` + + .. change:: + :tags: bug, examples + :tickets: 10920 + + Fixed regression in history_meta example where the use of + :meth:`_schema.MetaData.to_metadata` to make a copy of the history table + would also copy indexes (which is a good thing), but causing naming + conflicts indexes regardless of naming scheme used for those indexes. A + "_history" suffix is now added to these indexes in the same way as is + achieved for the table name. + + + .. change:: + :tags: bug, orm + :tickets: 10967 + + Fixed issue where using :meth:`_orm.Session.delete` along with the + :paramref:`_orm.Mapper.version_id_col` feature would fail to use the + correct version identifier in the case that an additional UPDATE were + emitted against the target object as a result of the use of + :paramref:`_orm.relationship.post_update` on the object. The issue is + similar to :ticket:`10800` just fixed in version 2.0.25 for the case of + updates alone. + + .. change:: + :tags: bug, orm + :tickets: 10990 + + Fixed issue where an assertion within the implementation for + :func:`_orm.with_expression` would raise if a SQL expression that was not + cacheable were used; this was a 2.0 regression since 1.4. + + .. change:: + :tags: postgresql, usecase + :tickets: 9736 + + Correctly type PostgreSQL RANGE and MULTIRANGE types as ``Range[T]`` + and ``Sequence[Range[T]]``. + Introduced utility sequence :class:`_postgresql.MultiRange` to allow better + interoperability of MULTIRANGE types. + + .. change:: + :tags: postgresql, usecase + + Differentiate between INT4 and INT8 ranges and multi-ranges types when + inferring the database type from a :class:`_postgresql.Range` or + :class:`_postgresql.MultiRange` instance, preferring INT4 if the values + fit into it. + + .. change:: + :tags: bug, typing + + Fixed the type signature for the :meth:`.PoolEvents.checkin` event to + indicate that the given :class:`.DBAPIConnection` argument may be ``None`` + in the case where the connection has been invalidated. + + .. change:: + :tags: bug, examples + + Fixed the performance example scripts in examples/performance to mostly + work with the Oracle database, by adding the :class:`.Identity` construct + to all the tables and allowing primary generation to occur on this backend. + A few of the "raw DBAPI" cases still are not compatible with Oracle. + + + .. change:: + :tags: bug, mssql + + Fixed an issue regarding the use of the :class:`.Uuid` datatype with the + :paramref:`.Uuid.as_uuid` parameter set to False, when using the pymssql + dialect. ORM-optimized INSERT statements (e.g. the "insertmanyvalues" + feature) would not correctly align primary key UUID values for bulk INSERT + statements, resulting in errors. Similar issues were fixed for the + PostgreSQL drivers as well. + + + .. change:: + :tags: bug, postgresql + + Fixed an issue regarding the use of the :class:`.Uuid` datatype with the + :paramref:`.Uuid.as_uuid` parameter set to False, when using PostgreSQL + dialects. ORM-optimized INSERT statements (e.g. the "insertmanyvalues" + feature) would not correctly align primary key UUID values for bulk INSERT + statements, resulting in errors. Similar issues were fixed for the + pymssql driver as well. + +.. changelog:: + :version: 2.0.25 + :released: January 2, 2024 + + .. change:: + :tags: oracle, asyncio + :tickets: 10679 + + Added support for :ref:`oracledb` in asyncio mode, using the newly released + version of the ``oracledb`` DBAPI that includes asyncio support. For the + 2.0 series, this is a preview release, where the current implementation + does not yet have include support for + :meth:`_asyncio.AsyncConnection.stream`. Improved support is planned for + the 2.1 release of SQLAlchemy. + + .. change:: + :tags: bug, orm + :tickets: 10800 + + Fixed issue where when making use of the + :paramref:`_orm.relationship.post_update` feature at the same time as using + a mapper version_id_col could lead to a situation where the second UPDATE + statement emitted by the post-update feature would fail to make use of the + correct version identifier, assuming an UPDATE was already emitted in that + flush which had already bumped the version counter. + + .. change:: + :tags: bug, typing + :tickets: 10801, 10818 + + Fixed regressions caused by typing added to the ``sqlalchemy.sql.functions`` + module in version 2.0.24, as part of :ticket:`6810`: + + * Further enhancements to pep-484 typing to allow SQL functions from + :attr:`_sql.func` derived elements to work more effectively with ORM-mapped + attributes (:ticket:`10801`) + + * Fixed the argument types passed to functions so that literal expressions + like strings and ints are again interpreted correctly (:ticket:`10818`) + + + .. change:: + :tags: usecase, orm + :tickets: 10807 + + Added preliminary support for Python 3.12 pep-695 type alias structures, + when resolving custom type maps for ORM Annotated Declarative mappings. + + + .. change:: + :tags: bug, orm + :tickets: 10815 + + Fixed issue where ORM Annotated Declarative would mis-interpret the left + hand side of a relationship without any collection specified as + uselist=True if the left type were given as a class and not a string, + without using future-style annotations. + + .. change:: + :tags: bug, sql + :tickets: 10817 + + Improved compilation of :func:`_sql.any_` / :func:`_sql.all_` in the + context of a negation of boolean comparison, will now render ``NOT (expr)`` + rather than reversing the equality operator to not equals, allowing + finer-grained control of negations for these non-typical operators. + .. changelog:: :version: 2.0.24 - :include_notes_from: unreleased_20 + :released: December 28, 2023 + + .. change:: + :tags: bug, orm + :tickets: 10597 + + Fixed issue where use of :func:`_orm.foreign` annotation on a + non-initialized :func:`_orm.mapped_column` construct would produce an + expression without a type, which was then not updated at initialization + time of the actual column, leading to issues such as relationships not + determining ``use_get`` appropriately. + + + .. change:: + :tags: bug, schema + :tickets: 10654 + + Fixed issue where error reporting for unexpected schema item when creating + objects like :class:`_schema.Table` would incorrectly handle an argument + that was itself passed as a tuple, leading to a formatting error. The + error message has been modernized to use f-strings. + + .. change:: + :tags: bug, engine + :tickets: 10662 + + Fixed URL-encoding of the username and password components of + :class:`.engine.URL` objects when converting them to string using the + :meth:`_engine.URL.render_as_string` method, by using Python standard + library ``urllib.parse.quote`` while allowing for plus signs and spaces to + remain unchanged as supported by SQLAlchemy's non-standard URL parsing, + rather than the legacy home-grown routine from many years ago. Pull request + courtesy of Xavier NUNN. + + .. change:: + :tags: bug, orm + :tickets: 10668 + + Improved the error message produced when the unit of work process sets the + value of a primary key column to NULL due to a related object with a + dependency rule on that column being deleted, to include not just the + destination object and column name but also the source column from which + the NULL value is originating. Pull request courtesy Jan Vollmer. + + .. change:: + :tags: bug, postgresql + :tickets: 10717 + + Adjusted the asyncpg dialect such that when the ``terminate()`` method is + used to discard an invalidated connection, the dialect will first attempt + to gracefully close the connection using ``.close()`` with a timeout, if + the operation is proceeding within an async event loop context only. This + allows the asyncpg driver to attend to finalizing a ``TimeoutError`` + including being able to close a long-running query server side, which + otherwise can keep running after the program has exited. + + .. change:: + :tags: bug, orm + :tickets: 10732 + + Modified the ``__init_subclass__()`` method used by + :class:`_orm.MappedAsDataclass`, :class:`_orm.DeclarativeBase` and + :class:`_orm.DeclarativeBaseNoMeta` to accept arbitrary ``**kw`` and to + propagate them to the ``super()`` call, allowing greater flexibility in + arranging custom superclasses and mixins which make use of + ``__init_subclass__()`` keyword arguments. Pull request courtesy Michael + Oliver. + + + .. change:: + :tags: bug, tests + :tickets: 10747 + + Improvements to the test suite to further harden its ability to run + when Python ``greenlet`` is not installed. There is now a tox + target that includes the token "nogreenlet" that will run the suite + with greenlet not installed (note that it still temporarily installs + greenlet as part of the tox config, however). + + .. change:: + :tags: bug, sql + :tickets: 10753 + + Fixed issue in stringify for SQL elements, where a specific dialect is not + passed, where a dialect-specific element such as the PostgreSQL "on + conflict do update" construct is encountered and then fails to provide for + a stringify dialect with the appropriate state to render the construct, + leading to internal errors. + + .. change:: + :tags: bug, sql + + Fixed issue where stringifying or compiling a :class:`.CTE` that was + against a DML construct such as an :func:`_sql.insert` construct would fail + to stringify, due to a mis-detection that the statement overall is an + INSERT, leading to internal errors. + + .. change:: + :tags: bug, orm + :tickets: 10776 + + Ensured the use case of :class:`.Bundle` objects used in the + ``returning()`` portion of ORM-enabled INSERT, UPDATE and DELETE statements + is tested and works fully. This was never explicitly implemented or + tested previously and did not work correctly in the 1.4 series; in the 2.0 + series, ORM UPDATE/DELETE with WHERE criteria was missing an implementation + method preventing :class:`.Bundle` objects from working. + + .. change:: + :tags: bug, orm + :tickets: 10784 + + Fixed 2.0 regression in :class:`.MutableList` where a routine that detects + sequences would not correctly filter out string or bytes instances, making + it impossible to assign a string value to a specific index (while + non-sequence values would work fine). + + .. change:: + :tags: change, asyncio + + The ``async_fallback`` dialect argument is now deprecated, and will be + removed in SQLAlchemy 2.1. This flag has not been used for SQLAlchemy's + test suite for some time. asyncio dialects can still run in a synchronous + style by running code within a greenlet using :func:`_util.greenlet_spawn`. + + .. change:: + :tags: bug, typing + :tickets: 6810 + + Completed pep-484 typing for the ``sqlalchemy.sql.functions`` module. + :func:`_sql.select` constructs made against ``func`` elements should now + have filled-in return types. .. changelog:: :version: 2.0.23 @@ -249,12 +2212,17 @@ .. change:: :tags: bug, orm - :tickets: 10365 + :tickets: 10365, 11412 Fixed bug where ORM :func:`_orm.with_loader_criteria` would not apply itself to a :meth:`_sql.Select.join` where the ON clause were given as a plain SQL comparison, rather than as a relationship target or similar. + **update** - this was found to also fix an issue where + single-inheritance criteria would not be correctly applied to a + subclass entity that only appeared in the ``select_from()`` list, + see :ticket:`11412` + .. change:: :tags: bug, sql :tickets: 10408 @@ -3149,7 +5117,7 @@ Added an error message when a :func:`_orm.relationship` is mapped against an abstract container type, such as ``Mapped[Sequence[B]]``, without providing the :paramref:`_orm.relationship.container_class` parameter which - is necessary when the type is abstract. Previously the the abstract + is necessary when the type is abstract. Previously the abstract container would attempt to be instantiated at a later step and fail. diff --git a/doc/build/changelog/migration_05.rst b/doc/build/changelog/migration_05.rst index d26a22c0d00..8b48f13f6b4 100644 --- a/doc/build/changelog/migration_05.rst +++ b/doc/build/changelog/migration_05.rst @@ -443,8 +443,7 @@ Schema/Types :: - class MyType(AdaptOldConvertMethods, TypeEngine): - ... + class MyType(AdaptOldConvertMethods, TypeEngine): ... * The ``quote`` flag on ``Column`` and ``Table`` as well as the ``quote_schema`` flag on ``Table`` now control quoting @@ -589,8 +588,7 @@ Removed :: class MyQuery(Query): - def get(self, ident): - ... + def get(self, ident): ... session = sessionmaker(query_cls=MyQuery)() diff --git a/doc/build/changelog/migration_06.rst b/doc/build/changelog/migration_06.rst index 0330ac5d4a4..320f34009af 100644 --- a/doc/build/changelog/migration_06.rst +++ b/doc/build/changelog/migration_06.rst @@ -86,11 +86,10 @@ sign "+": Important Dialect Links: * Documentation on connect arguments: - https://www.sqlalchemy.org/docs/06/dbengine.html#create- - engine-url-arguments. + https://www.sqlalchemy.org/docs/06/dbengine.html#create-engine-url-arguments. -* Reference documentation for individual dialects: https://ww - w.sqlalchemy.org/docs/06/reference/dialects/index.html +* Reference documentation for individual dialects: + https://www.sqlalchemy.org/docs/06/reference/dialects/index.html. * The tips and tricks at DatabaseNotes. @@ -1223,8 +1222,8 @@ SQLSoup SQLSoup has been modernized and updated to reflect common 0.5/0.6 capabilities, including well defined session -integration. Please read the new docs at [https://www.sqlalc -hemy.org/docs/06/reference/ext/sqlsoup.html]. +integration. Please read the new docs at +[https://www.sqlalchemy.org/docs/06/reference/ext/sqlsoup.html]. Declarative ----------- diff --git a/doc/build/changelog/migration_07.rst b/doc/build/changelog/migration_07.rst index 19716ad3c4c..4f1c98be1a8 100644 --- a/doc/build/changelog/migration_07.rst +++ b/doc/build/changelog/migration_07.rst @@ -204,8 +204,7 @@ scenarios. Highlights of this release include: A demonstration of callcount reduction including a sample benchmark script is at -https://techspot.zzzeek.org/2010/12/12/a-tale-of-three- -profiles/ +https://techspot.zzzeek.org/2010/12/12/a-tale-of-three-profiles/ Composites Rewritten -------------------- diff --git a/doc/build/changelog/migration_08.rst b/doc/build/changelog/migration_08.rst index 0f661cca790..ea9b9170537 100644 --- a/doc/build/changelog/migration_08.rst +++ b/doc/build/changelog/migration_08.rst @@ -1394,8 +1394,7 @@ yet, we'll be adding the ``inspector`` argument into it directly:: @event.listens_for(Table, "column_reflect") - def listen_for_col(inspector, table, column_info): - ... + def listen_for_col(inspector, table, column_info): ... :ticket:`2418` @@ -1495,7 +1494,7 @@ SQLSoup SQLSoup is a handy package that presents an alternative interface on top of the SQLAlchemy ORM. SQLSoup is now moved into its own project and documented/released -separately; see https://bitbucket.org/zzzeek/sqlsoup. +separately; see https://github.com/zzzeek/sqlsoup. SQLSoup is a very simple tool that could also benefit from contributors who are interested in its style of usage. diff --git a/doc/build/changelog/migration_09.rst b/doc/build/changelog/migration_09.rst index 287fc2c933a..61cd9a3a307 100644 --- a/doc/build/changelog/migration_09.rst +++ b/doc/build/changelog/migration_09.rst @@ -1148,7 +1148,7 @@ can be dropped in using callable functions. It is hoped that the :class:`.AutomapBase` system provides a quick and modernized solution to the problem that the very famous -`SQLSoup `_ +`SQLSoup `_ also tries to solve, that of generating a quick and rudimentary object model from an existing database on the fly. By addressing the issue strictly at the mapper configuration level, and integrating fully with existing diff --git a/doc/build/changelog/migration_10.rst b/doc/build/changelog/migration_10.rst index 5a016140ae3..1e61b308571 100644 --- a/doc/build/changelog/migration_10.rst +++ b/doc/build/changelog/migration_10.rst @@ -2680,7 +2680,7 @@ on MySQL:: Drizzle Dialect is now an External Dialect ------------------------------------------ -The dialect for `Drizzle `_ is now an external +The dialect for `Drizzle `_ is now an external dialect, available at https://bitbucket.org/zzzeek/sqlalchemy-drizzle. This dialect was added to SQLAlchemy right before SQLAlchemy was able to accommodate third party dialects well; going forward, all databases that aren't diff --git a/doc/build/changelog/migration_11.rst b/doc/build/changelog/migration_11.rst index 8a1ba3ba0e6..15ef6fcd0c7 100644 --- a/doc/build/changelog/migration_11.rst +++ b/doc/build/changelog/migration_11.rst @@ -2129,7 +2129,7 @@ table to an integer "id" column on the other:: pets = relationship( "Pets", primaryjoin=( - "foreign(Pets.person_id)" "==cast(type_coerce(Person.id, Integer), Integer)" + "foreign(Pets.person_id)==cast(type_coerce(Person.id, Integer), Integer)" ), ) diff --git a/doc/build/changelog/migration_12.rst b/doc/build/changelog/migration_12.rst index 454b17f12a5..cd21d087910 100644 --- a/doc/build/changelog/migration_12.rst +++ b/doc/build/changelog/migration_12.rst @@ -1586,7 +1586,7 @@ Support for Batch Mode / Fast Execution Helpers The psycopg2 ``cursor.executemany()`` method has been identified as performing poorly, particularly with INSERT statements. To alleviate this, psycopg2 -has added `Fast Execution Helpers `_ +has added `Fast Execution Helpers `_ which rework statements into fewer server round trips by sending multiple DML statements in batch. SQLAlchemy 1.2 now includes support for these helpers to be used transparently whenever the :class:`_engine.Engine` makes use diff --git a/doc/build/changelog/migration_14.rst b/doc/build/changelog/migration_14.rst index ae93003ae65..aef07864d60 100644 --- a/doc/build/changelog/migration_14.rst +++ b/doc/build/changelog/migration_14.rst @@ -552,8 +552,7 @@ SQLAlchemy has for a long time used a parameter-injecting decorator to help reso mutually-dependent module imports, like this:: @util.dependency_for("sqlalchemy.sql.dml") - def insert(self, dml, *args, **kw): - ... + def insert(self, dml, *args, **kw): ... Where the above function would be rewritten to no longer have the ``dml`` parameter on the outside. This would confuse code-linting tools into seeing a missing parameter @@ -2274,8 +2273,7 @@ in any way:: addresses = relationship(Address, backref=backref("user", viewonly=True)) - class Address(Base): - ... + class Address(Base): ... u1 = session.query(User).filter_by(name="x").first() diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index fe86338ee21..523eb638101 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -250,7 +250,7 @@ With warnings turned on, our program now has a lot to say: .. sourcecode:: text - $ SQLALCHEMY_WARN_20=1 python2 -W always::DeprecationWarning test3.py + $ SQLALCHEMY_WARN_20=1 python -W always::DeprecationWarning test3.py test3.py:9: RemovedIn20Warning: The Engine.execute() function/method is considered legacy as of the 1.x series of SQLAlchemy and will be removed in 2.0. All statement execution in SQLAlchemy 2.0 is performed by the Connection.execute() method of Connection, or in the ORM by the Session.execute() method of Session. (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) engine.execute("CREATE TABLE foo (id integer)") /home/classic/dev/sqlalchemy/lib/sqlalchemy/engine/base.py:2856: RemovedIn20Warning: Passing a string to Connection.execute() is deprecated and will be removed in version 2.0. Use the text() construct, or the Connection.exec_driver_sql() method to invoke a driver-level SQL string. (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) @@ -296,7 +296,7 @@ as a bonus our program is much clearer:: # select() now accepts column / table expressions positionally result = connection.execute(select(foo.c.id)) - print(result.fetchall()) + print(result.fetchall()) The goal of "2.0 deprecations mode" is that a program which runs with no :class:`_exc.RemovedIn20Warning` warnings with "2.0 deprecations mode" turned diff --git a/doc/build/changelog/unreleased_20/12593.rst b/doc/build/changelog/unreleased_20/12593.rst new file mode 100644 index 00000000000..945e0d65f5b --- /dev/null +++ b/doc/build/changelog/unreleased_20/12593.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 12593 + + Implemented the :func:`_orm.defer`, :func:`_orm.undefer` and + :func:`_orm.load_only` loader options to work for composite attributes, a + use case that had never been supported previously. diff --git a/doc/build/changelog/unreleased_20/12600.rst b/doc/build/changelog/unreleased_20/12600.rst new file mode 100644 index 00000000000..d544a225d3a --- /dev/null +++ b/doc/build/changelog/unreleased_20/12600.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, postgresql, reflection + :tickets: 12600 + + Fixed regression caused by :ticket:`10665` where the newly modified + constraint reflection query would fail on older versions of PostgreSQL + such as version 9.6. Pull request courtesy Denis Laxalde. diff --git a/doc/build/changelog/unreleased_20/12648.rst b/doc/build/changelog/unreleased_20/12648.rst new file mode 100644 index 00000000000..4abe0e395d6 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12648.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mysql + :tickets: 12648 + + Fixed yet another regression caused by by the DEFAULT rendering changes in + 2.0.40 :ticket:`12425`, similar to :ticket:`12488`, this time where using a + CURRENT_TIMESTAMP function with a fractional seconds portion inside a + textual default value would also fail to be recognized as a + non-parenthesized server default. + + diff --git a/doc/build/changelog/unreleased_20/8664.rst b/doc/build/changelog/unreleased_20/8664.rst new file mode 100644 index 00000000000..8a17e439720 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8664.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 8664 + + Added ``postgresql_ops`` key to the ``dialect_options`` entry in reflected + dictionary. This maps names of columns used in the index to respective + operator class, if distinct from the default one for column's data type. + Pull request courtesy Denis Laxalde. + + .. seealso:: + + :ref:`postgresql_operator_classes` diff --git a/doc/build/changelog/whatsnew_20.rst b/doc/build/changelog/whatsnew_20.rst index 179ed55f2da..59f3273333b 100644 --- a/doc/build/changelog/whatsnew_20.rst +++ b/doc/build/changelog/whatsnew_20.rst @@ -1050,7 +1050,7 @@ implemented by :meth:`_orm.Session.bulk_insert_mappings`, with additional enhancements. This will optimize the batching of rows making use of the new :ref:`fast insertmany ` feature, while also adding support for -heterogenous parameter sets and multiple-table mappings like joined table +heterogeneous parameter sets and multiple-table mappings like joined table inheritance:: >>> users = session.scalars( @@ -2184,6 +2184,11 @@ hold onto database connections after they are released, did in fact have a measurable negative performance impact. As always, the pool class is customizable via the :paramref:`_sa.create_engine.poolclass` parameter. +.. versionchanged:: 2.0.38 - an equivalent change is also made for the + ``aiosqlite`` dialect, using :class:`._pool.AsyncAdaptedQueuePool` instead + of :class:`._pool.NullPool`. The ``aiosqlite`` dialect was not included + in the initial change in error. + .. seealso:: :ref:`pysqlite_threading_pooling` diff --git a/doc/build/conf.py b/doc/build/conf.py index 7abecb59cdc..b91d5cd9c8c 100644 --- a/doc/build/conf.py +++ b/doc/build/conf.py @@ -20,7 +20,9 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath("../../lib")) sys.path.insert(0, os.path.abspath("../..")) # examples -sys.path.insert(0, os.path.abspath(".")) + +# was never needed, does not work as of python 3.12 due to conflicts +# sys.path.insert(0, os.path.abspath(".")) os.environ["DISABLE_SQLALCHEMY_CEXT_RUNTIME"] = "true" @@ -233,7 +235,7 @@ # General information about the project. project = "SQLAlchemy" -copyright = "2007-2023, the SQLAlchemy authors and contributors" # noqa +copyright = "2007-2025, the SQLAlchemy authors and contributors" # noqa # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -242,9 +244,9 @@ # The short X.Y version. version = "2.0" # The full version, including alpha/beta/rc tags. -release = "2.0.23" +release = "2.0.41" -release_date = "November 2, 2023" +release_date = "May 14, 2025" site_base = os.environ.get("RTD_SITE_BASE", "https://www.sqlalchemy.org") site_adapter_template = "docs_adapter.mako" diff --git a/doc/build/copyright.rst b/doc/build/copyright.rst index aa4abac9b1d..54535474c42 100644 --- a/doc/build/copyright.rst +++ b/doc/build/copyright.rst @@ -6,7 +6,7 @@ Appendix: Copyright This is the MIT license: ``_ -Copyright (c) 2005-2023 Michael Bayer and contributors. +Copyright (c) 2005-2025 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index 994daa8f541..030d41cd3b3 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -140,15 +140,15 @@ each time the transaction is ended, and a new statement is emitted, a new transaction begins implicitly:: with engine.connect() as connection: - connection.execute("") + connection.execute(text("")) connection.commit() # commits "some statement" # new transaction starts - connection.execute("") + connection.execute(text("")) connection.rollback() # rolls back "some other statement" # new transaction starts - connection.execute("") + connection.execute(text("")) connection.commit() # commits "a third statement" .. versionadded:: 2.0 "commit as you go" style is a new feature of @@ -321,7 +321,7 @@ begin a transaction:: isolation_level="REPEATABLE READ" ) as connection: with connection.begin(): - connection.execute("") + connection.execute(text("")) .. tip:: The return value of the :meth:`_engine.Connection.execution_options` method is the same @@ -419,7 +419,7 @@ reverted when a connection is returned to the connection pool. :ref:`SQL Server Transaction Isolation ` - :ref:`Oracle Transaction Isolation ` + :ref:`Oracle Database Transaction Isolation ` :ref:`session_transaction_isolation` - for the ORM @@ -443,8 +443,8 @@ If we wanted to check out a :class:`_engine.Connection` object and use it with engine.connect() as connection: connection.execution_options(isolation_level="AUTOCOMMIT") - connection.execute("") - connection.execute("") + connection.execute(text("")) + connection.execute(text("")) Above illustrates normal usage of "DBAPI autocommit" mode. There is no need to make use of methods such as :meth:`_engine.Connection.begin` @@ -472,8 +472,8 @@ In the example below, statements remain # this begin() does not affect the DBAPI connection, isolation stays at AUTOCOMMIT with connection.begin() as trans: - connection.execute("") - connection.execute("") + connection.execute(text("")) + connection.execute(text("")) When we run a block like the above with logging turned on, the logging will attempt to indicate that while a DBAPI level ``.commit()`` is called, @@ -496,11 +496,11 @@ called after autobegin has already occurred:: connection = connection.execution_options(isolation_level="AUTOCOMMIT") # "transaction" is autobegin (but has no effect due to autocommit) - connection.execute("") + connection.execute(text("")) # this will raise; "transaction" is already begun with connection.begin() as trans: - connection.execute("") + connection.execute(text("")) The above example also demonstrates the same theme that the "autocommit" isolation level is a configurational detail of the underlying database @@ -545,7 +545,7 @@ before we call upon :meth:`_engine.Connection.begin`:: connection.execution_options(isolation_level="AUTOCOMMIT") # run statement(s) in autocommit mode - connection.execute("") + connection.execute(text("")) # "commit" the autobegun "transaction" connection.commit() @@ -555,7 +555,7 @@ before we call upon :meth:`_engine.Connection.begin`:: # use a begin block with connection.begin() as trans: - connection.execute("") + connection.execute(text("")) Above, to manually revert the isolation level we made use of :attr:`_engine.Connection.default_isolation_level` to restore the default @@ -568,11 +568,11 @@ use two blocks :: # use an autocommit block with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection: # run statement in autocommit mode - connection.execute("") + connection.execute(text("")) # use a regular block with engine.begin() as connection: - connection.execute("") + connection.execute(text("")) To sum up: @@ -588,17 +588,17 @@ To sum up: Using Server Side Cursors (a.k.a. stream results) ------------------------------------------------- -Some backends feature explicit support for the concept of "server -side cursors" versus "client side cursors". A client side cursor here -means that the database driver fully fetches all rows from a result set -into memory before returning from a statement execution. Drivers such as -those of PostgreSQL and MySQL/MariaDB generally use client side cursors -by default. A server side cursor, by contrast, indicates that result rows -remain pending within the database server's state as result rows are consumed -by the client. The drivers for Oracle generally use a "server side" model, -for example, and the SQLite dialect, while not using a real "client / server" -architecture, still uses an unbuffered result fetching approach that will -leave result rows outside of process memory before they are consumed. +Some backends feature explicit support for the concept of "server side cursors" +versus "client side cursors". A client side cursor here means that the +database driver fully fetches all rows from a result set into memory before +returning from a statement execution. Drivers such as those of PostgreSQL and +MySQL/MariaDB generally use client side cursors by default. A server side +cursor, by contrast, indicates that result rows remain pending within the +database server's state as result rows are consumed by the client. The drivers +for Oracle Database generally use a "server side" model, for example, and the +SQLite dialect, while not using a real "client / server" architecture, still +uses an unbuffered result fetching approach that will leave result rows outside +of process memory before they are consumed. .. topic:: What we really mean is "buffered" vs. "unbuffered" results @@ -1490,10 +1490,8 @@ Basic guidelines include: def my_stmt(parameter, thing=False): stmt = lambda_stmt(lambda: select(table)) - stmt += ( - lambda s: s.where(table.c.x > parameter) - if thing - else s.where(table.c.y == parameter) + stmt += lambda s: ( + s.where(table.c.x > parameter) if thing else s.where(table.c.y == parameter) ) return stmt @@ -1809,17 +1807,18 @@ Current Support ~~~~~~~~~~~~~~~ The feature is enabled for all backend included in SQLAlchemy that support -RETURNING, with the exception of Oracle for which both the cx_Oracle and -OracleDB drivers offer their own equivalent feature. The feature normally takes -place when making use of the :meth:`_dml.Insert.returning` method of an -:class:`_dml.Insert` construct in conjunction with :term:`executemany` -execution, which occurs when passing a list of dictionaries to the -:paramref:`_engine.Connection.execute.parameters` parameter of the -:meth:`_engine.Connection.execute` or :meth:`_orm.Session.execute` methods (as -well as equivalent methods under :ref:`asyncio ` and -shorthand methods like :meth:`_orm.Session.scalars`). It also takes place -within the ORM :term:`unit of work` process when using methods such as -:meth:`_orm.Session.add` and :meth:`_orm.Session.add_all` to add rows. +RETURNING, with the exception of Oracle Database for which both the +python-oracledb and cx_Oracle drivers offer their own equivalent feature. The +feature normally takes place when making use of the +:meth:`_dml.Insert.returning` method of an :class:`_dml.Insert` construct in +conjunction with :term:`executemany` execution, which occurs when passing a +list of dictionaries to the :paramref:`_engine.Connection.execute.parameters` +parameter of the :meth:`_engine.Connection.execute` or +:meth:`_orm.Session.execute` methods (as well as equivalent methods under +:ref:`asyncio ` and shorthand methods like +:meth:`_orm.Session.scalars`). It also takes place within the ORM :term:`unit +of work` process when using methods such as :meth:`_orm.Session.add` and +:meth:`_orm.Session.add_all` to add rows. For SQLAlchemy's included dialects, support or equivalent support is currently as follows: @@ -1829,8 +1828,8 @@ as follows: * SQL Server - all supported SQL Server versions [#]_ * MariaDB - supported for MariaDB versions 10.5 and above * MySQL - no support, no RETURNING feature is present -* Oracle - supports RETURNING with executemany using native cx_Oracle / OracleDB - APIs, for all supported Oracle versions 9 and above, using multi-row OUT +* Oracle Database - supports RETURNING with executemany using native python-oracledb / cx_Oracle + APIs, for all supported Oracle Database versions 9 and above, using multi-row OUT parameters. This is not the same implementation as "executemanyvalues", however has the same usage patterns and equivalent performance benefits. diff --git a/doc/build/core/constraints.rst b/doc/build/core/constraints.rst index c63ad858e2c..9251bbf8306 100644 --- a/doc/build/core/constraints.rst +++ b/doc/build/core/constraints.rst @@ -308,8 +308,12 @@ arguments. The value is any string which will be output after the appropriate ), ) -Note that these clauses require ``InnoDB`` tables when used with MySQL. -They may also not be supported on other databases. +Note that some backends have special requirements for cascades to function: + +* MySQL / MariaDB - the ``InnoDB`` storage engine should be used (this is + typically the default in modern databases) +* SQLite - constraints are not enabled by default. + See :ref:`sqlite_foreign_keys` .. seealso:: @@ -320,6 +324,12 @@ They may also not be supported on other databases. :ref:`passive_deletes_many_to_many` + :ref:`postgresql_constraint_options` - indicates additional options + available for foreign key cascades such as column lists + + :ref:`sqlite_foreign_keys` - background on enabling foreign key support + with SQLite + .. _schema_unique_constraint: UNIQUE Constraint diff --git a/doc/build/core/custom_types.rst b/doc/build/core/custom_types.rst index 6ae9e066ace..4b27f2f18a2 100644 --- a/doc/build/core/custom_types.rst +++ b/doc/build/core/custom_types.rst @@ -15,7 +15,7 @@ A frequent need is to force the "string" version of a type, that is the one rendered in a CREATE TABLE statement or other SQL function like CAST, to be changed. For example, an application may want to force the rendering of ``BINARY`` for all platforms -except for one, in which is wants ``BLOB`` to be rendered. Usage +except for one, in which it wants ``BLOB`` to be rendered. Usage of an existing generic type, in this case :class:`.LargeBinary`, is preferred for most use cases. But to control types more accurately, a compilation directive that is per-dialect @@ -156,7 +156,7 @@ denormalize:: def process_bind_param(self, value, dialect): if value is not None: - if not value.tzinfo: + if not value.tzinfo or value.tzinfo.utcoffset(value) is None: raise TypeError("tzinfo is required") value = value.astimezone(datetime.timezone.utc).replace(tzinfo=None) return value @@ -173,7 +173,7 @@ Backend-agnostic GUID Type .. note:: Since version 2.0 the built-in :class:`_types.Uuid` type that behaves similarly should be preferred. This example is presented - just as an example of a type decorator that recieves and returns + just as an example of a type decorator that receives and returns python objects. Receives and returns Python uuid() objects. @@ -212,10 +212,8 @@ string, using a CHAR(36) type:: return dialect.type_descriptor(self._default_type) def process_bind_param(self, value, dialect): - if value is None: + if value is None or dialect.name in ("postgresql", "mssql"): return value - elif dialect.name in ("postgresql", "mssql"): - return str(value) else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) @@ -527,7 +525,10 @@ transparently:: with engine.begin() as conn: metadata_obj.create_all(conn) - conn.execute(message.insert(), username="some user", message="this is my message") + conn.execute( + message.insert(), + {"username": "some user", "message": "this is my message"}, + ) print( conn.scalar(select(message.c.message).where(message.c.username == "some user")) diff --git a/doc/build/core/defaults.rst b/doc/build/core/defaults.rst index ef5ad208159..586f0531438 100644 --- a/doc/build/core/defaults.rst +++ b/doc/build/core/defaults.rst @@ -349,7 +349,7 @@ SQLAlchemy represents database sequences using the :class:`~sqlalchemy.schema.Sequence` object, which is considered to be a special case of "column default". It only has an effect on databases which have explicit support for sequences, which among SQLAlchemy's included dialects -includes PostgreSQL, Oracle, MS SQL Server, and MariaDB. The +includes PostgreSQL, Oracle Database, MS SQL Server, and MariaDB. The :class:`~sqlalchemy.schema.Sequence` object is otherwise ignored. .. tip:: @@ -466,8 +466,8 @@ column:: In the above example, ``CREATE TABLE`` for PostgreSQL will make use of the ``SERIAL`` datatype for the ``cart_id`` column, and the ``cart_id_seq`` -sequence will be ignored. However on Oracle, the ``cart_id_seq`` sequence -will be created explicitly. +sequence will be ignored. However on Oracle Database, the ``cart_id_seq`` +sequence will be created explicitly. .. tip:: @@ -544,7 +544,7 @@ Associating a Sequence as the Server Side Default ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. note:: The following technique is known to work only with the PostgreSQL - database. It does not work with Oracle. + database. It does not work with Oracle Database. The preceding sections illustrate how to associate a :class:`.Sequence` with a :class:`_schema.Column` as the **Python side default generator**:: @@ -627,7 +627,7 @@ including the default schema, if any. :ref:`postgresql_sequences` - in the PostgreSQL dialect documentation - :ref:`oracle_returning` - in the Oracle dialect documentation + :ref:`oracle_returning` - in the Oracle Database dialect documentation .. _computed_ddl: @@ -704,9 +704,9 @@ eagerly fetched. * PostgreSQL as of version 12 -* Oracle - with the caveat that RETURNING does not work correctly with UPDATE - (a warning will be emitted to this effect when the UPDATE..RETURNING that - includes a computed column is rendered) +* Oracle Database - with the caveat that RETURNING does not work correctly with + UPDATE (a warning will be emitted to this effect when the UPDATE..RETURNING + that includes a computed column is rendered) * Microsoft SQL Server @@ -792,7 +792,7 @@ The :class:`.Identity` construct is currently known to be supported by: * PostgreSQL as of version 10. -* Oracle as of version 12. It also supports passing ``always=None`` to +* Oracle Database as of version 12. It also supports passing ``always=None`` to enable the default generated mode and the parameter ``on_null=True`` to specify "ON NULL" in conjunction with a "BY DEFAULT" identity column. diff --git a/doc/build/core/dml.rst b/doc/build/core/dml.rst index 7070277f14f..1724dd6985c 100644 --- a/doc/build/core/dml.rst +++ b/doc/build/core/dml.rst @@ -32,11 +32,15 @@ Class documentation for the constructors listed at .. automethod:: Delete.where + .. automethod:: Delete.with_dialect_options + .. automethod:: Delete.returning .. autoclass:: Insert :members: + .. automethod:: Insert.with_dialect_options + .. automethod:: Insert.values .. automethod:: Insert.returning @@ -48,6 +52,8 @@ Class documentation for the constructors listed at .. automethod:: Update.where + .. automethod:: Update.with_dialect_options + .. automethod:: Update.values .. autoclass:: sqlalchemy.sql.expression.UpdateBase diff --git a/doc/build/core/engines.rst b/doc/build/core/engines.rst index 3397a65e83e..8ac57cdaaf3 100644 --- a/doc/build/core/engines.rst +++ b/doc/build/core/engines.rst @@ -200,13 +200,23 @@ More notes on connecting to MySQL at :ref:`mysql_toplevel`. Oracle ^^^^^^^^^^ -The Oracle dialect uses cx_oracle as the default DBAPI:: +The preferred Oracle Database dialect uses the python-oracledb driver as the +DBAPI:: - engine = create_engine("oracle://scott:tiger@127.0.0.1:1521/sidname") + engine = create_engine( + "oracle+oracledb://scott:tiger@127.0.0.1:1521/?service_name=freepdb1" + ) - engine = create_engine("oracle+cx_oracle://scott:tiger@tnsname") + engine = create_engine("oracle+oracledb://scott:tiger@tnsalias") -More notes on connecting to Oracle at :ref:`oracle_toplevel`. +For historical reasons, the Oracle dialect uses the obsolete cx_Oracle driver +as the default DBAPI:: + + engine = create_engine("oracle://scott:tiger@127.0.0.1:1521/?service_name=freepdb1") + + engine = create_engine("oracle+cx_oracle://scott:tiger@tnsalias") + +More notes on connecting to Oracle Database at :ref:`oracle_toplevel`. Microsoft SQL Server ^^^^^^^^^^^^^^^^^^^^ @@ -578,21 +588,57 @@ getting duplicate log lines. Setting the Logging Name ------------------------- -The logger name of instance such as an :class:`~sqlalchemy.engine.Engine` or -:class:`~sqlalchemy.pool.Pool` defaults to using a truncated hex identifier -string. To set this to a specific name, use the +The logger name for :class:`~sqlalchemy.engine.Engine` or +:class:`~sqlalchemy.pool.Pool` is set to be the module-qualified class name of the +object. This name can be further qualified with an additional name +using the :paramref:`_sa.create_engine.logging_name` and -:paramref:`_sa.create_engine.pool_logging_name` with -:func:`sqlalchemy.create_engine`:: +:paramref:`_sa.create_engine.pool_logging_name` parameters with +:func:`sqlalchemy.create_engine`; the name will be appended to existing +class-qualified logging name. This use is recommended for applications that +make use of multiple global :class:`.Engine` instances simultaenously, so +that they may be distinguished in logging:: + >>> import logging >>> from sqlalchemy import create_engine >>> from sqlalchemy import text - >>> e = create_engine("sqlite://", echo=True, logging_name="myengine") + >>> logging.basicConfig() + >>> logging.getLogger("sqlalchemy.engine.Engine.myengine").setLevel(logging.INFO) + >>> e = create_engine("sqlite://", logging_name="myengine") >>> with e.connect() as conn: ... conn.execute(text("select 'hi'")) 2020-10-24 12:47:04,291 INFO sqlalchemy.engine.Engine.myengine select 'hi' 2020-10-24 12:47:04,292 INFO sqlalchemy.engine.Engine.myengine () +.. tip:: + + The :paramref:`_sa.create_engine.logging_name` and + :paramref:`_sa.create_engine.pool_logging_name` parameters may also be used in + conjunction with :paramref:`_sa.create_engine.echo` and + :paramref:`_sa.create_engine.echo_pool`. However, an unavoidable double logging + condition will occur if other engines are created with echo flags set to True + and **no** logging name. This is because a handler will be added automatically + for ``sqlalchemy.engine.Engine`` which will log messages both for the name-less + engine as well as engines with logging names. For example:: + + from sqlalchemy import create_engine, text + + e1 = create_engine("sqlite://", echo=True, logging_name="myname") + with e1.begin() as conn: + conn.execute(text("SELECT 1")) + + e2 = create_engine("sqlite://", echo=True) + with e2.begin() as conn: + conn.execute(text("SELECT 2")) + + with e1.begin() as conn: + conn.execute(text("SELECT 3")) + + The above scenario will double log ``SELECT 3``. To resolve, ensure + all engines have a ``logging_name`` set, or use explicit logger / handler + setup without using :paramref:`_sa.create_engine.echo` and + :paramref:`_sa.create_engine.echo_pool`. + .. _dbengine_logging_tokens: Setting Per-Connection / Sub-Engine Tokens @@ -616,7 +662,7 @@ tokens:: >>> from sqlalchemy import create_engine >>> e = create_engine("sqlite://", echo="debug") >>> with e.connect().execution_options(logging_token="track1") as conn: - ... conn.execute("select 1").all() + ... conn.execute(text("select 1")).all() 2021-02-03 11:48:45,754 INFO sqlalchemy.engine.Engine [track1] select 1 2021-02-03 11:48:45,754 INFO sqlalchemy.engine.Engine [track1] [raw sql] () 2021-02-03 11:48:45,754 DEBUG sqlalchemy.engine.Engine [track1] Col ('1',) @@ -633,14 +679,14 @@ of an application without creating new engines:: >>> e1 = e.execution_options(logging_token="track1") >>> e2 = e.execution_options(logging_token="track2") >>> with e1.connect() as conn: - ... conn.execute("select 1").all() + ... conn.execute(text("select 1")).all() 2021-02-03 11:51:08,960 INFO sqlalchemy.engine.Engine [track1] select 1 2021-02-03 11:51:08,960 INFO sqlalchemy.engine.Engine [track1] [raw sql] () 2021-02-03 11:51:08,960 DEBUG sqlalchemy.engine.Engine [track1] Col ('1',) 2021-02-03 11:51:08,961 DEBUG sqlalchemy.engine.Engine [track1] Row (1,) >>> with e2.connect() as conn: - ... conn.execute("select 2").all() + ... conn.execute(text("select 2")).all() 2021-02-03 11:52:05,518 INFO sqlalchemy.engine.Engine [track2] Select 1 2021-02-03 11:52:05,519 INFO sqlalchemy.engine.Engine [track2] [raw sql] () 2021-02-03 11:52:05,520 DEBUG sqlalchemy.engine.Engine [track2] Col ('1',) @@ -660,4 +706,3 @@ these parameters from being logged for privacy purposes, enable the ... conn.execute(text("select :some_private_name"), {"some_private_name": "pii"}) 2020-10-24 12:48:32,808 INFO sqlalchemy.engine.Engine select ? 2020-10-24 12:48:32,808 INFO sqlalchemy.engine.Engine [SQL parameters hidden due to hide_parameters=True] - diff --git a/doc/build/core/event.rst b/doc/build/core/event.rst index 427da8fb15b..e07329f4e75 100644 --- a/doc/build/core/event.rst +++ b/doc/build/core/event.rst @@ -140,6 +140,33 @@ this value can be supported:: # it to use the return value listen(UserContact.phone, "set", validate_phone, retval=True) +Events and Multiprocessing +-------------------------- + +SQLAlchemy's event hooks are implemented with Python functions and objects, +so events propagate via Python function calls. +Python multiprocessing follows the +same way we think about OS multiprocessing, +such as a parent process forking a child process, +thus we can describe the SQLAlchemy event system's behavior using the same model. + +Event hooks registered in a parent process +will be present in new child processes +that are forked from that parent after the hooks have been registered, +since the child process starts with +a copy of all existing Python structures from the parent when spawned. +Child processes that already exist before the hooks are registered +will not receive those new event hooks, +as changes made to Python structures in a parent process +do not propagate to child processes. + +For the events themselves, these are Python function calls, +which do not have any ability to propagate between processes. +SQLAlchemy's event system does not implement any inter-process communication. +It is possible to implement event hooks +that use Python inter-process messaging within them, +however this would need to be implemented by the user. + Event Reference --------------- diff --git a/doc/build/core/metadata.rst b/doc/build/core/metadata.rst index 1a933828856..318509bbdac 100644 --- a/doc/build/core/metadata.rst +++ b/doc/build/core/metadata.rst @@ -296,9 +296,9 @@ refer to alternate sets of tables and other constructs. The server-side geometry of a "schema" takes many forms, including names of "schemas" under the scope of a particular database (e.g. PostgreSQL schemas), named sibling databases (e.g. MySQL / MariaDB access to other databases on the same server), -as well as other concepts like tables owned by other usernames (Oracle, SQL -Server) or even names that refer to alternate database files (SQLite ATTACH) or -remote servers (Oracle DBLINK with synonyms). +as well as other concepts like tables owned by other usernames (Oracle +Database, SQL Server) or even names that refer to alternate database files +(SQLite ATTACH) or remote servers (Oracle Database DBLINK with synonyms). What all of the above approaches have (mostly) in common is that there's a way of referencing this alternate set of tables using a string name. SQLAlchemy @@ -328,14 +328,15 @@ schema names on a per-connection or per-statement basis. "database" that typically has a single "owner". Within this database there can be any number of "schemas" which then contain the actual table objects. - A table within a specific schema is referenced explicitly using the - syntax ".". Contrast this to an architecture such - as that of MySQL, where there are only "databases", however SQL statements - can refer to multiple databases at once, using the same syntax except it - is ".". On Oracle, this syntax refers to yet another - concept, the "owner" of a table. Regardless of which kind of database is - in use, SQLAlchemy uses the phrase "schema" to refer to the qualifying - identifier within the general syntax of ".". + A table within a specific schema is referenced explicitly using the syntax + ".". Contrast this to an architecture such as that + of MySQL, where there are only "databases", however SQL statements can + refer to multiple databases at once, using the same syntax except it is + ".". On Oracle Database, this syntax refers to yet + another concept, the "owner" of a table. Regardless of which kind of + database is in use, SQLAlchemy uses the phrase "schema" to refer to the + qualifying identifier within the general syntax of + ".". .. seealso:: @@ -510,17 +511,19 @@ These names are usually configured at the login level, such as when connecting to a PostgreSQL database, the default "schema" is called "public". There are often cases where the default "schema" cannot be set via the login -itself and instead would usefully be configured each time a connection -is made, using a statement such as "SET SEARCH_PATH" on PostgreSQL or -"ALTER SESSION" on Oracle. These approaches may be achieved by using -the :meth:`_pool.PoolEvents.connect` event, which allows access to the -DBAPI connection when it is first created. For example, to set the -Oracle CURRENT_SCHEMA variable to an alternate name:: +itself and instead would usefully be configured each time a connection is made, +using a statement such as "SET SEARCH_PATH" on PostgreSQL or "ALTER SESSION" on +Oracle Database. These approaches may be achieved by using the +:meth:`_pool.PoolEvents.connect` event, which allows access to the DBAPI +connection when it is first created. For example, to set the Oracle Database +CURRENT_SCHEMA variable to an alternate name:: from sqlalchemy import event from sqlalchemy import create_engine - engine = create_engine("oracle+cx_oracle://scott:tiger@tsn_name") + engine = create_engine( + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" + ) @event.listens_for(engine, "connect", insert=True) diff --git a/doc/build/core/operators.rst b/doc/build/core/operators.rst index 0450aab03ee..35c25fe75c3 100644 --- a/doc/build/core/operators.rst +++ b/doc/build/core/operators.rst @@ -303,7 +303,7 @@ databases support: using the :meth:`_sql.ColumnOperators.__eq__` overloaded operator, i.e. ``==``, in conjunction with the ``None`` or :func:`_sql.null` value. In this way, there's typically not a need to use :meth:`_sql.ColumnOperators.is_` - explicitly, paricularly when used with a dynamic value:: + explicitly, particularly when used with a dynamic value:: >>> a = None >>> print(column("x") == a) diff --git a/doc/build/core/pooling.rst b/doc/build/core/pooling.rst index 78bbdcb1af8..1a4865ba2b9 100644 --- a/doc/build/core/pooling.rst +++ b/doc/build/core/pooling.rst @@ -50,6 +50,13 @@ queued up - the pool would only grow to that size if the application actually used five connections concurrently, in which case the usage of a small pool is an entirely appropriate default behavior. +.. note:: The :class:`.QueuePool` class is **not compatible with asyncio**. + When using :class:`_asyncio.create_async_engine` to create an instance of + :class:`.AsyncEngine`, the :class:`_pool.AsyncAdaptedQueuePool` class, + which makes use of an asyncio-compatible queue implementation, is used + instead. + + .. _pool_switching: Switching Pool Implementations @@ -502,30 +509,32 @@ particular error should be considered a "disconnect" situation or not, as well as if this disconnect should cause the entire connection pool to be invalidated or not. -For example, to add support to consider the Oracle error codes -``DPY-1001`` and ``DPY-4011`` to be handled as disconnect codes, apply an -event handler to the engine after creation:: +For example, to add support to consider the Oracle Database driver error codes +``DPY-1001`` and ``DPY-4011`` to be handled as disconnect codes, apply an event +handler to the engine after creation:: import re from sqlalchemy import create_engine - engine = create_engine("oracle://scott:tiger@dnsname") + engine = create_engine( + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" + ) @event.listens_for(engine, "handle_error") def handle_exception(context: ExceptionContext) -> None: if not context.is_disconnect and re.match( - r"^(?:DPI-1001|DPI-4011)", str(context.original_exception) + r"^(?:DPY-1001|DPY-4011)", str(context.original_exception) ): context.is_disconnect = True return None -The above error processing function will be invoked for all Oracle errors -raised, including those caught when using the -:ref:`pool pre ping ` feature for those backends -that rely upon disconnect error handling (new in 2.0). +The above error processing function will be invoked for all Oracle Database +errors raised, including those caught when using the :ref:`pool pre ping +` feature for those backends that rely upon +disconnect error handling (new in 2.0). .. seealso:: @@ -549,7 +558,7 @@ close these connections out. The difference between FIFO and LIFO is basically whether or not its desirable for the pool to keep a full set of connections ready to go even during idle periods:: - engine = create_engine("postgreql://", pool_use_lifo=True, pool_pre_ping=True) + engine = create_engine("postgresql://", pool_use_lifo=True, pool_pre_ping=True) Above, we also make use of the :paramref:`_sa.create_engine.pool_pre_ping` flag so that connections which are closed from the server side are gracefully @@ -713,6 +722,8 @@ like in the following example:: my_pool = create_pool_from_url("mysql+mysqldb://", poolclass=NullPool) +.. _pool_api: + API Documentation - Available Pool Implementations -------------------------------------------------- @@ -722,6 +733,9 @@ API Documentation - Available Pool Implementations .. autoclass:: sqlalchemy.pool.QueuePool :members: +.. autoclass:: sqlalchemy.pool.AsyncAdaptedQueuePool + :members: + .. autoclass:: SingletonThreadPool :members: @@ -748,4 +762,3 @@ API Documentation - Available Pool Implementations .. autoclass:: _ConnectionFairy .. autoclass:: _ConnectionRecord - diff --git a/doc/build/core/reflection.rst b/doc/build/core/reflection.rst index 4f3805b7ed2..043f6f8ee7e 100644 --- a/doc/build/core/reflection.rst +++ b/doc/build/core/reflection.rst @@ -123,8 +123,9 @@ object's dictionary of tables:: metadata_obj = MetaData() metadata_obj.reflect(bind=someengine) - for table in reversed(metadata_obj.sorted_tables): - someengine.execute(table.delete()) + with someengine.begin() as conn: + for table in reversed(metadata_obj.sorted_tables): + conn.execute(table.delete()) .. _metadata_reflection_schemas: diff --git a/doc/build/core/type_basics.rst b/doc/build/core/type_basics.rst index a8bb0f84afb..817bca601aa 100644 --- a/doc/build/core/type_basics.rst +++ b/doc/build/core/type_basics.rst @@ -63,9 +63,9 @@ not every backend has a real "boolean" datatype; some make use of integers or BIT values 0 and 1, some have boolean literal constants ``true`` and ``false`` while others dont. For this datatype, :class:`_types.Boolean` may render ``BOOLEAN`` on a backend such as PostgreSQL, ``BIT`` on the -MySQL backend and ``SMALLINT`` on Oracle. As data is sent and received -from the database using this type, based on the dialect in use it may be -interpreting Python numeric or boolean values. +MySQL backend and ``SMALLINT`` on Oracle Database. As data is sent and +received from the database using this type, based on the dialect in use it +may be interpreting Python numeric or boolean values. The typical SQLAlchemy application will likely wish to use primarily "CamelCase" types in the general case, as they will generally provide the best @@ -259,7 +259,9 @@ its exact name in DDL with ``CREATE TABLE`` is issued. .. autoclass:: ARRAY - :members: + :members: __init__, Comparator + :member-order: bysource + .. autoclass:: BIGINT @@ -334,5 +336,3 @@ its exact name in DDL with ``CREATE TABLE`` is issued. .. autoclass:: VARCHAR - - diff --git a/doc/build/dialects/index.rst b/doc/build/dialects/index.rst index 70ac258e401..bca807355c6 100644 --- a/doc/build/dialects/index.rst +++ b/doc/build/dialects/index.rst @@ -24,8 +24,8 @@ Included Dialects oracle mssql -Support Levels for Included Dialects -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Supported versions for Included Dialects +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following table summarizes the support level for each included dialect. @@ -35,21 +35,20 @@ The following table summarizes the support level for each included dialect. Support Definitions ^^^^^^^^^^^^^^^^^^^ -.. glossary:: + .. Fully tested in CI + .. **Fully tested in CI** indicates a version that is tested in the sqlalchemy + .. CI system and passes all the tests in the test suite. - Fully tested in CI - **Fully tested in CI** indicates a version that is tested in the sqlalchemy - CI system and passes all the tests in the test suite. +.. glossary:: - Normal support - **Normal support** indicates that most features should work, - but not all versions are tested in the ci configuration so there may - be some not supported edge cases. We will try to fix issues that affect - these versions. + Supported version + **Supported version** indicates that most SQLAlchemy features should work + for the mentioned database version. Since not all database versions may be + tested in the ci there may be some not working edge cases. Best effort - **Best effort** indicates that we try to support basic features on them, - but most likely there will be unsupported features or errors in some use cases. + **Best effort** indicates that SQLAlchemy tries to support basic features on these + versions, but most likely there will be unsupported features or errors in some use cases. Pull requests with associated issues may be accepted to continue supporting older versions, which are reviewed on a case-by-case basis. @@ -63,7 +62,7 @@ Currently maintained external dialect projects for SQLAlchemy include: +------------------------------------------------+---------------------------------------+ | Database | Dialect | +================================================+=======================================+ -| Actian Avalanche, Vector, Actian X, and Ingres | sqlalchemy-ingres_ | +| Actian Data Platform, Vector, Actian X, Ingres | sqlalchemy-ingres_ | +------------------------------------------------+---------------------------------------+ | Amazon Athena | pyathena_ | +------------------------------------------------+---------------------------------------+ @@ -77,9 +76,17 @@ Currently maintained external dialect projects for SQLAlchemy include: +------------------------------------------------+---------------------------------------+ | Apache Solr | sqlalchemy-solr_ | +------------------------------------------------+---------------------------------------+ +| Clickhouse | clickhouse-sqlalchemy_ | ++------------------------------------------------+---------------------------------------+ | CockroachDB | sqlalchemy-cockroachdb_ | +------------------------------------------------+---------------------------------------+ -| CrateDB | crate-python_ | +| CrateDB | sqlalchemy-cratedb_ | ++------------------------------------------------+---------------------------------------+ +| Databend | databend-sqlalchemy_ | ++------------------------------------------------+---------------------------------------+ +| Databricks | databricks_ | ++------------------------------------------------+---------------------------------------+ +| Denodo | denodo-sqlalchemy_ | +------------------------------------------------+---------------------------------------+ | EXASolution | sqlalchemy_exasol_ | +------------------------------------------------+---------------------------------------+ @@ -89,21 +96,29 @@ Currently maintained external dialect projects for SQLAlchemy include: +------------------------------------------------+---------------------------------------+ | Firebolt | firebolt-sqlalchemy_ | +------------------------------------------------+---------------------------------------+ -| Google BigQuery | pybigquery_ | +| Google BigQuery | sqlalchemy-bigquery_ | +------------------------------------------------+---------------------------------------+ | Google Sheets | gsheets_ | +------------------------------------------------+---------------------------------------+ +| Greenplum | sqlalchemy-greenplum_ | ++------------------------------------------------+---------------------------------------+ +| HyperSQL (hsqldb) | sqlalchemy-hsqldb_ | ++------------------------------------------------+---------------------------------------+ | IBM DB2 and Informix | ibm-db-sa_ | +------------------------------------------------+---------------------------------------+ | IBM Netezza Performance Server [1]_ | nzalchemy_ | +------------------------------------------------+---------------------------------------+ +| Impala | impyla_ | ++------------------------------------------------+---------------------------------------+ +| Kinetica | sqlalchemy-kinetica_ | ++------------------------------------------------+---------------------------------------+ | Microsoft Access (via pyodbc) | sqlalchemy-access_ | +------------------------------------------------+---------------------------------------+ -| Microsoft SQL Server (via python-tds) | sqlalchemy-tds_ | +| Microsoft SQL Server (via python-tds) | sqlalchemy-pytds_ | +------------------------------------------------+---------------------------------------+ | Microsoft SQL Server (via turbodbc) | sqlalchemy-turbodbc_ | +------------------------------------------------+---------------------------------------+ -| MonetDB [1]_ | sqlalchemy-monetdb_ | +| MonetDB | sqlalchemy-monetdb_ | +------------------------------------------------+---------------------------------------+ | OpenGauss | openGauss-sqlalchemy_ | +------------------------------------------------+---------------------------------------+ @@ -111,7 +126,7 @@ Currently maintained external dialect projects for SQLAlchemy include: +------------------------------------------------+---------------------------------------+ | SAP ASE (fork of former Sybase dialect) | sqlalchemy-sybase_ | +------------------------------------------------+---------------------------------------+ -| SAP Hana [1]_ | sqlalchemy-hana_ | +| SAP HANA | sqlalchemy-hana_ | +------------------------------------------------+---------------------------------------+ | SAP Sybase SQL Anywhere | sqlalchemy-sqlany_ | +------------------------------------------------+---------------------------------------+ @@ -119,27 +134,33 @@ Currently maintained external dialect projects for SQLAlchemy include: +------------------------------------------------+---------------------------------------+ | Teradata Vantage | teradatasqlalchemy_ | +------------------------------------------------+---------------------------------------+ +| TiDB | sqlalchemy-tidb_ | ++------------------------------------------------+---------------------------------------+ +| YDB | ydb-sqlalchemy_ | ++------------------------------------------------+---------------------------------------+ +| YugabyteDB | sqlalchemy-yugabytedb_ | ++------------------------------------------------+---------------------------------------+ .. [1] Supports version 1.3.x only at the moment. .. _openGauss-sqlalchemy: https://gitee.com/opengauss/openGauss-sqlalchemy .. _rockset-sqlalchemy: https://pypi.org/project/rockset-sqlalchemy -.. _sqlalchemy-ingres: https://github.com/clach04/ingres_sa_dialect +.. _sqlalchemy-ingres: https://github.com/ActianCorp/sqlalchemy-ingres .. _nzalchemy: https://pypi.org/project/nzalchemy/ .. _ibm-db-sa: https://pypi.org/project/ibm-db-sa/ .. _PyHive: https://github.com/dropbox/PyHive#sqlalchemy .. _teradatasqlalchemy: https://pypi.org/project/teradatasqlalchemy/ -.. _pybigquery: https://github.com/mxmzdlv/pybigquery/ +.. _sqlalchemy-bigquery: https://pypi.org/project/sqlalchemy-bigquery/ .. _sqlalchemy-redshift: https://pypi.org/project/sqlalchemy-redshift .. _sqlalchemy-drill: https://github.com/JohnOmernik/sqlalchemy-drill .. _sqlalchemy-hana: https://github.com/SAP/sqlalchemy-hana .. _sqlalchemy-solr: https://github.com/aadel/sqlalchemy-solr .. _sqlalchemy_exasol: https://github.com/blue-yonder/sqlalchemy_exasol .. _sqlalchemy-sqlany: https://github.com/sqlanywhere/sqlalchemy-sqlany -.. _sqlalchemy-monetdb: https://github.com/gijzelaerr/sqlalchemy-monetdb +.. _sqlalchemy-monetdb: https://github.com/MonetDB/sqlalchemy-monetdb .. _snowflake-sqlalchemy: https://github.com/snowflakedb/snowflake-sqlalchemy -.. _sqlalchemy-tds: https://github.com/m32/sqlalchemy-tds -.. _crate-python: https://github.com/crate/crate-python +.. _sqlalchemy-pytds: https://pypi.org/project/sqlalchemy-pytds/ +.. _sqlalchemy-cratedb: https://github.com/crate/sqlalchemy-cratedb .. _sqlalchemy-access: https://pypi.org/project/sqlalchemy-access/ .. _elasticsearch-dbapi: https://github.com/preset-io/elasticsearch-dbapi/ .. _pydruid: https://github.com/druid-io/pydruid @@ -150,3 +171,14 @@ Currently maintained external dialect projects for SQLAlchemy include: .. _sqlalchemy-sybase: https://pypi.org/project/sqlalchemy-sybase/ .. _firebolt-sqlalchemy: https://pypi.org/project/firebolt-sqlalchemy/ .. _pyathena: https://github.com/laughingman7743/PyAthena/ +.. _sqlalchemy-yugabytedb: https://pypi.org/project/sqlalchemy-yugabytedb/ +.. _impyla: https://pypi.org/project/impyla/ +.. _databend-sqlalchemy: https://github.com/datafuselabs/databend-sqlalchemy +.. _sqlalchemy-greenplum: https://github.com/PlaidCloud/sqlalchemy-greenplum +.. _sqlalchemy-hsqldb: https://pypi.org/project/sqlalchemy-hsqldb/ +.. _databricks: https://docs.databricks.com/en/dev-tools/sqlalchemy.html +.. _clickhouse-sqlalchemy: https://pypi.org/project/clickhouse-sqlalchemy/ +.. _sqlalchemy-kinetica: https://github.com/kineticadb/sqlalchemy-kinetica/ +.. _sqlalchemy-tidb: https://github.com/pingcap/sqlalchemy-tidb +.. _ydb-sqlalchemy: https://github.com/ydb-platform/ydb-sqlalchemy/ +.. _denodo-sqlalchemy: https://pypi.org/project/denodo-sqlalchemy/ diff --git a/doc/build/dialects/mysql.rst b/doc/build/dialects/mysql.rst index a46bf721e21..657cd2a4189 100644 --- a/doc/build/dialects/mysql.rst +++ b/doc/build/dialects/mysql.rst @@ -56,7 +56,14 @@ valid with MySQL are importable from the top level dialect:: YEAR, ) -Types which are specific to MySQL, or have MySQL-specific +In addition to the above types, MariaDB also supports the following:: + + from sqlalchemy.dialects.mysql import ( + INET4, + INET6, + ) + +Types which are specific to MySQL or MariaDB, or have specific construction arguments, are as follows: .. note: where :noindex: is used, indicates a type that is not redefined @@ -117,6 +124,10 @@ construction arguments, are as follows: :members: __init__ +.. autoclass:: INET4 + +.. autoclass:: INET6 + .. autoclass:: INTEGER :members: __init__ diff --git a/doc/build/dialects/oracle.rst b/doc/build/dialects/oracle.rst index 8187e714798..882f9266047 100644 --- a/doc/build/dialects/oracle.rst +++ b/doc/build/dialects/oracle.rst @@ -5,12 +5,12 @@ Oracle .. automodule:: sqlalchemy.dialects.oracle.base -Oracle Data Types ------------------ +Oracle Database Data Types +-------------------------- -As with all SQLAlchemy dialects, all UPPERCASE types that are known to be -valid with Oracle are importable from the top level dialect, whether -they originate from :mod:`sqlalchemy.types` or from the local dialect:: +As with all SQLAlchemy dialects, all UPPERCASE types that are known to be valid +with Oracle Database are importable from the top level dialect, whether they +originate from :mod:`sqlalchemy.types` or from the local dialect:: from sqlalchemy.dialects.oracle import ( BFILE, @@ -31,12 +31,13 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect:: TIMESTAMP, VARCHAR, VARCHAR2, + VECTOR, ) .. versionadded:: 1.2.19 Added :class:`_types.NCHAR` to the list of datatypes exported by the Oracle dialect. -Types which are specific to Oracle, or have Oracle-specific +Types which are specific to Oracle Database, or have Oracle-specific construction arguments, are as follows: .. currentmodule:: sqlalchemy.dialects.oracle @@ -80,12 +81,22 @@ construction arguments, are as follows: .. autoclass:: TIMESTAMP :members: __init__ -.. _cx_oracle: +.. autoclass:: VECTOR + :members: __init__ -cx_Oracle ---------- +.. autoclass:: VectorIndexType + :members: + +.. autoclass:: VectorIndexConfig + :members: + :undoc-members: + +.. autoclass:: VectorStorageFormat + :members: + +.. autoclass:: VectorDistanceType + :members: -.. automodule:: sqlalchemy.dialects.oracle.cx_oracle .. _oracledb: @@ -94,3 +105,9 @@ python-oracledb .. automodule:: sqlalchemy.dialects.oracle.oracledb +.. _cx_oracle: + +cx_Oracle +--------- + +.. automodule:: sqlalchemy.dialects.oracle.cx_oracle diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 0575837185c..2d377e3623e 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -238,6 +238,8 @@ dialect, **does not** support multirange datatypes. .. versionadded:: 2.0.17 Added multirange support for the pg8000 dialect. pg8000 1.29.8 or greater is required. +.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added. + The example below illustrates use of the :class:`_postgresql.TSMULTIRANGE` datatype:: @@ -260,6 +262,7 @@ datatype:: id: Mapped[int] = mapped_column(primary_key=True) event_name: Mapped[str] + added: Mapped[datetime] in_session_periods: Mapped[List[Range[datetime]]] = mapped_column(TSMULTIRANGE) Illustrating insertion and selecting of a record:: @@ -294,6 +297,38 @@ Illustrating insertion and selecting of a record:: a new list to the attribute, or use the :class:`.MutableList` type modifier. See the section :ref:`mutable_toplevel` for background. +.. _postgresql_multirange_list_use: + +Use of a MultiRange sequence to infer the multirange type +""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +When using a multirange as a literal without specifying the type +the utility :class:`_postgresql.MultiRange` sequence can be used:: + + from sqlalchemy import literal + from sqlalchemy.dialects.postgresql import MultiRange + + with Session(engine) as session: + stmt = select(EventCalendar).where( + EventCalendar.added.op("<@")( + MultiRange( + [ + Range(datetime(2023, 1, 1), datetime(2013, 3, 31)), + Range(datetime(2023, 7, 1), datetime(2013, 9, 30)), + ] + ) + ) + ) + in_range = session.execute(stmt).all() + + with engine.connect() as conn: + row = conn.scalar(select(literal(MultiRange([Range(2, 4)])))) + print(f"{row.lower} -> {row.upper}") + +Using a simple ``list`` instead of :class:`_postgresql.MultiRange` would require +manually setting the type of the literal value to the appropriate multirange type. + +.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added. The available multirange datatypes are as follows: @@ -416,12 +451,14 @@ construction arguments, are as follows: .. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange :members: comparator_factory +.. autoclass:: sqlalchemy.dialects.postgresql.AbstractSingleRange + .. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange .. autoclass:: ARRAY :members: __init__, Comparator - + :member-order: bysource .. autoclass:: BIT @@ -529,6 +566,9 @@ construction arguments, are as follows: .. autoclass:: TSTZMULTIRANGE +.. autoclass:: MultiRange + + PostgreSQL SQL Elements and Functions -------------------------------------- diff --git a/doc/build/errors.rst b/doc/build/errors.rst index 48fdedeace0..99701ba2790 100644 --- a/doc/build/errors.rst +++ b/doc/build/errors.rst @@ -136,7 +136,7 @@ What causes an application to use up all the connections that it has available? upon to release resources in a timely manner. A common reason this can occur is that the application uses ORM sessions and - does not call :meth:`.Session.close` upon them one the work involving that + does not call :meth:`.Session.close` upon them once the work involving that session is complete. Solution is to make sure ORM sessions if using the ORM, or engine-bound :class:`_engine.Connection` objects if using Core, are explicitly closed at the end of the work being done, either via the appropriate @@ -188,6 +188,28 @@ sooner. :ref:`connections_toplevel` +.. _error_pcls: + +Pool class cannot be used with asyncio engine (or vice versa) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`_pool.QueuePool` pool class uses a ``thread.Lock`` object internally +and is not compatible with asyncio. If using the :func:`_asyncio.create_async_engine` +function to create an :class:`.AsyncEngine`, the appropriate queue pool class +is :class:`_pool.AsyncAdaptedQueuePool`, which is used automatically and does +not need to be specified. + +In addition to :class:`_pool.AsyncAdaptedQueuePool`, the :class:`_pool.NullPool` +and :class:`_pool.StaticPool` pool classes do not use locks and are also +suitable for use with async engines. + +This error is also raised in reverse in the unlikely case that the +:class:`_pool.AsyncAdaptedQueuePool` pool class is indicated explicitly with +the :func:`_sa.create_engine` function. + +.. seealso:: + + :ref:`pooling_toplevel` .. _error_8s2b: @@ -453,7 +475,7 @@ when a construct is stringified without any dialect-specific information. However, there are many constructs that are specific to some particular kind of database dialect, for which the :class:`.StrSQLCompiler` doesn't know how to turn into a string, such as the PostgreSQL -`"insert on conflict" `_ construct:: +:ref:`postgresql_insert_on_conflict` construct:: >>> from sqlalchemy.dialects.postgresql import insert >>> from sqlalchemy import table, column @@ -550,7 +572,7 @@ is executed:: Above, no value has been provided for the parameter "my_param". The correct approach is to provide a value:: - result = conn.execute(stmt, my_param=12) + result = conn.execute(stmt, {"my_param": 12}) When the message takes the form "a value is required for bind parameter in parameter group ", the message is referring to the "executemany" style @@ -1777,8 +1799,7 @@ and associating the :class:`_engine.Engine` with the Base = declarative_base(metadata=metadata_obj) - class MyClass(Base): - ... + class MyClass(Base): ... session = Session() @@ -1796,8 +1817,7 @@ engine:: Base = declarative_base() - class MyClass(Base): - ... + class MyClass(Base): ... session = Session() diff --git a/doc/build/faq/connections.rst b/doc/build/faq/connections.rst index d93a4b1af76..3177d7ea926 100644 --- a/doc/build/faq/connections.rst +++ b/doc/build/faq/connections.rst @@ -258,11 +258,13 @@ statement executions:: fn(cursor_obj, statement, context=context, *arg) except engine.dialect.dbapi.Error as raw_dbapi_err: connection = context.root_connection - if engine.dialect.is_disconnect(raw_dbapi_err, connection, cursor_obj): - if retry > num_retries: - raise + if engine.dialect.is_disconnect( + raw_dbapi_err, connection.connection.dbapi_connection, cursor_obj + ): engine.logger.error( - "disconnection error, retrying operation", + "disconnection error, attempt %d/%d", + retry + 1, + num_retries + 1, exc_info=True, ) connection.invalidate() @@ -275,6 +277,9 @@ statement executions:: if trans: trans.rollback() + if retry == num_retries: + raise + time.sleep(retry_interval) context.cursor = cursor_obj = connection.connection.cursor() else: diff --git a/doc/build/faq/ormconfiguration.rst b/doc/build/faq/ormconfiguration.rst index 90d74d23ee9..bfcf117ae09 100644 --- a/doc/build/faq/ormconfiguration.rst +++ b/doc/build/faq/ormconfiguration.rst @@ -349,3 +349,94 @@ loads directly to primary key values just loaded. .. seealso:: :ref:`subquery_eager_loading` + +.. _defaults_default_factory_insert_default: + +What are ``default``, ``default_factory`` and ``insert_default`` and what should I use? +--------------------------------------------------------------------------------------- + +There's a bit of a clash in SQLAlchemy's API here due to the addition of PEP-681 +dataclass transforms, which is strict about its naming conventions. PEP-681 comes +into play if you are using :class:`_orm.MappedAsDataclass` as shown in :ref:`orm_declarative_native_dataclasses`. +If you are not using MappedAsDataclass, then it does not apply. + +Part One - Classic SQLAlchemy that is not using dataclasses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When **not** using :class:`_orm.MappedAsDataclass`, as has been the case for many years +in SQLAlchemy, the :func:`_orm.mapped_column` (and :class:`_schema.Column`) +construct supports a parameter :paramref:`_orm.mapped_column.default`. +This indicates a Python-side default (as opposed to a server side default that +would be part of your database's schema definition) that will take place when +an ``INSERT`` statement is emitted. This default can be **any** of a static Python value +like a string, **or** a Python callable function, **or** a SQLAlchemy SQL construct. +Full documentation for :paramref:`_orm.mapped_column.default` is at +:ref:`defaults_client_invoked_sql`. + +When using :paramref:`_orm.mapped_column.default` with an ORM mapping that is **not** +using :class:`_orm.MappedAsDataclass`, this default value /callable **does not show +up on your object when you first construct it**. It only takes place when SQLAlchemy +works up an ``INSERT`` statement for your object. + +A very important thing to note is that when using :func:`_orm.mapped_column` +(and :class:`_schema.Column`), the classic :paramref:`_orm.mapped_column.default` +parameter is also available under a new name, called +:paramref:`_orm.mapped_column.insert_default`. If you build a +:func:`_orm.mapped_column` and you are **not** using :class:`_orm.MappedAsDataclass`, the +:paramref:`_orm.mapped_column.default` and :paramref:`_orm.mapped_column.insert_default` +parameters are **synonymous**. + +Part Two - Using Dataclasses support with MappedAsDataclass +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When you **are** using :class:`_orm.MappedAsDataclass`, that is, the specific form +of mapping used at :ref:`orm_declarative_native_dataclasses`, the meaning of the +:paramref:`_orm.mapped_column.default` keyword changes. We recognize that it's not +ideal that this name changes its behavior, however there was no alternative as +PEP-681 requires :paramref:`_orm.mapped_column.default` to take on this meaning. + +When dataclasses are used, the :paramref:`_orm.mapped_column.default` parameter must +be used the way it's described at +`Python Dataclasses `_ - it refers +to a constant value like a string or a number, and **is applied to your object +immediately when constructed**. It is also at the moment also applied to the +:paramref:`_orm.mapped_column.default` parameter of :class:`_schema.Column` where +it would be used in an ``INSERT`` statement automatically even if not present +on the object. If you instead want to use a callable for your dataclass, +which will be applied to the object when constructed, you would use +:paramref:`_orm.mapped_column.default_factory`. + +To get access to the ``INSERT``-only behavior of :paramref:`_orm.mapped_column.default` +that is described in part one above, you would use the +:paramref:`_orm.mapped_column.insert_default` parameter instead. +:paramref:`_orm.mapped_column.insert_default` when dataclasses are used continues +to be a direct route to the Core-level "default" process where the parameter can +be a static value or callable. + +.. list-table:: Summary Chart + :header-rows: 1 + + * - Construct + - Works with dataclasses? + - Works without dataclasses? + - Accepts scalar? + - Accepts callable? + - Populates object immediately? + * - :paramref:`_orm.mapped_column.default` + - ✔ + - ✔ + - ✔ + - Only if no dataclasses + - Only if dataclasses + * - :paramref:`_orm.mapped_column.insert_default` + - ✔ + - ✔ + - ✔ + - ✔ + - ✖ + * - :paramref:`_orm.mapped_column.default_factory` + - ✔ + - ✖ + - ✖ + - ✔ + - Only if dataclasses diff --git a/doc/build/faq/sessions.rst b/doc/build/faq/sessions.rst index a2c61c0a41d..a95580ef514 100644 --- a/doc/build/faq/sessions.rst +++ b/doc/build/faq/sessions.rst @@ -370,7 +370,7 @@ See :ref:`session_deleting_from_collections` for a description of this behavior. why isn't my ``__init__()`` called when I load objects? ------------------------------------------------------- -See :ref:`mapping_constructors` for a description of this behavior. +See :ref:`mapped_class_load_events` for a description of this behavior. how do I use ON DELETE CASCADE with SA's ORM? --------------------------------------------- diff --git a/doc/build/faq/sqlexpressions.rst b/doc/build/faq/sqlexpressions.rst index 051d5cca204..7a03bdb0362 100644 --- a/doc/build/faq/sqlexpressions.rst +++ b/doc/build/faq/sqlexpressions.rst @@ -319,7 +319,7 @@ known values are passed. "Expanding" parameters are used for string can be safely cached independently of the actual lists of values being passed to a particular invocation of :meth:`_sql.ColumnOperators.in_`:: - >>> stmt = select(A).where(A.id.in_[1, 2, 3]) + >>> stmt = select(A).where(A.id.in_([1, 2, 3])) To render the IN clause with real bound parameter symbols, use the ``render_postcompile=True`` flag with :meth:`_sql.ClauseElement.compile`: diff --git a/doc/build/glossary.rst b/doc/build/glossary.rst index c3e49cacf61..1d8ac29aabe 100644 --- a/doc/build/glossary.rst +++ b/doc/build/glossary.rst @@ -298,7 +298,7 @@ Glossary A key limitation of the ``cursor.executemany()`` method as used with all known DBAPIs is that the ``cursor`` is not configured to return rows when this method is used. For **most** backends (a notable - exception being the cx_Oracle, / OracleDB DBAPIs), this means that + exception being the python-oracledb / cx_Oracle DBAPIs), this means that statements like ``INSERT..RETURNING`` typically cannot be used with ``cursor.executemany()`` directly, since DBAPIs typically do not aggregate the single row from each INSERT execution together. @@ -811,6 +811,19 @@ Glossary :ref:`session_basics` + flush + flushing + flushed + + This refers to the actual process used by the :term:`unit of work` + to emit changes to a database. In SQLAlchemy this process occurs + via the :class:`_orm.Session` object and is usually automatic, but + can also be controlled manually. + + .. seealso:: + + :ref:`session_flushing` + expire expired expires @@ -1038,7 +1051,6 @@ Glossary isolation isolated - Isolation isolation level The isolation property of the :term:`ACID` model ensures that the concurrent execution @@ -1146,16 +1158,17 @@ Glossary values as they are not included otherwise (but note any series of columns or SQL expressions can be placed into RETURNING, not just default-value columns). - The backends that currently support - RETURNING or a similar construct are PostgreSQL, SQL Server, Oracle, - and Firebird. The PostgreSQL and Firebird implementations are generally - full featured, whereas the implementations of SQL Server and Oracle - have caveats. On SQL Server, the clause is known as "OUTPUT INSERTED" - for INSERT and UPDATE statements and "OUTPUT DELETED" for DELETE statements; - the key caveat is that triggers are not supported in conjunction with this - keyword. On Oracle, it is known as "RETURNING...INTO", and requires that the - value be placed into an OUT parameter, meaning not only is the syntax awkward, - but it can also only be used for one row at a time. + The backends that currently support RETURNING or a similar construct + are PostgreSQL, SQL Server, Oracle Database, and Firebird. The + PostgreSQL and Firebird implementations are generally full featured, + whereas the implementations of SQL Server and Oracle Database have + caveats. On SQL Server, the clause is known as "OUTPUT INSERTED" for + INSERT and UPDATE statements and "OUTPUT DELETED" for DELETE + statements; the key caveat is that triggers are not supported in + conjunction with this keyword. In Oracle Database, it is known as + "RETURNING...INTO", and requires that the value be placed into an OUT + parameter, meaning not only is the syntax awkward, but it can also only + be used for one row at a time. SQLAlchemy's :meth:`.UpdateBase.returning` system provides a layer of abstraction on top of the RETURNING systems of these backends to provide a consistent @@ -1690,4 +1703,3 @@ Glossary .. seealso:: :ref:`session_object_states` - diff --git a/doc/build/index.rst b/doc/build/index.rst index 37b807723f3..44914b0bb54 100644 --- a/doc/build/index.rst +++ b/doc/build/index.rst @@ -149,9 +149,9 @@ SQLAlchemy Documentation This section describes notes, options, and usage patterns regarding individual dialects. :doc:`PostgreSQL ` | - :doc:`MySQL ` | + :doc:`MySQL and MariaDB ` | :doc:`SQLite ` | - :doc:`Oracle ` | + :doc:`Oracle Database ` | :doc:`Microsoft SQL Server ` :doc:`More Dialects ... ` @@ -168,7 +168,6 @@ SQLAlchemy Documentation * :doc:`Frequently Asked Questions ` - A collection of common problems and solutions * :doc:`Glossary ` - Terms used in SQLAlchemy's documentation - * :doc:`Error Message Guide ` - Explainations of many SQLAlchemy Errors + * :doc:`Error Message Guide ` - Explanations of many SQLAlchemy Errors * :doc:`Complete table of of contents ` * :ref:`Index ` - diff --git a/doc/build/intro.rst b/doc/build/intro.rst index cac103ed831..709d56b7b87 100644 --- a/doc/build/intro.rst +++ b/doc/build/intro.rst @@ -42,7 +42,7 @@ augmented by ORM-specific automations and object-centric querying capabilities. Whereas working with Core and the SQL Expression language presents a schema-centric view of the database, along with a programming paradigm that is oriented around immutability, the ORM builds on top of this a domain-centric -view of the database with a programming paradigm that is more explcitly +view of the database with a programming paradigm that is more explicitly object-oriented and reliant upon mutability. Since a relational database is itself a mutable service, the difference is that Core/SQL Expression language is command oriented whereas the ORM is state oriented. diff --git a/doc/build/orm/basic_relationships.rst b/doc/build/orm/basic_relationships.rst index 7e3ce5ec551..b4a3ed2b5f5 100644 --- a/doc/build/orm/basic_relationships.rst +++ b/doc/build/orm/basic_relationships.rst @@ -1018,7 +1018,7 @@ within any of these string expressions:: In an example like the above, the string passed to :class:`_orm.Mapped` can be disambiguated from a specific class argument by passing the class -location string directly to :paramref:`_orm.relationship.argument` as well. +location string directly to the first positional parameter (:paramref:`_orm.relationship.argument`) as well. Below illustrates a typing-only import for ``Child``, combined with a runtime specifier for the target class that will search for the correct name within the :class:`_orm.registry`:: @@ -1102,8 +1102,10 @@ that will be passed to ``eval()`` are: are **evaluated as Python code expressions using eval(). DO NOT PASS UNTRUSTED INPUT TO THESE ARGUMENTS.** +.. _orm_declarative_table_adding_relationship: + Adding Relationships to Mapped Classes After Declaration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ It should also be noted that in a similar way as described at :ref:`orm_declarative_table_adding_columns`, any :class:`_orm.MapperProperty` @@ -1116,15 +1118,13 @@ class were available, we could also apply it afterwards:: # we create a Parent class which knows nothing about Child - class Parent(Base): - ... + class Parent(Base): ... # ... later, in Module B, which is imported after module A: - class Child(Base): - ... + class Child(Base): ... from module_a import Parent diff --git a/doc/build/orm/cascades.rst b/doc/build/orm/cascades.rst index 02d68669eee..20f96001e33 100644 --- a/doc/build/orm/cascades.rst +++ b/doc/build/orm/cascades.rst @@ -301,6 +301,14 @@ The feature by default works completely independently of database-configured In order to integrate more efficiently with this configuration, additional directives described at :ref:`passive_deletes` should be used. +.. warning:: Note that the ORM's "delete" and "delete-orphan" behavior applies + **only** to the use of the :meth:`_orm.Session.delete` method to mark + individual ORM instances for deletion within the :term:`unit of work` process. + It does **not** apply to "bulk" deletes, which would be emitted using + the :func:`_sql.delete` construct as illustrated at + :ref:`orm_queryguide_update_delete_where`. See + :ref:`orm_queryguide_update_delete_caveats` for additional background. + .. seealso:: :ref:`passive_deletes` diff --git a/doc/build/orm/collection_api.rst b/doc/build/orm/collection_api.rst index 2d56bb9b2b0..be8e4ea9516 100644 --- a/doc/build/orm/collection_api.rst +++ b/doc/build/orm/collection_api.rst @@ -129,7 +129,7 @@ Python code, as well as in a few special cases, the collection class for a In the absence of :paramref:`_orm.relationship.collection_class` or :class:`_orm.Mapped`, the default collection type is ``list``. -Beyond ``list`` and ``set`` builtins, there is also support for two varities of +Beyond ``list`` and ``set`` builtins, there is also support for two varieties of dictionary, described below at :ref:`orm_dictionary_collection`. There is also support for any arbitrary mutable sequence type can be set up as the target collection, with some additional configuration steps; this is described in the @@ -533,8 +533,7 @@ methods can be changed as well: ... @collection.iterator - def hey_use_this_instead_for_iteration(self): - ... + def hey_use_this_instead_for_iteration(self): ... There is no requirement to be "list-like" or "set-like" at all. Collection classes can be any shape, so long as they have the append, remove and iterate diff --git a/doc/build/orm/composites.rst b/doc/build/orm/composites.rst index 2e625509e02..2fc62cbfd01 100644 --- a/doc/build/orm/composites.rst +++ b/doc/build/orm/composites.rst @@ -63,6 +63,12 @@ of the columns to be generated, in this case the names; the def __repr__(self): return f"Vertex(start={self.start}, end={self.end})" +.. tip:: In the example above the columns that represent the composites + (``x1``, ``y1``, etc.) are also accessible on the class but are not + correctly understood by type checkers. + If accessing the single columns is important they can be explicitly declared, + as shown in :ref:`composite_with_typing`. + The above mapping would correspond to a CREATE TABLE statement as: .. sourcecode:: pycon+sql @@ -182,14 +188,15 @@ Other mapping forms for composites The :func:`_orm.composite` construct may be passed the relevant columns using a :func:`_orm.mapped_column` construct, a :class:`_schema.Column`, or the string name of an existing mapped column. The following examples -illustrate an equvalent mapping as that of the main section above. +illustrate an equivalent mapping as that of the main section above. -* Map columns directly, then pass to composite +Map columns directly, then pass to composite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Here we pass the existing :func:`_orm.mapped_column` instances to the - :func:`_orm.composite` construct, as in the non-annotated example below - where we also pass the ``Point`` class as the first argument to - :func:`_orm.composite`:: +Here we pass the existing :func:`_orm.mapped_column` instances to the +:func:`_orm.composite` construct, as in the non-annotated example below +where we also pass the ``Point`` class as the first argument to +:func:`_orm.composite`:: from sqlalchemy import Integer from sqlalchemy.orm import mapped_column, composite @@ -207,11 +214,14 @@ illustrate an equvalent mapping as that of the main section above. start = composite(Point, x1, y1) end = composite(Point, x2, y2) -* Map columns directly, pass attribute names to composite +.. _composite_with_typing: + +Map columns directly, pass attribute names to composite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - We can write the same example above using more annotated forms where we have - the option to pass attribute names to :func:`_orm.composite` instead of - full column constructs:: +We can write the same example above using more annotated forms where we have +the option to pass attribute names to :func:`_orm.composite` instead of +full column constructs:: from sqlalchemy.orm import mapped_column, composite, Mapped @@ -228,12 +238,13 @@ illustrate an equvalent mapping as that of the main section above. start: Mapped[Point] = composite("x1", "y1") end: Mapped[Point] = composite("x2", "y2") -* Imperative mapping and imperative table +Imperative mapping and imperative table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - When using :ref:`imperative table ` or - fully :ref:`imperative ` mappings, we have access - to :class:`_schema.Column` objects directly. These may be passed to - :func:`_orm.composite` as well, as in the imperative example below:: +When using :ref:`imperative table ` or +fully :ref:`imperative ` mappings, we have access +to :class:`_schema.Column` objects directly. These may be passed to +:func:`_orm.composite` as well, as in the imperative example below:: mapper_registry.map_imperatively( Vertex, diff --git a/doc/build/orm/dataclasses.rst b/doc/build/orm/dataclasses.rst index b7d0bee4313..7f377ca3996 100644 --- a/doc/build/orm/dataclasses.rst +++ b/doc/build/orm/dataclasses.rst @@ -18,7 +18,7 @@ attrs_ third party integration library. .. _orm_declarative_native_dataclasses: Declarative Dataclass Mapping -------------------------------- +----------------------------- SQLAlchemy :ref:`Annotated Declarative Table ` mappings may be augmented with an additional @@ -41,7 +41,7 @@ decorator. limited and is currently known to be supported by Pyright_ as well as Mypy_ as of **version 1.2**. Note that Mypy 1.1.1 introduced :pep:`681` support but did not correctly accommodate Python descriptors - which will lead to errors when using SQLAlhcemy's ORM mapping scheme. + which will lead to errors when using SQLAlchemy's ORM mapping scheme. .. seealso:: @@ -278,17 +278,24 @@ parameter for ``created_at`` were passed proceeds as: Integration with Annotated ~~~~~~~~~~~~~~~~~~~~~~~~~~ -The approach introduced at :ref:`orm_declarative_mapped_column_pep593` illustrates -how to use :pep:`593` ``Annotated`` objects to package whole -:func:`_orm.mapped_column` constructs for re-use. This feature is supported -with the dataclasses feature. One aspect of the feature however requires -a workaround when working with typing tools, which is that the -:pep:`681`-specific arguments ``init``, ``default``, ``repr``, and ``default_factory`` -**must** be on the right hand side, packaged into an explicit :func:`_orm.mapped_column` -construct, in order for the typing tool to interpret the attribute correctly. -As an example, the approach below will work perfectly fine at runtime, -however typing tools will consider the ``User()`` construction to be -invalid, as they do not see the ``init=False`` parameter present:: +The approach introduced at :ref:`orm_declarative_mapped_column_pep593` +illustrates how to use :pep:`593` ``Annotated`` objects to package whole +:func:`_orm.mapped_column` constructs for re-use. While ``Annotated`` objects +can be combined with the use of dataclasses, **dataclass-specific keyword +arguments unfortunately cannot be used within the Annotated construct**. This +includes :pep:`681`-specific arguments ``init``, ``default``, ``repr``, and +``default_factory``, which **must** be present in a :func:`_orm.mapped_column` +or similar construct inline with the class attribute. + +.. versionchanged:: 2.0.14/2.0.22 the ``Annotated`` construct when used with + an ORM construct like :func:`_orm.mapped_column` cannot accommodate dataclass + field parameters such as ``init`` and ``repr`` - this use goes against the + design of Python dataclasses and is not supported by :pep:`681`, and therefore + is also rejected by the SQLAlchemy ORM at runtime. A deprecation warning + is now emitted and the attribute will be ignored. + +As an example, the ``init=False`` parameter below will be ignored and additionally +emit a deprecation warning:: from typing import Annotated @@ -296,7 +303,7 @@ invalid, as they do not see the ``init=False`` parameter present:: from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry - # typing tools will ignore init=False here + # typing tools as well as SQLAlchemy will ignore init=False here intpk = Annotated[int, mapped_column(init=False, primary_key=True)] reg = registry() @@ -308,7 +315,7 @@ invalid, as they do not see the ``init=False`` parameter present:: id: Mapped[intpk] - # typing error: Argument missing for parameter "id" + # typing error as well as runtime error: Argument missing for parameter "id" u1 = User() Instead, :func:`_orm.mapped_column` must be present on the right side @@ -424,7 +431,7 @@ scalar object references may make use of The above mapping will generate an empty list for ``Parent.children`` when a new ``Parent()`` object is constructed without passing ``children``, and similarly a ``None`` value for ``Child.parent`` when a new ``Child()`` object -is constructed without passsing ``parent``. +is constructed without passing ``parent``. While the :paramref:`_orm.relationship.default_factory` can be automatically derived from the given collection class of the :func:`_orm.relationship` @@ -705,6 +712,15 @@ which itself is specified within the ``__mapper_args__`` dictionary, so that it is passed to the constructor for :class:`_orm.Mapper`. An alternative to this approach is in the next example. + +.. warning:: + Declaring a dataclass ``field()`` setting a ``default`` together with ``init=False`` + will not work as would be expected with a totally plain dataclass, + since the SQLAlchemy class instrumentation will replace + the default value set on the class by the dataclass creation process. + Use ``default_factory`` instead. This adaptation is done automatically when + making use of :ref:`orm_declarative_native_dataclasses`. + .. _orm_declarative_dataclasses_declarative_table: Mapping pre-existing dataclasses using Declarative-style fields @@ -778,8 +794,8 @@ example at :ref:`orm_declarative_mixins_relationships`:: class RefTargetMixin: @declared_attr - def target_id(cls): - return Column("target_id", ForeignKey("target.id")) + def target_id(cls) -> Mapped[int]: + return mapped_column("target_id", ForeignKey("target.id")) @declared_attr def target(cls): @@ -909,11 +925,19 @@ variables:: mapper_registry.map_imperatively(Address, address) +The same warning mentioned in :ref:`orm_declarative_dataclasses_imperative_table` +applies when using this mapping style. + .. _orm_declarative_attrs_imperative_table: Applying ORM mappings to an existing attrs class ------------------------------------------------- +.. warning:: The ``attrs`` library is not part of SQLAlchemy's continuous + integration testing, and compatibility with this library may change without + notice due to incompatibilities introduced by either side. + + The attrs_ library is a popular third party library that provides similar features as dataclasses, with many additional features provided not found in ordinary dataclasses. @@ -923,103 +947,27 @@ initiates a process to scan the class for attributes that define the class' behavior, which are then used to generate methods, documentation, and annotations. -The SQLAlchemy ORM supports mapping an attrs_ class using **Declarative with -Imperative Table** or **Imperative** mapping. The general form of these two -styles is fully equivalent to the -:ref:`orm_declarative_dataclasses_declarative_table` and -:ref:`orm_declarative_dataclasses_imperative_table` mapping forms used with -dataclasses, where the inline attribute directives used by dataclasses or attrs -are unchanged, and SQLAlchemy's table-oriented instrumentation is applied at -runtime. +The SQLAlchemy ORM supports mapping an attrs_ class using **Imperative** mapping. +The general form of this style is equivalent to the +:ref:`orm_imperative_dataclasses` mapping form used with +dataclasses, where the class construction uses ``attrs`` alone, with ORM mappings +applied after the fact without any class attribute scanning. The ``@define`` decorator of attrs_ by default replaces the annotated class with a new __slots__ based class, which is not supported. When using the old style annotation ``@attr.s`` or using ``define(slots=False)``, the class -does not get replaced. Furthermore attrs removes its own class-bound attributes +does not get replaced. Furthermore ``attrs`` removes its own class-bound attributes after the decorator runs, so that SQLAlchemy's mapping process takes over these attributes without any issue. Both decorators, ``@attr.s`` and ``@define(slots=False)`` work with SQLAlchemy. -Mapping attrs with Declarative "Imperative Table" -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the "Declarative with Imperative Table" style, a :class:`_schema.Table` -object is declared inline with the declarative class. The -``@define`` decorator is applied to the class first, then the -:meth:`_orm.registry.mapped` decorator second:: - - from __future__ import annotations - - from typing import List - from typing import Optional - - from attrs import define - from sqlalchemy import Column - from sqlalchemy import ForeignKey - from sqlalchemy import Integer - from sqlalchemy import MetaData - from sqlalchemy import String - from sqlalchemy import Table - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import registry - from sqlalchemy.orm import relationship - - mapper_registry = registry() - - - @mapper_registry.mapped - @define(slots=False) - class User: - __table__ = Table( - "user", - mapper_registry.metadata, - Column("id", Integer, primary_key=True), - Column("name", String(50)), - Column("FullName", String(50), key="fullname"), - Column("nickname", String(12)), - ) - id: Mapped[int] - name: Mapped[str] - fullname: Mapped[str] - nickname: Mapped[str] - addresses: Mapped[List[Address]] - - __mapper_args__ = { # type: ignore - "properties": { - "addresses": relationship("Address"), - } - } - - - @mapper_registry.mapped - @define(slots=False) - class Address: - __table__ = Table( - "address", - mapper_registry.metadata, - Column("id", Integer, primary_key=True), - Column("user_id", Integer, ForeignKey("user.id")), - Column("email_address", String(50)), - ) - id: Mapped[int] - user_id: Mapped[int] - email_address: Mapped[Optional[str]] - -.. note:: The ``attrs`` ``slots=True`` option, which enables ``__slots__`` on - a mapped class, cannot be used with SQLAlchemy mappings without fully - implementing alternative - :ref:`attribute instrumentation `, as mapped - classes normally rely upon direct access to ``__dict__`` for state storage. - Behavior is undefined when this option is present. - - - -Mapping attrs with Imperative Mapping -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. versionchanged:: 2.0 SQLAlchemy integration with ``attrs`` works only + with imperative mapping style, that is, not using Declarative. + The introduction of ORM Annotated Declarative style is not cross-compatible + with ``attrs``. -Just as is the case with dataclasses, we can make use of -:meth:`_orm.registry.map_imperatively` to map an existing ``attrs`` class -as well:: +The ``attrs`` class is built first. The SQLAlchemy ORM mapping can be +applied after the fact using :meth:`_orm.registry.map_imperatively`:: from __future__ import annotations @@ -1083,11 +1031,6 @@ as well:: mapper_registry.map_imperatively(Address, address) -The above form is equivalent to the previous example using -Declarative with Imperative Table. - - - .. _dataclass: https://docs.python.org/3/library/dataclasses.html .. _dataclasses: https://docs.python.org/3/library/dataclasses.html .. _attrs: https://pypi.org/project/attrs/ diff --git a/doc/build/orm/declarative_mixins.rst b/doc/build/orm/declarative_mixins.rst index 0ee8a952bb8..9f26207c07a 100644 --- a/doc/build/orm/declarative_mixins.rst +++ b/doc/build/orm/declarative_mixins.rst @@ -152,7 +152,7 @@ Augmenting the Base In addition to using a pure mixin, most of the techniques in this section can also be applied to the base class directly, for patterns that should apply to all classes derived from a particular base. The example -below illustrates some of the the previous section's example in terms of the +below illustrates some of the previous section's example in terms of the ``Base`` class:: from sqlalchemy import ForeignKey diff --git a/doc/build/orm/declarative_styles.rst b/doc/build/orm/declarative_styles.rst index 48897ee6d6d..8feb5398b10 100644 --- a/doc/build/orm/declarative_styles.rst +++ b/doc/build/orm/declarative_styles.rst @@ -51,6 +51,7 @@ With the declarative base class, new mapped classes are declared as subclasses of the base:: from datetime import datetime + from typing import List from typing import Optional from sqlalchemy import ForeignKey diff --git a/doc/build/orm/declarative_tables.rst b/doc/build/orm/declarative_tables.rst index 711fa11bbee..5bffe97b0a1 100644 --- a/doc/build/orm/declarative_tables.rst +++ b/doc/build/orm/declarative_tables.rst @@ -108,7 +108,7 @@ further at :ref:`orm_declarative_metadata`. The :func:`_orm.mapped_column` construct accepts all arguments that are accepted by the :class:`_schema.Column` construct, as well as additional -ORM-specific arguments. The :paramref:`_orm.mapped_column.__name` field, +ORM-specific arguments. The :paramref:`_orm.mapped_column.__name` positional parameter, indicating the name of the database column, is typically omitted, as the Declarative process will make use of the attribute name given to the construct and assign this as the name of the column (in the above example, this refers to @@ -133,22 +133,19 @@ itself (more on this at :ref:`mapper_column_distinct_names`). :ref:`mapping_columns_toplevel` - contains additional notes on affecting how :class:`_orm.Mapper` interprets incoming :class:`.Column` objects. -.. _orm_declarative_mapped_column: - -Using Annotated Declarative Table (Type Annotated Forms for ``mapped_column()``) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :func:`_orm.mapped_column` construct is capable of deriving its column-configuration -information from :pep:`484` type annotations associated with the attribute -as declared in the Declarative mapped class. These type annotations, -if used, **must** -be present within a special SQLAlchemy type called :class:`_orm.Mapped`, which -is a generic_ type that then indicates a specific Python type within it. +ORM Annotated Declarative - Automated Mapping with Type Annotations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Below illustrates the mapping from the previous section, adding the use of -:class:`_orm.Mapped`:: +The :func:`_orm.mapped_column` construct in modern Python is normally augmented +by the use of :pep:`484` Python type annotations, where it is capable of +deriving its column-configuration information from type annotations associated +with the attribute as declared in the Declarative mapped class. These type +annotations, if used, must be present within a special SQLAlchemy type called +:class:`.Mapped`, which is a generic type that indicates a specific Python type +within it. - from typing import Optional +Using this technique, the example in the previous section can be written +more succinctly as below:: from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase @@ -165,222 +162,779 @@ Below illustrates the mapping from the previous section, adding the use of id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(50)) - fullname: Mapped[Optional[str]] - nickname: Mapped[Optional[str]] = mapped_column(String(30)) - -Above, when Declarative processes each class attribute, each -:func:`_orm.mapped_column` will derive additional arguments from the -corresponding :class:`_orm.Mapped` type annotation on the left side, if -present. Additionally, Declarative will generate an empty -:func:`_orm.mapped_column` directive implicitly, whenever a -:class:`_orm.Mapped` type annotation is encountered that does not have -a value assigned to the attribute (this form is inspired by the similar -style used in Python dataclasses_); this :func:`_orm.mapped_column` construct -proceeds to derive its configuration from the :class:`_orm.Mapped` -annotation present. + fullname: Mapped[str | None] + nickname: Mapped[str | None] = mapped_column(String(30)) -.. _orm_declarative_mapped_column_nullability: +The example above demonstrates that if a class attribute is type-hinted with +:class:`.Mapped` but doesn't have an explicit :func:`_orm.mapped_column` assigned +to it, SQLAlchemy will automatically create one. Furthermore, details like the +column's datatype and whether it can be null (nullability) are inferred from +the :class:`.Mapped` annotation. However, you can always explicitly provide these +arguments to :func:`_orm.mapped_column` to override these automatically-derived +settings. -``mapped_column()`` derives the datatype and nullability from the ``Mapped`` annotation -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +For complete details on using the ORM Annotated Declarative system, see +:ref:`orm_declarative_mapped_column` later in this chapter. -The two qualities that :func:`_orm.mapped_column` derives from the -:class:`_orm.Mapped` annotation are: +.. seealso:: -* **datatype** - the Python type given inside :class:`_orm.Mapped`, as contained - within the ``typing.Optional`` construct if present, is associated with a - :class:`_sqltypes.TypeEngine` subclass such as :class:`.Integer`, :class:`.String`, - :class:`.DateTime`, or :class:`.Uuid`, to name a few common types. + :ref:`orm_declarative_mapped_column` - complete reference for ORM Annotated Declarative - The datatype is determined based on a dictionary of Python type to - SQLAlchemy datatype. This dictionary is completely customizable, - as detailed in the next section :ref:`orm_declarative_mapped_column_type_map`. - The default type map is implemented as in the code example below:: +Dataclass features in ``mapped_column()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - from typing import Any - from typing import Dict - from typing import Type +The :func:`_orm.mapped_column` construct integrates with SQLAlchemy's +"native dataclasses" feature, discussed at +:ref:`orm_declarative_native_dataclasses`. See that section for current +background on additional directives supported by :func:`_orm.mapped_column`. - import datetime - import decimal - import uuid - from sqlalchemy import types - # default type mapping, deriving the type for mapped_column() - # from a Mapped[] annotation - type_map: Dict[Type[Any], TypeEngine[Any]] = { - bool: types.Boolean(), - bytes: types.LargeBinary(), - datetime.date: types.Date(), - datetime.datetime: types.DateTime(), - datetime.time: types.Time(), - datetime.timedelta: types.Interval(), - decimal.Decimal: types.Numeric(), - float: types.Float(), - int: types.Integer(), - str: types.String(), - uuid.UUID: types.Uuid(), - } - If the :func:`_orm.mapped_column` construct indicates an explicit type - as passed to the :paramref:`_orm.mapped_column.__type` argument, then - the given Python type is disregarded. +.. _orm_declarative_metadata: -* **nullability** - The :func:`_orm.mapped_column` construct will indicate - its :class:`_schema.Column` as ``NULL`` or ``NOT NULL`` first and foremost by - the presence of the :paramref:`_orm.mapped_column.nullable` parameter, passed - either as ``True`` or ``False``. Additionally , if the - :paramref:`_orm.mapped_column.primary_key` parameter is present and set to - ``True``, that will also imply that the column should be ``NOT NULL``. +Accessing Table and Metadata +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - In the absence of **both** of these parameters, the presence of - ``typing.Optional[]`` within the :class:`_orm.Mapped` type annotation will be - used to determine nullability, where ``typing.Optional[]`` means ``NULL``, - and the absense of ``typing.Optional[]`` means ``NOT NULL``. If there is no - ``Mapped[]`` annotation present at all, and there is no - :paramref:`_orm.mapped_column.nullable` or - :paramref:`_orm.mapped_column.primary_key` parameter, then SQLAlchemy's usual - default for :class:`_schema.Column` of ``NULL`` is used. +A declaratively mapped class will always include an attribute called +``__table__``; when the above configuration using ``__tablename__`` is +complete, the declarative process makes the :class:`_schema.Table` +available via the ``__table__`` attribute:: - In the example below, the ``id`` and ``data`` columns will be ``NOT NULL``, - and the ``additional_info`` column will be ``NULL``:: - from typing import Optional + # access the Table + user_table = User.__table__ - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column +The above table is ultimately the same one that corresponds to the +:attr:`_orm.Mapper.local_table` attribute, which we can see through the +:ref:`runtime inspection system `:: + from sqlalchemy import inspect - class Base(DeclarativeBase): - pass + user_table = inspect(User).local_table +The :class:`_schema.MetaData` collection associated with both the declarative +:class:`_orm.registry` as well as the base class is frequently necessary in +order to run DDL operations such as CREATE, as well as in use with migration +tools such as Alembic. This object is available via the ``.metadata`` +attribute of :class:`_orm.registry` as well as the declarative base class. +Below, for a small script we may wish to emit a CREATE for all tables against a +SQLite database:: - class SomeClass(Base): - __tablename__ = "some_table" + engine = create_engine("sqlite://") - # primary_key=True, therefore will be NOT NULL - id: Mapped[int] = mapped_column(primary_key=True) + Base.metadata.create_all(engine) - # not Optional[], therefore will be NOT NULL - data: Mapped[str] +.. _orm_declarative_table_configuration: - # Optional[], therefore will be NULL - additional_info: Mapped[Optional[str]] +Declarative Table Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - It is also perfectly valid to have a :func:`_orm.mapped_column` whose - nullability is **different** from what would be implied by the annotation. - For example, an ORM mapped attribute may be annotated as allowing ``None`` - within Python code that works with the object as it is first being created - and populated, however the value will ultimately be written to a database - column that is ``NOT NULL``. The :paramref:`_orm.mapped_column.nullable` - parameter, when present, will always take precedence:: +When using Declarative Table configuration with the ``__tablename__`` +declarative class attribute, additional arguments to be supplied to the +:class:`_schema.Table` constructor should be provided using the +``__table_args__`` declarative class attribute. - class SomeClass(Base): - # ... +This attribute accommodates both positional as well as keyword +arguments that are normally sent to the +:class:`_schema.Table` constructor. +The attribute can be specified in one of two forms. One is as a +dictionary:: - # will be String() NOT NULL, but can be None in Python - data: Mapped[Optional[str]] = mapped_column(nullable=False) + class MyClass(Base): + __tablename__ = "sometable" + __table_args__ = {"mysql_engine": "InnoDB"} - Similarly, a non-None attribute that's written to a database column that - for whatever reason needs to be NULL at the schema level, - :paramref:`_orm.mapped_column.nullable` may be set to ``True``:: +The other, a tuple, where each argument is positional +(usually constraints):: - class SomeClass(Base): - # ... + class MyClass(Base): + __tablename__ = "sometable" + __table_args__ = ( + ForeignKeyConstraint(["id"], ["remote_table.id"]), + UniqueConstraint("foo"), + ) - # will be String() NULL, but type checker will not expect - # the attribute to be None - data: Mapped[str] = mapped_column(nullable=True) +Keyword arguments can be specified with the above form by +specifying the last argument as a dictionary:: -.. _orm_declarative_mapped_column_type_map: + class MyClass(Base): + __tablename__ = "sometable" + __table_args__ = ( + ForeignKeyConstraint(["id"], ["remote_table.id"]), + UniqueConstraint("foo"), + {"autoload": True}, + ) -Customizing the Type Map -~~~~~~~~~~~~~~~~~~~~~~~~ +A class may also specify the ``__table_args__`` declarative attribute, +as well as the ``__tablename__`` attribute, in a dynamic style using the +:func:`_orm.declared_attr` method decorator. See +:ref:`orm_mixins_toplevel` for background. -The mapping of Python types to SQLAlchemy :class:`_types.TypeEngine` types -described in the previous section defaults to a hardcoded dictionary -present in the ``sqlalchemy.sql.sqltypes`` module. However, the :class:`_orm.registry` -object that coordinates the Declarative mapping process will first consult -a local, user defined dictionary of types which may be passed -as the :paramref:`_orm.registry.type_annotation_map` parameter when -constructing the :class:`_orm.registry`, which may be associated with -the :class:`_orm.DeclarativeBase` superclass when first used. +.. _orm_declarative_table_schema_name: -As an example, if we wish to make use of the :class:`_sqltypes.BIGINT` datatype for -``int``, the :class:`_sqltypes.TIMESTAMP` datatype with ``timezone=True`` for -``datetime.datetime``, and then only on Microsoft SQL Server we'd like to use -:class:`_sqltypes.NVARCHAR` datatype when Python ``str`` is used, -the registry and Declarative base could be configured as:: +Explicit Schema Name with Declarative Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - import datetime +The schema name for a :class:`_schema.Table` as documented at +:ref:`schema_table_schema_name` is applied to an individual :class:`_schema.Table` +using the :paramref:`_schema.Table.schema` argument. When using Declarative +tables, this option is passed like any other to the ``__table_args__`` +dictionary:: - from sqlalchemy import BIGINT, Integer, NVARCHAR, String, TIMESTAMP from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped, mapped_column, registry class Base(DeclarativeBase): - type_annotation_map = { - int: BIGINT, - datetime.datetime: TIMESTAMP(timezone=True), - str: String().with_variant(NVARCHAR, "mssql"), - } + pass - class SomeClass(Base): - __tablename__ = "some_table" + class MyClass(Base): + __tablename__ = "sometable" + __table_args__ = {"schema": "some_schema"} - id: Mapped[int] = mapped_column(primary_key=True) - date: Mapped[datetime.datetime] - status: Mapped[str] +The schema name can also be applied to all :class:`_schema.Table` objects +globally by using the :paramref:`_schema.MetaData.schema` parameter documented +at :ref:`schema_metadata_schema_name`. The :class:`_schema.MetaData` object +may be constructed separately and associated with a :class:`_orm.DeclarativeBase` +subclass by assigning to the ``metadata`` attribute directly:: -Below illustrates the CREATE TABLE statement generated for the above mapping, -first on the Microsoft SQL Server backend, illustrating the ``NVARCHAR`` datatype: + from sqlalchemy import MetaData + from sqlalchemy.orm import DeclarativeBase -.. sourcecode:: pycon+sql + metadata_obj = MetaData(schema="some_schema") - >>> from sqlalchemy.schema import CreateTable - >>> from sqlalchemy.dialects import mssql, postgresql - >>> print(CreateTable(SomeClass.__table__).compile(dialect=mssql.dialect())) - {printsql}CREATE TABLE some_table ( - id BIGINT NOT NULL IDENTITY, - date TIMESTAMP NOT NULL, - status NVARCHAR(max) NOT NULL, - PRIMARY KEY (id) - ) -Then on the PostgreSQL backend, illustrating ``TIMESTAMP WITH TIME ZONE``: + class Base(DeclarativeBase): + metadata = metadata_obj -.. sourcecode:: pycon+sql - >>> print(CreateTable(SomeClass.__table__).compile(dialect=postgresql.dialect())) - {printsql}CREATE TABLE some_table ( - id BIGSERIAL NOT NULL, - date TIMESTAMP WITH TIME ZONE NOT NULL, - status VARCHAR NOT NULL, - PRIMARY KEY (id) - ) + class MyClass(Base): + # will use "some_schema" by default + __tablename__ = "sometable" -By making use of methods such as :meth:`.TypeEngine.with_variant`, we're able -to build up a type map that's customized to what we need for different backends, -while still being able to use succinct annotation-only :func:`_orm.mapped_column` -configurations. There are two more levels of Python-type configurability -available beyond this, described in the next two sections. +.. seealso:: -.. _orm_declarative_mapped_column_type_map_pep593: + :ref:`schema_table_schema_name` - in the :ref:`metadata_toplevel` documentation. -Mapping Multiple Type Configurations to Python Types -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. _orm_declarative_column_options: -As individual Python types may be associated with :class:`_types.TypeEngine` -configurations of any variety by using the :paramref:`_orm.registry.type_annotation_map` -parameter, an additional -capability is the ability to associate a single Python type with different -variants of a SQL type based on additional type qualifiers. One typical -example of this is mapping the Python ``str`` datatype to ``VARCHAR`` -SQL types of different lengths. Another is mapping different varieties of +Setting Load and Persistence Options for Declarative Mapped Columns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :func:`_orm.mapped_column` construct accepts additional ORM-specific +arguments that affect how the generated :class:`_schema.Column` is +mapped, affecting its load and persistence-time behavior. Options +that are commonly used include: + +* **deferred column loading** - The :paramref:`_orm.mapped_column.deferred` + boolean establishes the :class:`_schema.Column` using + :ref:`deferred column loading ` by default. In the example + below, the ``User.bio`` column will not be loaded by default, but only + when accessed:: + + class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + bio: Mapped[str] = mapped_column(Text, deferred=True) + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - full description of deferred column loading + +* **active history** - The :paramref:`_orm.mapped_column.active_history` + ensures that upon change of value for the attribute, the previous value + will have been loaded and made part of the :attr:`.AttributeState.history` + collection when inspecting the history of the attribute. This may incur + additional SQL statements:: + + class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + important_identifier: Mapped[str] = mapped_column(active_history=True) + +See the docstring for :func:`_orm.mapped_column` for a list of supported +parameters. + +.. seealso:: + + :ref:`orm_imperative_table_column_options` - describes using + :func:`_orm.column_property` and :func:`_orm.deferred` for use with + Imperative Table configuration + +.. _mapper_column_distinct_names: + +.. _orm_declarative_table_column_naming: + +Naming Declarative Mapped Columns Explicitly +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All of the examples thus far feature the :func:`_orm.mapped_column` construct +linked to an ORM mapped attribute, where the Python attribute name given +to the :func:`_orm.mapped_column` is also that of the column as we see in +CREATE TABLE statements as well as queries. The name for a column as +expressed in SQL may be indicated by passing the string positional argument +:paramref:`_orm.mapped_column.__name` as the first positional argument. +In the example below, the ``User`` class is mapped with alternate names +given to the columns themselves:: + + class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column("user_id", primary_key=True) + name: Mapped[str] = mapped_column("user_name") + +Where above ``User.id`` resolves to a column named ``user_id`` +and ``User.name`` resolves to a column named ``user_name``. We +may write a :func:`_sql.select` statement using our Python attribute names +and will see the SQL names generated: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User.id, User.name).where(User.name == "x")) + {printsql}SELECT "user".user_id, "user".user_name + FROM "user" + WHERE "user".user_name = :user_name_1 + + +.. seealso:: + + :ref:`orm_imperative_table_column_naming` - applies to Imperative Table + +.. _orm_declarative_table_adding_columns: + +Appending additional columns to an existing Declarative mapped class +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A declarative table configuration allows the addition of new +:class:`_schema.Column` objects to an existing mapping after the :class:`.Table` +metadata has already been generated. + +For a declarative class that is declared using a declarative base class, +the underlying metaclass :class:`.DeclarativeMeta` includes a ``__setattr__()`` +method that will intercept additional :func:`_orm.mapped_column` or Core +:class:`.Column` objects and +add them to both the :class:`.Table` using :meth:`.Table.append_column` +as well as to the existing :class:`.Mapper` using :meth:`.Mapper.add_property`:: + + MyClass.some_new_column = mapped_column(String) + +Using core :class:`_schema.Column`:: + + MyClass.some_new_column = Column(String) + +All arguments are supported including an alternate name, such as +``MyClass.some_new_column = mapped_column("some_name", String)``. However, +the SQL type must be passed to the :func:`_orm.mapped_column` or +:class:`_schema.Column` object explicitly, as in the above examples where +the :class:`_sqltypes.String` type is passed. There's no capability for +the :class:`_orm.Mapped` annotation type to take part in the operation. + +Additional :class:`_schema.Column` objects may also be added to a mapping +in the specific circumstance of using single table inheritance, where +additional columns are present on mapped subclasses that have +no :class:`.Table` of their own. This is illustrated in the section +:ref:`single_inheritance`. + +.. seealso:: + + :ref:`orm_declarative_table_adding_relationship` - similar examples for :func:`_orm.relationship` + +.. note:: Assignment of mapped + properties to an already mapped class will only + function correctly if the "declarative base" class is used, meaning + the user-defined subclass of :class:`_orm.DeclarativeBase` or the + dynamically generated class returned by :func:`_orm.declarative_base` + or :meth:`_orm.registry.generate_base`. This "base" class includes + a Python metaclass which implements a special ``__setattr__()`` method + that intercepts these operations. + + Runtime assignment of class-mapped attributes to a mapped class will **not** work + if the class is mapped using decorators like :meth:`_orm.registry.mapped` + or imperative functions like :meth:`_orm.registry.map_imperatively`. + + +.. _orm_declarative_mapped_column: + +ORM Annotated Declarative - Complete Guide +------------------------------------------ + +The :func:`_orm.mapped_column` construct is capable of deriving its +column-configuration information from :pep:`484` type annotations associated +with the attribute as declared in the Declarative mapped class. These type +annotations, if used, must be present within a special SQLAlchemy type called +:class:`_orm.Mapped`, which is a generic_ type that then indicates a specific +Python type within it. + +Using this technique, the ``User`` example from previous sections may be +written as below:: + + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(50)) + fullname: Mapped[str | None] + nickname: Mapped[str | None] = mapped_column(String(30)) + +Above, when Declarative processes each class attribute, each +:func:`_orm.mapped_column` will derive additional arguments from the +corresponding :class:`_orm.Mapped` type annotation on the left side, if +present. Additionally, Declarative will generate an empty +:func:`_orm.mapped_column` directive implicitly, whenever a +:class:`_orm.Mapped` type annotation is encountered that does not have +a value assigned to the attribute (this form is inspired by the similar +style used in Python dataclasses_); this :func:`_orm.mapped_column` construct +proceeds to derive its configuration from the :class:`_orm.Mapped` +annotation present. + +.. _orm_declarative_mapped_column_nullability: + +``mapped_column()`` derives the datatype and nullability from the ``Mapped`` annotation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The two qualities that :func:`_orm.mapped_column` derives from the +:class:`_orm.Mapped` annotation are: + +* **datatype** - the Python type given inside :class:`_orm.Mapped`, as contained + within the ``typing.Optional`` construct if present, is associated with a + :class:`_sqltypes.TypeEngine` subclass such as :class:`.Integer`, :class:`.String`, + :class:`.DateTime`, or :class:`.Uuid`, to name a few common types. + + The datatype is determined based on a dictionary of Python type to + SQLAlchemy datatype. This dictionary is completely customizable, + as detailed in the next section :ref:`orm_declarative_mapped_column_type_map`. + The default type map is implemented as in the code example below:: + + from typing import Any + from typing import Dict + from typing import Type + + import datetime + import decimal + import uuid + + from sqlalchemy import types + + # default type mapping, deriving the type for mapped_column() + # from a Mapped[] annotation + type_map: Dict[Type[Any], TypeEngine[Any]] = { + bool: types.Boolean(), + bytes: types.LargeBinary(), + datetime.date: types.Date(), + datetime.datetime: types.DateTime(), + datetime.time: types.Time(), + datetime.timedelta: types.Interval(), + decimal.Decimal: types.Numeric(), + float: types.Float(), + int: types.Integer(), + str: types.String(), + uuid.UUID: types.Uuid(), + } + + If the :func:`_orm.mapped_column` construct indicates an explicit type + as passed to the :paramref:`_orm.mapped_column.__type` argument, then + the given Python type is disregarded. + +* **nullability** - The :func:`_orm.mapped_column` construct will indicate + its :class:`_schema.Column` as ``NULL`` or ``NOT NULL`` first and foremost by + the presence of the :paramref:`_orm.mapped_column.nullable` parameter, passed + either as ``True`` or ``False``. Additionally , if the + :paramref:`_orm.mapped_column.primary_key` parameter is present and set to + ``True``, that will also imply that the column should be ``NOT NULL``. + + In the absence of **both** of these parameters, the presence of + ``typing.Optional[]`` within the :class:`_orm.Mapped` type annotation will be + used to determine nullability, where ``typing.Optional[]`` means ``NULL``, + and the absence of ``typing.Optional[]`` means ``NOT NULL``. If there is no + ``Mapped[]`` annotation present at all, and there is no + :paramref:`_orm.mapped_column.nullable` or + :paramref:`_orm.mapped_column.primary_key` parameter, then SQLAlchemy's usual + default for :class:`_schema.Column` of ``NULL`` is used. + + In the example below, the ``id`` and ``data`` columns will be ``NOT NULL``, + and the ``additional_info`` column will be ``NULL``:: + + from typing import Optional + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class SomeClass(Base): + __tablename__ = "some_table" + + # primary_key=True, therefore will be NOT NULL + id: Mapped[int] = mapped_column(primary_key=True) + + # not Optional[], therefore will be NOT NULL + data: Mapped[str] + + # Optional[], therefore will be NULL + additional_info: Mapped[Optional[str]] + + It is also perfectly valid to have a :func:`_orm.mapped_column` whose + nullability is **different** from what would be implied by the annotation. + For example, an ORM mapped attribute may be annotated as allowing ``None`` + within Python code that works with the object as it is first being created + and populated, however the value will ultimately be written to a database + column that is ``NOT NULL``. The :paramref:`_orm.mapped_column.nullable` + parameter, when present, will always take precedence:: + + class SomeClass(Base): + # ... + + # will be String() NOT NULL, but can be None in Python + data: Mapped[Optional[str]] = mapped_column(nullable=False) + + Similarly, a non-None attribute that's written to a database column that + for whatever reason needs to be NULL at the schema level, + :paramref:`_orm.mapped_column.nullable` may be set to ``True``:: + + class SomeClass(Base): + # ... + + # will be String() NULL, but type checker will not expect + # the attribute to be None + data: Mapped[str] = mapped_column(nullable=True) + +.. _orm_declarative_mapped_column_type_map: + +Customizing the Type Map +^^^^^^^^^^^^^^^^^^^^^^^^ + + +The mapping of Python types to SQLAlchemy :class:`_types.TypeEngine` types +described in the previous section defaults to a hardcoded dictionary +present in the ``sqlalchemy.sql.sqltypes`` module. However, the :class:`_orm.registry` +object that coordinates the Declarative mapping process will first consult +a local, user defined dictionary of types which may be passed +as the :paramref:`_orm.registry.type_annotation_map` parameter when +constructing the :class:`_orm.registry`, which may be associated with +the :class:`_orm.DeclarativeBase` superclass when first used. + +As an example, if we wish to make use of the :class:`_sqltypes.BIGINT` datatype for +``int``, the :class:`_sqltypes.TIMESTAMP` datatype with ``timezone=True`` for +``datetime.datetime``, and then only on Microsoft SQL Server we'd like to use +:class:`_sqltypes.NVARCHAR` datatype when Python ``str`` is used, +the registry and Declarative base could be configured as:: + + import datetime + + from sqlalchemy import BIGINT, NVARCHAR, String, TIMESTAMP + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + + class Base(DeclarativeBase): + type_annotation_map = { + int: BIGINT, + datetime.datetime: TIMESTAMP(timezone=True), + str: String().with_variant(NVARCHAR, "mssql"), + } + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + date: Mapped[datetime.datetime] + status: Mapped[str] + +Below illustrates the CREATE TABLE statement generated for the above mapping, +first on the Microsoft SQL Server backend, illustrating the ``NVARCHAR`` datatype: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.schema import CreateTable + >>> from sqlalchemy.dialects import mssql, postgresql + >>> print(CreateTable(SomeClass.__table__).compile(dialect=mssql.dialect())) + {printsql}CREATE TABLE some_table ( + id BIGINT NOT NULL IDENTITY, + date TIMESTAMP NOT NULL, + status NVARCHAR(max) NOT NULL, + PRIMARY KEY (id) + ) + +Then on the PostgreSQL backend, illustrating ``TIMESTAMP WITH TIME ZONE``: + +.. sourcecode:: pycon+sql + + >>> print(CreateTable(SomeClass.__table__).compile(dialect=postgresql.dialect())) + {printsql}CREATE TABLE some_table ( + id BIGSERIAL NOT NULL, + date TIMESTAMP WITH TIME ZONE NOT NULL, + status VARCHAR NOT NULL, + PRIMARY KEY (id) + ) + +By making use of methods such as :meth:`.TypeEngine.with_variant`, we're able +to build up a type map that's customized to what we need for different backends, +while still being able to use succinct annotation-only :func:`_orm.mapped_column` +configurations. There are two more levels of Python-type configurability +available beyond this, described in the next two sections. + +.. _orm_declarative_type_map_union_types: + +Union types inside the Type Map +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +.. versionchanged:: 2.0.37 The features described in this section have been + repaired and enhanced to work consistently. Prior to this change, union + types were supported in ``type_annotation_map``, however the feature + exhibited inconsistent behaviors between union syntaxes as well as in how + ``None`` was handled. Please ensure SQLAlchemy is up to date before + attempting to use the features described in this section. + +SQLAlchemy supports mapping union types inside the ``type_annotation_map`` to +allow mapping database types that can support multiple Python types, such as +:class:`_types.JSON` or :class:`_postgresql.JSONB`:: + + from typing import Union + from sqlalchemy import JSON + from sqlalchemy.dialects import postgresql + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + from sqlalchemy.schema import CreateTable + + # new style Union using a pipe operator + json_list = list[int] | list[str] + + # old style Union using Union explicitly + json_scalar = Union[float, str, bool] + + + class Base(DeclarativeBase): + type_annotation_map = { + json_list: postgresql.JSONB, + json_scalar: JSON, + } + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + list_col: Mapped[list[str] | list[int]] + + # uses JSON + scalar_col: Mapped[json_scalar] + + # uses JSON and is also nullable=True + scalar_col_nullable: Mapped[json_scalar | None] + + # these forms all use JSON as well due to the json_scalar entry + scalar_col_newstyle: Mapped[float | str | bool] + scalar_col_oldstyle: Mapped[Union[float, str, bool]] + scalar_col_mixedstyle: Mapped[Optional[float | str | bool]] + +The above example maps the union of ``list[int]`` and ``list[str]`` to the Postgresql +:class:`_postgresql.JSONB` datatype, while naming a union of ``float, +str, bool`` will match to the :class:`_types.JSON` datatype. An equivalent +union, stated in the :class:`_orm.Mapped` construct, will match into the +corresponding entry in the type map. + +The matching of a union type is based on the contents of the union regardless +of how the individual types are named, and additionally excluding the use of +the ``None`` type. That is, ``json_scalar`` will also match to ``str | bool | +float | None``. It will **not** match to a union that is a subset or superset +of this union; that is, ``str | bool`` would not match, nor would ``str | bool +| float | int``. The individual contents of the union excluding ``None`` must +be an exact match. + +The ``None`` value is never significant as far as matching +from ``type_annotation_map`` to :class:`_orm.Mapped`, however is significant +as an indicator for nullability of the :class:`_schema.Column`. When ``None`` is present in the +union either as it is placed in the :class:`_orm.Mapped` construct. When +present in :class:`_orm.Mapped`, it indicates the :class:`_schema.Column` +would be nullable, in the absense of more specific indicators. This logic works +in the same way as indicating an ``Optional`` type as described at +:ref:`orm_declarative_mapped_column_nullability`. + +The CREATE TABLE statement for the above mapping will look as below: + +.. sourcecode:: pycon+sql + + >>> print(CreateTable(SomeClass.__table__).compile(dialect=postgresql.dialect())) + {printsql}CREATE TABLE some_table ( + id SERIAL NOT NULL, + list_col JSONB NOT NULL, + scalar_col JSON, + scalar_col_not_null JSON NOT NULL, + PRIMARY KEY (id) + ) + +While union types use a "loose" matching approach that matches on any equivalent +set of subtypes, Python typing also features a way to create "type aliases" +that are treated as distinct types that are non-equivalent to another type that +includes the same composition. Integration of these types with ``type_annotation_map`` +is described in the next section, :ref:`orm_declarative_type_map_pep695_types`. + +.. _orm_declarative_type_map_pep695_types: + +Support for Type Alias Types (defined by PEP 695) and NewType +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +In contrast to the typing lookup described in +:ref:`orm_declarative_type_map_union_types`, Python typing also includes two +ways to create a composed type in a more formal way, using ``typing.NewType`` as +well as the ``type`` keyword introduced in :pep:`695`. These types behave +differently from ordinary type aliases (i.e. assigning a type to a variable +name), and this difference is honored in how SQLAlchemy resolves these +types from the type map. + +.. versionchanged:: 2.0.37 The behaviors described in this section for ``typing.NewType`` + as well as :pep:`695` ``type`` have been formalized and corrected. + Deprecation warnings are now emitted for "loose matching" patterns that have + worked in some 2.0 releases, but are to be removed in SQLAlchemy 2.1. + Please ensure SQLAlchemy is up to date before attempting to use the features + described in this section. + +The typing module allows the creation of "new types" using ``typing.NewType``:: + + from typing import NewType + + nstr30 = NewType("nstr30", str) + nstr50 = NewType("nstr50", str) + +Additionally, in Python 3.12, a new feature defined by :pep:`695` was introduced which +provides the ``type`` keyword to accomplish a similar task; using +``type`` produces an object that is similar in many ways to ``typing.NewType`` +which is internally referred to as ``typing.TypeAliasType``:: + + type SmallInt = int + type BigInt = int + type JsonScalar = str | float | bool | None + +For the purposes of how SQLAlchemy treats these type objects when used +for SQL type lookup inside of :class:`_orm.Mapped`, it's important to note +that Python does not consider two equivalent ``typing.TypeAliasType`` +or ``typing.NewType`` objects to be equal:: + + # two typing.NewType objects are not equal even if they are both str + >>> nstr50 == nstr30 + False + + # two TypeAliasType objects are not equal even if they are both int + >>> SmallInt == BigInt + False + + # an equivalent union is not equal to JsonScalar + >>> JsonScalar == str | float | bool | None + False + +This is the opposite behavior from how ordinary unions are compared, and +informs the correct behavior for SQLAlchemy's ``type_annotation_map``. When +using ``typing.NewType`` or :pep:`695` ``type`` objects, the type object is +expected to be explicit within the ``type_annotation_map`` for it to be matched +from a :class:`_orm.Mapped` type, where the same object must be stated in order +for a match to be made (excluding whether or not the type inside of +:class:`_orm.Mapped` also unions on ``None``). This is distinct from the +behavior described at :ref:`orm_declarative_type_map_union_types`, where a +plain ``Union`` that is referenced directly will match to other ``Unions`` +based on the composition, rather than the object identity, of a particular type +in ``type_annotation_map``. + +In the example below, the composed types for ``nstr30``, ``nstr50``, +``SmallInt``, ``BigInt``, and ``JsonScalar`` have no overlap with each other +and can be named distinctly within each :class:`_orm.Mapped` construct, and +are also all explicit in ``type_annotation_map``. Any of these types may +also be unioned with ``None`` or declared as ``Optional[]`` without affecting +the lookup, only deriving column nullability:: + + from typing import NewType + + from sqlalchemy import SmallInteger, BigInteger, JSON, String + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + from sqlalchemy.schema import CreateTable + + nstr30 = NewType("nstr30", str) + nstr50 = NewType("nstr50", str) + type SmallInt = int + type BigInt = int + type JsonScalar = str | float | bool | None + + + class TABase(DeclarativeBase): + type_annotation_map = { + nstr30: String(30), + nstr50: String(50), + SmallInt: SmallInteger, + BigInteger: BigInteger, + JsonScalar: JSON, + } + + + class SomeClass(TABase): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + normal_str: Mapped[str] + + short_str: Mapped[nstr30] + long_str_nullable: Mapped[nstr50 | None] + + small_int: Mapped[SmallInt] + big_int: Mapped[BigInteger] + scalar_col: Mapped[JsonScalar] + +a CREATE TABLE for the above mapping will illustrate the different variants +of integer and string we've configured, and looks like: + +.. sourcecode:: pycon+sql + + >>> print(CreateTable(SomeClass.__table__)) + {printsql}CREATE TABLE some_table ( + id INTEGER NOT NULL, + normal_str VARCHAR NOT NULL, + short_str VARCHAR(30) NOT NULL, + long_str_nullable VARCHAR(50), + small_int SMALLINT NOT NULL, + big_int BIGINT NOT NULL, + scalar_col JSON, + PRIMARY KEY (id) + ) + +Regarding nullability, the ``JsonScalar`` type includes ``None`` in its +definition, which indicates a nullable column. Similarly the +``long_str_nullable`` column applies a union of ``None`` to ``nstr50``, +which matches to the ``nstr50`` type in the ``type_annotation_map`` while +also applying nullability to the mapped column. The other columns all remain +NOT NULL as they are not indicated as optional. + + +.. _orm_declarative_mapped_column_type_map_pep593: + +Mapping Multiple Type Configurations to Python Types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +As individual Python types may be associated with :class:`_types.TypeEngine` +configurations of any variety by using the :paramref:`_orm.registry.type_annotation_map` +parameter, an additional +capability is the ability to associate a single Python type with different +variants of a SQL type based on additional type qualifiers. One typical +example of this is mapping the Python ``str`` datatype to ``VARCHAR`` +SQL types of different lengths. Another is mapping different varieties of ``decimal.Decimal`` to differently sized ``NUMERIC`` columns. Python's typing system provides a great way to add additional metadata to a @@ -458,10 +1012,12 @@ us a wide degree of flexibility, the next section illustrates a second way in which ``Annotated`` may be used with Declarative that is even more open ended. + .. _orm_declarative_mapped_column_pep593: Mapping Whole Column Declarations to Python Types -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + The previous section illustrated using :pep:`593` ``Annotated`` type instances as keys within the :paramref:`_orm.registry.type_annotation_map` @@ -520,281 +1076,41 @@ specific to each attribute:: class SomeClass(Base): __tablename__ = "some_table" - id: Mapped[intpk] - name: Mapped[required_name] - created_at: Mapped[timestamp] - -``CREATE TABLE`` for our above mapping looks like: - -.. sourcecode:: pycon+sql - - >>> from sqlalchemy.schema import CreateTable - >>> print(CreateTable(SomeClass.__table__)) - {printsql}CREATE TABLE some_table ( - id INTEGER NOT NULL, - name VARCHAR(30) NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, - PRIMARY KEY (id) - ) - -When using ``Annotated`` types in this way, the configuration of the type -may also be affected on a per-attribute basis. For the types in the above -example that feature explcit use of :paramref:`_orm.mapped_column.nullable`, -we can apply the ``Optional[]`` generic modifier to any of our types so that -the field is optional or not at the Python level, which will be independent -of the ``NULL`` / ``NOT NULL`` setting that takes place in the database:: - - from typing_extensions import Annotated - - import datetime - from typing import Optional - - from sqlalchemy.orm import DeclarativeBase - - timestamp = Annotated[ - datetime.datetime, - mapped_column(nullable=False), - ] - - - class Base(DeclarativeBase): - pass - - - class SomeClass(Base): - # ... - - # pep-484 type will be Optional, but column will be - # NOT NULL - created_at: Mapped[Optional[timestamp]] - -The :func:`_orm.mapped_column` construct is also reconciled with an explicitly -passed :func:`_orm.mapped_column` construct, whose arguments will take precedence -over those of the ``Annotated`` construct. Below we add a :class:`.ForeignKey` -constraint to our integer primary key and also use an alternate server -default for the ``created_at`` column:: - - import datetime - - from typing_extensions import Annotated - - from sqlalchemy import ForeignKey - from sqlalchemy import func - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - from sqlalchemy.schema import CreateTable - - intpk = Annotated[int, mapped_column(primary_key=True)] - timestamp = Annotated[ - datetime.datetime, - mapped_column(nullable=False, server_default=func.CURRENT_TIMESTAMP()), - ] - - - class Base(DeclarativeBase): - pass - - - class Parent(Base): - __tablename__ = "parent" - - id: Mapped[intpk] - - - class SomeClass(Base): - __tablename__ = "some_table" - - # add ForeignKey to mapped_column(Integer, primary_key=True) - id: Mapped[intpk] = mapped_column(ForeignKey("parent.id")) - - # change server default from CURRENT_TIMESTAMP to UTC_TIMESTAMP - created_at: Mapped[timestamp] = mapped_column(server_default=func.UTC_TIMESTAMP()) - -The CREATE TABLE statement illustrates these per-attribute settings, -adding a ``FOREIGN KEY`` constraint as well as substituting -``UTC_TIMESTAMP`` for ``CURRENT_TIMESTAMP``: - -.. sourcecode:: pycon+sql - - >>> from sqlalchemy.schema import CreateTable - >>> print(CreateTable(SomeClass.__table__)) - {printsql}CREATE TABLE some_table ( - id INTEGER NOT NULL, - created_at DATETIME DEFAULT UTC_TIMESTAMP() NOT NULL, - PRIMARY KEY (id), - FOREIGN KEY(id) REFERENCES parent (id) - ) - -.. note:: The feature of :func:`_orm.mapped_column` just described, where - a fully constructed set of column arguments may be indicated using - :pep:`593` ``Annotated`` objects that contain a "template" - :func:`_orm.mapped_column` object to be copied into the attribute, is - currently not implemented for other ORM constructs such as - :func:`_orm.relationship` and :func:`_orm.composite`. While this functionality - is in theory possible, for the moment attempting to use ``Annotated`` - to indicate further arguments for :func:`_orm.relationship` and similar - will raise a ``NotImplementedError`` exception at runtime, but - may be implemented in future releases. - -.. _orm_declarative_mapped_column_enums: - -Using Python ``Enum`` or pep-586 ``Literal`` types in the type map -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. versionadded:: 2.0.0b4 - Added ``Enum`` support - -.. versionadded:: 2.0.1 - Added ``Literal`` support - -User-defined Python types which derive from the Python built-in ``enum.Enum`` -as well as the ``typing.Literal`` -class are automatically linked to the SQLAlchemy :class:`.Enum` datatype -when used in an ORM declarative mapping. The example below uses -a custom ``enum.Enum`` within the ``Mapped[]`` constructor:: - - import enum - - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - - - class Base(DeclarativeBase): - pass - - - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" - - - class SomeClass(Base): - __tablename__ = "some_table" - - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[Status] - -In the above example, the mapped attribute ``SomeClass.status`` will be -linked to a :class:`.Column` with the datatype of ``Enum(Status)``. -We can see this for example in the CREATE TABLE output for the PostgreSQL -database: - -.. sourcecode:: sql - - CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED') - - CREATE TABLE some_table ( - id SERIAL NOT NULL, - status status NOT NULL, - PRIMARY KEY (id) - ) - -In a similar way, ``typing.Literal`` may be used instead, using -a ``typing.Literal`` that consists of all strings:: - - - from typing import Literal - - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - - - class Base(DeclarativeBase): - pass - - - Status = Literal["pending", "received", "completed"] - - - class SomeClass(Base): - __tablename__ = "some_table" - - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[Status] - -The entries used in :paramref:`_orm.registry.type_annotation_map` link the base -``enum.Enum`` Python type as well as the ``typing.Literal`` type to the -SQLAlchemy :class:`.Enum` SQL type, using a special form which indicates to the -:class:`.Enum` datatype that it should automatically configure itself against -an arbitrary enumerated type. This configuration, which is implicit by default, -would be indicated explicitly as:: - - import enum - import typing - - import sqlalchemy - from sqlalchemy.orm import DeclarativeBase - - - class Base(DeclarativeBase): - type_annotation_map = { - enum.Enum: sqlalchemy.Enum(enum.Enum), - typing.Literal: sqlalchemy.Enum(enum.Enum), - } - -The resolution logic within Declarative is able to resolve subclasses -of ``enum.Enum`` as well as instances of ``typing.Literal`` to match the -``enum.Enum`` or ``typing.Literal`` entry in the -:paramref:`_orm.registry.type_annotation_map` dictionary. The :class:`.Enum` -SQL type then knows how to produce a configured version of itself with the -appropriate settings, including default string length. If a ``typing.Literal`` -that does not consist of only string values is passed, an informative -error is raised. - -Native Enums and Naming -+++++++++++++++++++++++ - -The :paramref:`.sqltypes.Enum.native_enum` parameter refers to if the -:class:`.sqltypes.Enum` datatype should create a so-called "native" -enum, which on MySQL/MariaDB is the ``ENUM`` datatype and on PostgreSQL is -a new ``TYPE`` object created by ``CREATE TYPE``, or a "non-native" enum, -which means that ``VARCHAR`` will be used to create the datatype. For -backends other than MySQL/MariaDB or PostgreSQL, ``VARCHAR`` is used in -all cases (third party dialects may have their own behaviors). - -Because PostgreSQL's ``CREATE TYPE`` requires that there's an explicit name -for the type to be created, special fallback logic exists when working -with implicitly generated :class:`.sqltypes.Enum` without specifying an -explicit :class:`.sqltypes.Enum` datatype within a mapping: - -1. If the :class:`.sqltypes.Enum` is linked to an ``enum.Enum`` object, - the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to - ``True`` and the name of the enum will be taken from the name of the - ``enum.Enum`` datatype. The PostgreSQL backend will assume ``CREATE TYPE`` - with this name. -2. If the :class:`.sqltypes.Enum` is linked to a ``typing.Literal`` object, - the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to - ``False``; no name is generated and ``VARCHAR`` is assumed. - -To use ``typing.Literal`` with a PostgreSQL ``CREATE TYPE`` type, an -explicit :class:`.sqltypes.Enum` must be used, either within the -type map:: - - import enum - import typing + id: Mapped[intpk] + name: Mapped[required_name] + created_at: Mapped[timestamp] - import sqlalchemy - from sqlalchemy.orm import DeclarativeBase +``CREATE TABLE`` for our above mapping looks like: - Status = Literal["pending", "received", "completed"] +.. sourcecode:: pycon+sql + >>> from sqlalchemy.schema import CreateTable + >>> print(CreateTable(SomeClass.__table__)) + {printsql}CREATE TABLE some_table ( + id INTEGER NOT NULL, + name VARCHAR(30) NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (id) + ) - class Base(DeclarativeBase): - type_annotation_map = { - Status: sqlalchemy.Enum("pending", "received", "completed", name="status_enum"), - } +When using ``Annotated`` types in this way, the configuration of the type +may also be affected on a per-attribute basis. For the types in the above +example that feature explicit use of :paramref:`_orm.mapped_column.nullable`, +we can apply the ``Optional[]`` generic modifier to any of our types so that +the field is optional or not at the Python level, which will be independent +of the ``NULL`` / ``NOT NULL`` setting that takes place in the database:: -Or alternatively within :func:`_orm.mapped_column`:: + from typing_extensions import Annotated - import enum - import typing + import datetime + from typing import Optional - import sqlalchemy from sqlalchemy.orm import DeclarativeBase - Status = Literal["pending", "received", "completed"] + timestamp = Annotated[ + datetime.datetime, + mapped_column(nullable=False), + ] class Base(DeclarativeBase): @@ -802,359 +1118,365 @@ Or alternatively within :func:`_orm.mapped_column`:: class SomeClass(Base): - __tablename__ = "some_table" + # ... - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[Status] = mapped_column( - sqlalchemy.Enum("pending", "received", "completed", name="status_enum") - ) + # pep-484 type will be Optional, but column will be + # NOT NULL + created_at: Mapped[Optional[timestamp]] -Altering the Configuration of the Default Enum -+++++++++++++++++++++++++++++++++++++++++++++++ +The :func:`_orm.mapped_column` construct is also reconciled with an explicitly +passed :func:`_orm.mapped_column` construct, whose arguments will take precedence +over those of the ``Annotated`` construct. Below we add a :class:`.ForeignKey` +constraint to our integer primary key and also use an alternate server +default for the ``created_at`` column:: -In order to modify the fixed configuration of the :class:`.enum.Enum` datatype -that's generated implicitly, specify new entries in the -:paramref:`_orm.registry.type_annotation_map`, indicating additional arguments. -For example, to use "non native enumerations" unconditionally, the -:paramref:`.Enum.native_enum` parameter may be set to False for all types:: + import datetime - import enum - import typing - import sqlalchemy - from sqlalchemy.orm import DeclarativeBase + from typing_extensions import Annotated + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.schema import CreateTable - class Base(DeclarativeBase): - type_annotation_map = { - enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False), - typing.Literal: sqlalchemy.Enum(enum.Enum, native_enum=False), - } + intpk = Annotated[int, mapped_column(primary_key=True)] + timestamp = Annotated[ + datetime.datetime, + mapped_column(nullable=False, server_default=func.CURRENT_TIMESTAMP()), + ] -.. versionchanged:: 2.0.1 Implemented support for overriding parameters - such as :paramref:`_sqltypes.Enum.native_enum` within the - :class:`_sqltypes.Enum` datatype when establishing the - :paramref:`_orm.registry.type_annotation_map`. Previously, this - functionality was not working. -To use a specific configuration for a specific ``enum.Enum`` subtype, such -as setting the string length to 50 when using the example ``Status`` -datatype:: + class Base(DeclarativeBase): + pass - import enum - import sqlalchemy - from sqlalchemy.orm import DeclarativeBase + class Parent(Base): + __tablename__ = "parent" - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" + id: Mapped[intpk] - class Base(DeclarativeBase): - type_annotation_map = { - Status: sqlalchemy.Enum(Status, length=50, native_enum=False) - } + class SomeClass(Base): + __tablename__ = "some_table" -Linking Specific ``enum.Enum`` or ``typing.Literal`` to other datatypes -++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # add ForeignKey to mapped_column(Integer, primary_key=True) + id: Mapped[intpk] = mapped_column(ForeignKey("parent.id")) -The above examples feature the use of an :class:`_sqltypes.Enum` that is -automatically configuring itself to the arguments / attributes present on -an ``enum.Enum`` or ``typing.Literal`` type object. For use cases where -specific kinds of ``enum.Enum`` or ``typing.Literal`` should be linked to -other types, these specific types may be placed in the type map also. -In the example below, an entry for ``Literal[]`` that contains non-string -types is linked to the :class:`_sqltypes.JSON` datatype:: + # change server default from CURRENT_TIMESTAMP to UTC_TIMESTAMP + created_at: Mapped[timestamp] = mapped_column(server_default=func.UTC_TIMESTAMP()) +The CREATE TABLE statement illustrates these per-attribute settings, +adding a ``FOREIGN KEY`` constraint as well as substituting +``UTC_TIMESTAMP`` for ``CURRENT_TIMESTAMP``: - from typing import Literal +.. sourcecode:: pycon+sql - from sqlalchemy import JSON - from sqlalchemy.orm import DeclarativeBase + >>> from sqlalchemy.schema import CreateTable + >>> print(CreateTable(SomeClass.__table__)) + {printsql}CREATE TABLE some_table ( + id INTEGER NOT NULL, + created_at DATETIME DEFAULT UTC_TIMESTAMP() NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY(id) REFERENCES parent (id) + ) - my_literal = Literal[0, 1, True, False, "true", "false"] +.. note:: The feature of :func:`_orm.mapped_column` just described, where + a fully constructed set of column arguments may be indicated using + :pep:`593` ``Annotated`` objects that contain a "template" + :func:`_orm.mapped_column` object to be copied into the attribute, is + currently not implemented for other ORM constructs such as + :func:`_orm.relationship` and :func:`_orm.composite`. While this functionality + is in theory possible, for the moment attempting to use ``Annotated`` + to indicate further arguments for :func:`_orm.relationship` and similar + will raise a ``NotImplementedError`` exception at runtime, but + may be implemented in future releases. +.. _orm_declarative_mapped_column_enums: - class Base(DeclarativeBase): - type_annotation_map = {my_literal: JSON} +Using Python ``Enum`` or pep-586 ``Literal`` types in the type map +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In the above configuration, the ``my_literal`` datatype will resolve to a -:class:`._sqltypes.JSON` instance. Other ``Literal`` variants will continue -to resolve to :class:`_sqltypes.Enum` datatypes. +.. versionadded:: 2.0.0b4 - Added ``Enum`` support -Dataclass features in ``mapped_column()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. versionadded:: 2.0.1 - Added ``Literal`` support -The :func:`_orm.mapped_column` construct integrates with SQLAlchemy's -"native dataclasses" feature, discussed at -:ref:`orm_declarative_native_dataclasses`. See that section for current -background on additional directives supported by :func:`_orm.mapped_column`. +User-defined Python types which derive from the Python built-in ``enum.Enum`` +as well as the ``typing.Literal`` +class are automatically linked to the SQLAlchemy :class:`.Enum` datatype +when used in an ORM declarative mapping. The example below uses +a custom ``enum.Enum`` within the ``Mapped[]`` constructor:: + import enum + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column -.. _orm_declarative_metadata: -Accessing Table and Metadata -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + class Base(DeclarativeBase): + pass -A declaratively mapped class will always include an attribute called -``__table__``; when the above configuration using ``__tablename__`` is -complete, the declarative process makes the :class:`_schema.Table` -available via the ``__table__`` attribute:: + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" - # access the Table - user_table = User.__table__ -The above table is ultimately the same one that corresponds to the -:attr:`_orm.Mapper.local_table` attribute, which we can see through the -:ref:`runtime inspection system `:: + class SomeClass(Base): + __tablename__ = "some_table" - from sqlalchemy import inspect + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] - user_table = inspect(User).local_table +In the above example, the mapped attribute ``SomeClass.status`` will be +linked to a :class:`.Column` with the datatype of ``Enum(Status)``. +We can see this for example in the CREATE TABLE output for the PostgreSQL +database: -The :class:`_schema.MetaData` collection associated with both the declarative -:class:`_orm.registry` as well as the base class is frequently necessary in -order to run DDL operations such as CREATE, as well as in use with migration -tools such as Alembic. This object is available via the ``.metadata`` -attribute of :class:`_orm.registry` as well as the declarative base class. -Below, for a small script we may wish to emit a CREATE for all tables against a -SQLite database:: +.. sourcecode:: sql - engine = create_engine("sqlite://") + CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED') - Base.metadata.create_all(engine) + CREATE TABLE some_table ( + id SERIAL NOT NULL, + status status NOT NULL, + PRIMARY KEY (id) + ) -.. _orm_declarative_table_configuration: +In a similar way, ``typing.Literal`` may be used instead, using +a ``typing.Literal`` that consists of all strings:: -Declarative Table Configuration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When using Declarative Table configuration with the ``__tablename__`` -declarative class attribute, additional arguments to be supplied to the -:class:`_schema.Table` constructor should be provided using the -``__table_args__`` declarative class attribute. + from typing import Literal -This attribute accommodates both positional as well as keyword -arguments that are normally sent to the -:class:`_schema.Table` constructor. -The attribute can be specified in one of two forms. One is as a -dictionary:: + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column - class MyClass(Base): - __tablename__ = "sometable" - __table_args__ = {"mysql_engine": "InnoDB"} -The other, a tuple, where each argument is positional -(usually constraints):: + class Base(DeclarativeBase): + pass - class MyClass(Base): - __tablename__ = "sometable" - __table_args__ = ( - ForeignKeyConstraint(["id"], ["remote_table.id"]), - UniqueConstraint("foo"), - ) -Keyword arguments can be specified with the above form by -specifying the last argument as a dictionary:: + Status = Literal["pending", "received", "completed"] - class MyClass(Base): - __tablename__ = "sometable" - __table_args__ = ( - ForeignKeyConstraint(["id"], ["remote_table.id"]), - UniqueConstraint("foo"), - {"autoload": True}, - ) -A class may also specify the ``__table_args__`` declarative attribute, -as well as the ``__tablename__`` attribute, in a dynamic style using the -:func:`_orm.declared_attr` method decorator. See -:ref:`orm_mixins_toplevel` for background. + class SomeClass(Base): + __tablename__ = "some_table" -.. _orm_declarative_table_schema_name: + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] -Explicit Schema Name with Declarative Table -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The entries used in :paramref:`_orm.registry.type_annotation_map` link the base +``enum.Enum`` Python type as well as the ``typing.Literal`` type to the +SQLAlchemy :class:`.Enum` SQL type, using a special form which indicates to the +:class:`.Enum` datatype that it should automatically configure itself against +an arbitrary enumerated type. This configuration, which is implicit by default, +would be indicated explicitly as:: -The schema name for a :class:`_schema.Table` as documented at -:ref:`schema_table_schema_name` is applied to an individual :class:`_schema.Table` -using the :paramref:`_schema.Table.schema` argument. When using Declarative -tables, this option is passed like any other to the ``__table_args__`` -dictionary:: + import enum + import typing + import sqlalchemy from sqlalchemy.orm import DeclarativeBase class Base(DeclarativeBase): - pass + type_annotation_map = { + enum.Enum: sqlalchemy.Enum(enum.Enum), + typing.Literal: sqlalchemy.Enum(enum.Enum), + } + +The resolution logic within Declarative is able to resolve subclasses +of ``enum.Enum`` as well as instances of ``typing.Literal`` to match the +``enum.Enum`` or ``typing.Literal`` entry in the +:paramref:`_orm.registry.type_annotation_map` dictionary. The :class:`.Enum` +SQL type then knows how to produce a configured version of itself with the +appropriate settings, including default string length. If a ``typing.Literal`` +that does not consist of only string values is passed, an informative +error is raised. +``typing.TypeAliasType`` can also be used to create enums, by assigning them +to a ``typing.Literal`` of strings:: - class MyClass(Base): - __tablename__ = "sometable" - __table_args__ = {"schema": "some_schema"} + from typing import Literal -The schema name can also be applied to all :class:`_schema.Table` objects -globally by using the :paramref:`_schema.MetaData.schema` parameter documented -at :ref:`schema_metadata_schema_name`. The :class:`_schema.MetaData` object -may be constructed separately and associated with a :class:`_orm.DeclarativeBase` -subclass by assigning to the ``metadata`` attribute directly:: + type Status = Literal["on", "off", "unknown"] - from sqlalchemy import MetaData - from sqlalchemy.orm import DeclarativeBase +Since this is a ``typing.TypeAliasType``, it represents a unique type object, +so it must be placed in the ``type_annotation_map`` for it to be looked up +successfully, keyed to the :class:`.Enum` type as follows:: - metadata_obj = MetaData(schema="some_schema") + import enum + import sqlalchemy class Base(DeclarativeBase): - metadata = metadata_obj + type_annotation_map = {Status: sqlalchemy.Enum(enum.Enum)} +Since SQLAlchemy supports mapping different ``typing.TypeAliasType`` +objects that are otherwise structurally equivalent individually, +these must be present in ``type_annotation_map`` to avoid ambiguity. - class MyClass(Base): - # will use "some_schema" by default - __tablename__ = "sometable" +Native Enums and Naming +~~~~~~~~~~~~~~~~~~~~~~~~ -.. seealso:: +The :paramref:`.sqltypes.Enum.native_enum` parameter refers to if the +:class:`.sqltypes.Enum` datatype should create a so-called "native" +enum, which on MySQL/MariaDB is the ``ENUM`` datatype and on PostgreSQL is +a new ``TYPE`` object created by ``CREATE TYPE``, or a "non-native" enum, +which means that ``VARCHAR`` will be used to create the datatype. For +backends other than MySQL/MariaDB or PostgreSQL, ``VARCHAR`` is used in +all cases (third party dialects may have their own behaviors). - :ref:`schema_table_schema_name` - in the :ref:`metadata_toplevel` documentation. +Because PostgreSQL's ``CREATE TYPE`` requires that there's an explicit name +for the type to be created, special fallback logic exists when working +with implicitly generated :class:`.sqltypes.Enum` without specifying an +explicit :class:`.sqltypes.Enum` datatype within a mapping: -.. _orm_declarative_column_options: +1. If the :class:`.sqltypes.Enum` is linked to an ``enum.Enum`` object, + the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to + ``True`` and the name of the enum will be taken from the name of the + ``enum.Enum`` datatype. The PostgreSQL backend will assume ``CREATE TYPE`` + with this name. +2. If the :class:`.sqltypes.Enum` is linked to a ``typing.Literal`` object, + the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to + ``False``; no name is generated and ``VARCHAR`` is assumed. -Setting Load and Persistence Options for Declarative Mapped Columns -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +To use ``typing.Literal`` with a PostgreSQL ``CREATE TYPE`` type, an +explicit :class:`.sqltypes.Enum` must be used, either within the +type map:: -The :func:`_orm.mapped_column` construct accepts additional ORM-specific -arguments that affect how the generated :class:`_schema.Column` is -mapped, affecting its load and persistence-time behavior. Options -that are commonly used include: + import enum + import typing -* **deferred column loading** - The :paramref:`_orm.mapped_column.deferred` - boolean establishes the :class:`_schema.Column` using - :ref:`deferred column loading ` by default. In the example - below, the ``User.bio`` column will not be loaded by default, but only - when accessed:: + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase - class User(Base): - __tablename__ = "user" + Status = Literal["pending", "received", "completed"] - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] - bio: Mapped[str] = mapped_column(Text, deferred=True) - .. seealso:: + class Base(DeclarativeBase): + type_annotation_map = { + Status: sqlalchemy.Enum("pending", "received", "completed", name="status_enum"), + } - :ref:`orm_queryguide_column_deferral` - full description of deferred column loading +Or alternatively within :func:`_orm.mapped_column`:: -* **active history** - The :paramref:`_orm.mapped_column.active_history` - ensures that upon change of value for the attribute, the previous value - will have been loaded and made part of the :attr:`.AttributeState.history` - collection when inspecting the history of the attribute. This may incur - additional SQL statements:: + import enum + import typing - class User(Base): - __tablename__ = "user" + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + Status = Literal["pending", "received", "completed"] + + + class Base(DeclarativeBase): + pass + + + class SomeClass(Base): + __tablename__ = "some_table" id: Mapped[int] = mapped_column(primary_key=True) - important_identifier: Mapped[str] = mapped_column(active_history=True) + status: Mapped[Status] = mapped_column( + sqlalchemy.Enum("pending", "received", "completed", name="status_enum") + ) -See the docstring for :func:`_orm.mapped_column` for a list of supported -parameters. +Altering the Configuration of the Default Enum +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. seealso:: +In order to modify the fixed configuration of the :class:`.enum.Enum` datatype +that's generated implicitly, specify new entries in the +:paramref:`_orm.registry.type_annotation_map`, indicating additional arguments. +For example, to use "non native enumerations" unconditionally, the +:paramref:`.Enum.native_enum` parameter may be set to False for all types:: - :ref:`orm_imperative_table_column_options` - describes using - :func:`_orm.column_property` and :func:`_orm.deferred` for use with - Imperative Table configuration + import enum + import typing + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase -.. _mapper_column_distinct_names: -.. _orm_declarative_table_column_naming: + class Base(DeclarativeBase): + type_annotation_map = { + enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False), + typing.Literal: sqlalchemy.Enum(enum.Enum, native_enum=False), + } -Naming Declarative Mapped Columns Explicitly -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. versionchanged:: 2.0.1 Implemented support for overriding parameters + such as :paramref:`_sqltypes.Enum.native_enum` within the + :class:`_sqltypes.Enum` datatype when establishing the + :paramref:`_orm.registry.type_annotation_map`. Previously, this + functionality was not working. -All of the examples thus far feature the :func:`_orm.mapped_column` construct -linked to an ORM mapped attribute, where the Python attribute name given -to the :func:`_orm.mapped_column` is also that of the column as we see in -CREATE TABLE statements as well as queries. The name for a column as -expressed in SQL may be indicated by passing the string positional argument -:paramref:`_orm.mapped_column.__name` as the first positional argument. -In the example below, the ``User`` class is mapped with alternate names -given to the columns themselves:: +To use a specific configuration for a specific ``enum.Enum`` subtype, such +as setting the string length to 50 when using the example ``Status`` +datatype:: - class User(Base): - __tablename__ = "user" + import enum + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase - id: Mapped[int] = mapped_column("user_id", primary_key=True) - name: Mapped[str] = mapped_column("user_name") -Where above ``User.id`` resolves to a column named ``user_id`` -and ``User.name`` resolves to a column named ``user_name``. We -may write a :func:`_sql.select` statement using our Python attribute names -and will see the SQL names generated: + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" -.. sourcecode:: pycon+sql - >>> from sqlalchemy import select - >>> print(select(User.id, User.name).where(User.name == "x")) - {printsql}SELECT "user".user_id, "user".user_name - FROM "user" - WHERE "user".user_name = :user_name_1 + class Base(DeclarativeBase): + type_annotation_map = { + Status: sqlalchemy.Enum(Status, length=50, native_enum=False) + } +By default :class:`_sqltypes.Enum` that are automatically generated are not +associated with the :class:`_sql.MetaData` instance used by the ``Base``, so if +the metadata defines a schema it will not be automatically associated with the +enum. To automatically associate the enum with the schema in the metadata or +table they belong to the :paramref:`_sqltypes.Enum.inherit_schema` can be set:: -.. seealso:: + from enum import Enum + import sqlalchemy as sa + from sqlalchemy.orm import DeclarativeBase - :ref:`orm_imperative_table_column_naming` - applies to Imperative Table -.. _orm_declarative_table_adding_columns: + class Base(DeclarativeBase): + metadata = sa.MetaData(schema="my_schema") + type_annotation_map = {Enum: sa.Enum(Enum, inherit_schema=True)} -Appending additional columns to an existing Declarative mapped class -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Linking Specific ``enum.Enum`` or ``typing.Literal`` to other datatypes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -A declarative table configuration allows the addition of new -:class:`_schema.Column` objects to an existing mapping after the :class:`.Table` -metadata has already been generated. +The above examples feature the use of an :class:`_sqltypes.Enum` that is +automatically configuring itself to the arguments / attributes present on +an ``enum.Enum`` or ``typing.Literal`` type object. For use cases where +specific kinds of ``enum.Enum`` or ``typing.Literal`` should be linked to +other types, these specific types may be placed in the type map also. +In the example below, an entry for ``Literal[]`` that contains non-string +types is linked to the :class:`_sqltypes.JSON` datatype:: -For a declarative class that is declared using a declarative base class, -the underlying metaclass :class:`.DeclarativeMeta` includes a ``__setattr__()`` -method that will intercept additional :func:`_orm.mapped_column` or Core -:class:`.Column` objects and -add them to both the :class:`.Table` using :meth:`.Table.append_column` -as well as to the existing :class:`.Mapper` using :meth:`.Mapper.add_property`:: - MyClass.some_new_column = mapped_column(String) + from typing import Literal -Using core :class:`_schema.Column`:: + from sqlalchemy import JSON + from sqlalchemy.orm import DeclarativeBase - MyClass.some_new_column = Column(String) + my_literal = Literal[0, 1, True, False, "true", "false"] -All arguments are supported including an alternate name, such as -``MyClass.some_new_column = mapped_column("some_name", String)``. However, -the SQL type must be passed to the :func:`_orm.mapped_column` or -:class:`_schema.Column` object explicitly, as in the above examples where -the :class:`_sqltypes.String` type is passed. There's no capability for -the :class:`_orm.Mapped` annotation type to take part in the operation. -Additional :class:`_schema.Column` objects may also be added to a mapping -in the specific circumstance of using single table inheritance, where -additional columns are present on mapped subclasses that have -no :class:`.Table` of their own. This is illustrated in the section -:ref:`single_inheritance`. + class Base(DeclarativeBase): + type_annotation_map = {my_literal: JSON} -.. note:: Assignment of mapped - properties to an already mapped class will only - function correctly if the "declarative base" class is used, meaning - the user-defined subclass of :class:`_orm.DeclarativeBase` or the - dynamically generated class returned by :func:`_orm.declarative_base` - or :meth:`_orm.registry.generate_base`. This "base" class includes - a Python metaclass which implements a special ``__setattr__()`` method - that intercepts these operations. +In the above configuration, the ``my_literal`` datatype will resolve to a +:class:`._sqltypes.JSON` instance. Other ``Literal`` variants will continue +to resolve to :class:`_sqltypes.Enum` datatypes. - Runtime assignment of class-mapped attributes to a mapped class will **not** work - if the class is mapped using decorators like :meth:`_orm.registry.mapped` - or imperative functions like :meth:`_orm.registry.map_imperatively`. .. _orm_imperative_table_configuration: @@ -1233,7 +1555,7 @@ mapper configuration:: __mapper_args__ = { "polymorphic_on": __table__.c.type, - "polymorhpic_identity": "person", + "polymorphic_identity": "person", } The "imperative table" form is also used when a non-:class:`_schema.Table` diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 0815da29aff..5b881054304 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -64,47 +64,64 @@ methods which both deliver asynchronous context managers. The :class:`_asyncio.AsyncConnection` can then invoke statements using either the :meth:`_asyncio.AsyncConnection.execute` method to deliver a buffered :class:`_engine.Result`, or the :meth:`_asyncio.AsyncConnection.stream` method -to deliver a streaming server-side :class:`_asyncio.AsyncResult`:: - - import asyncio - - from sqlalchemy import Column - from sqlalchemy import MetaData - from sqlalchemy import select - from sqlalchemy import String - from sqlalchemy import Table - from sqlalchemy.ext.asyncio import create_async_engine - - meta = MetaData() - t1 = Table("t1", meta, Column("name", String(50), primary_key=True)) - - - async def async_main() -> None: - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", - echo=True, - ) - - async with engine.begin() as conn: - await conn.run_sync(meta.create_all) - - await conn.execute( - t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] - ) - - async with engine.connect() as conn: - # select a Result, which will be delivered with buffered - # results - result = await conn.execute(select(t1).where(t1.c.name == "some name 1")) - - print(result.fetchall()) - - # for AsyncEngine created in function scope, close and - # clean-up pooled connections - await engine.dispose() - - - asyncio.run(async_main()) +to deliver a streaming server-side :class:`_asyncio.AsyncResult`: + +.. sourcecode:: pycon+sql + + >>> import asyncio + + >>> from sqlalchemy import Column + >>> from sqlalchemy import MetaData + >>> from sqlalchemy import select + >>> from sqlalchemy import String + >>> from sqlalchemy import Table + >>> from sqlalchemy.ext.asyncio import create_async_engine + + >>> meta = MetaData() + >>> t1 = Table("t1", meta, Column("name", String(50), primary_key=True)) + + + >>> async def async_main() -> None: + ... engine = create_async_engine("sqlite+aiosqlite://", echo=True) + ... + ... async with engine.begin() as conn: + ... await conn.run_sync(meta.drop_all) + ... await conn.run_sync(meta.create_all) + ... + ... await conn.execute( + ... t1.insert(), [{"name": "some name 1"}, {"name": "some name 2"}] + ... ) + ... + ... async with engine.connect() as conn: + ... # select a Result, which will be delivered with buffered + ... # results + ... result = await conn.execute(select(t1).where(t1.c.name == "some name 1")) + ... + ... print(result.fetchall()) + ... + ... # for AsyncEngine created in function scope, close and + ... # clean-up pooled connections + ... await engine.dispose() + + + >>> asyncio.run(async_main()) + {execsql}BEGIN (implicit) + ... + CREATE TABLE t1 ( + name VARCHAR(50) NOT NULL, + PRIMARY KEY (name) + ) + ... + INSERT INTO t1 (name) VALUES (?) + [...] [('some name 1',), ('some name 2',)] + COMMIT + BEGIN (implicit) + SELECT t1.name + FROM t1 + WHERE t1.name = ? + [...] ('some name 1',) + [('some name 1',)] + ROLLBACK Above, the :meth:`_asyncio.AsyncConnection.run_sync` method may be used to invoke special DDL functions such as :meth:`_schema.MetaData.create_all` that @@ -154,114 +171,165 @@ this. :ref:`asyncio_concurrency` and :ref:`session_faq_threadsafe` for background. The example below illustrates a complete example including mapper and session -configuration:: - - from __future__ import annotations - - import asyncio - import datetime - from typing import List - - from sqlalchemy import ForeignKey - from sqlalchemy import func - from sqlalchemy import select - from sqlalchemy.ext.asyncio import AsyncAttrs - from sqlalchemy.ext.asyncio import async_sessionmaker - from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - from sqlalchemy.orm import relationship - from sqlalchemy.orm import selectinload - - - class Base(AsyncAttrs, DeclarativeBase): - pass - - - class A(Base): - __tablename__ = "a" - - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[str] - create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) - bs: Mapped[List[B]] = relationship() - - - class B(Base): - __tablename__ = "b" - id: Mapped[int] = mapped_column(primary_key=True) - a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) - data: Mapped[str] - - - async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None: - async with async_session() as session: - async with session.begin(): - session.add_all( - [ - A(bs=[B(), B()], data="a1"), - A(bs=[], data="a2"), - A(bs=[B(), B()], data="a3"), - ] - ) - - - async def select_and_update_objects( - async_session: async_sessionmaker[AsyncSession], - ) -> None: - async with async_session() as session: - stmt = select(A).options(selectinload(A.bs)) - - result = await session.execute(stmt) - - for a1 in result.scalars(): - print(a1) - print(f"created at: {a1.create_date}") - for b1 in a1.bs: - print(b1) - - result = await session.execute(select(A).order_by(A.id).limit(1)) - - a1 = result.scalars().one() - - a1.data = "new data" - - await session.commit() - - # access attribute subsequent to commit; this is what - # expire_on_commit=False allows - print(a1.data) - - # alternatively, AsyncAttrs may be used to access any attribute - # as an awaitable (new in 2.0.13) - for b1 in await a1.awaitable_attrs.bs: - print(b1) - - - async def async_main() -> None: - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/test", - echo=True, - ) - - # async_sessionmaker: a factory for new AsyncSession objects. - # expire_on_commit - don't expire objects after transaction commit - async_session = async_sessionmaker(engine, expire_on_commit=False) - - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - await insert_objects(async_session) - await select_and_update_objects(async_session) - - # for AsyncEngine created in function scope, close and - # clean-up pooled connections - await engine.dispose() - - - asyncio.run(async_main()) +configuration: + +.. sourcecode:: pycon+sql + + >>> from __future__ import annotations + + >>> import asyncio + >>> import datetime + >>> from typing import List + + >>> from sqlalchemy import ForeignKey + >>> from sqlalchemy import func + >>> from sqlalchemy import select + >>> from sqlalchemy.ext.asyncio import AsyncAttrs + >>> from sqlalchemy.ext.asyncio import async_sessionmaker + >>> from sqlalchemy.ext.asyncio import AsyncSession + >>> from sqlalchemy.ext.asyncio import create_async_engine + >>> from sqlalchemy.orm import DeclarativeBase + >>> from sqlalchemy.orm import Mapped + >>> from sqlalchemy.orm import mapped_column + >>> from sqlalchemy.orm import relationship + >>> from sqlalchemy.orm import selectinload + + + >>> class Base(AsyncAttrs, DeclarativeBase): + ... pass + + >>> class B(Base): + ... __tablename__ = "b" + ... + ... id: Mapped[int] = mapped_column(primary_key=True) + ... a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + ... data: Mapped[str] + + >>> class A(Base): + ... __tablename__ = "a" + ... + ... id: Mapped[int] = mapped_column(primary_key=True) + ... data: Mapped[str] + ... create_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) + ... bs: Mapped[List[B]] = relationship() + + >>> async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None: + ... async with async_session() as session: + ... async with session.begin(): + ... session.add_all( + ... [ + ... A(bs=[B(data="b1"), B(data="b2")], data="a1"), + ... A(bs=[], data="a2"), + ... A(bs=[B(data="b3"), B(data="b4")], data="a3"), + ... ] + ... ) + + + >>> async def select_and_update_objects( + ... async_session: async_sessionmaker[AsyncSession], + ... ) -> None: + ... async with async_session() as session: + ... stmt = select(A).order_by(A.id).options(selectinload(A.bs)) + ... + ... result = await session.execute(stmt) + ... + ... for a in result.scalars(): + ... print(a, a.data) + ... print(f"created at: {a.create_date}") + ... for b in a.bs: + ... print(b, b.data) + ... + ... result = await session.execute(select(A).order_by(A.id).limit(1)) + ... + ... a1 = result.scalars().one() + ... + ... a1.data = "new data" + ... + ... await session.commit() + ... + ... # access attribute subsequent to commit; this is what + ... # expire_on_commit=False allows + ... print(a1.data) + ... + ... # alternatively, AsyncAttrs may be used to access any attribute + ... # as an awaitable (new in 2.0.13) + ... for b1 in await a1.awaitable_attrs.bs: + ... print(b1, b1.data) + + + >>> async def async_main() -> None: + ... engine = create_async_engine("sqlite+aiosqlite://", echo=True) + ... + ... # async_sessionmaker: a factory for new AsyncSession objects. + ... # expire_on_commit - don't expire objects after transaction commit + ... async_session = async_sessionmaker(engine, expire_on_commit=False) + ... + ... async with engine.begin() as conn: + ... await conn.run_sync(Base.metadata.create_all) + ... + ... await insert_objects(async_session) + ... await select_and_update_objects(async_session) + ... + ... # for AsyncEngine created in function scope, close and + ... # clean-up pooled connections + ... await engine.dispose() + + + >>> asyncio.run(async_main()) + {execsql}BEGIN (implicit) + ... + CREATE TABLE a ( + id INTEGER NOT NULL, + data VARCHAR NOT NULL, + create_date DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (id) + ) + ... + CREATE TABLE b ( + id INTEGER NOT NULL, + a_id INTEGER NOT NULL, + data VARCHAR NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY(a_id) REFERENCES a (id) + ) + ... + COMMIT + BEGIN (implicit) + INSERT INTO a (data) VALUES (?) RETURNING id, create_date + [...] ('a1',) + ... + INSERT INTO b (a_id, data) VALUES (?, ?) RETURNING id + [...] (1, 'b2') + ... + COMMIT + BEGIN (implicit) + SELECT a.id, a.data, a.create_date + FROM a ORDER BY a.id + [...] () + SELECT b.a_id AS b_a_id, b.id AS b_id, b.data AS b_data + FROM b + WHERE b.a_id IN (?, ?, ?) + [...] (1, 2, 3) + a1 + created at: ... + b1 + b2 + a2 + created at: ... + a3 + created at: ... + b3 + b4 + SELECT a.id, a.data, a.create_date + FROM a ORDER BY a.id + LIMIT ? OFFSET ? + [...] (1, 0) + UPDATE a SET data=? WHERE a.id = ? + [...] ('new data', 1) + COMMIT + new data + b1 + b2 In the example above, the :class:`_asyncio.AsyncSession` is instantiated using the optional :class:`_asyncio.async_sessionmaker` helper, which provides diff --git a/doc/build/orm/extensions/mypy.rst b/doc/build/orm/extensions/mypy.rst index 042af370914..b7d50c607ad 100644 --- a/doc/build/orm/extensions/mypy.rst +++ b/doc/build/orm/extensions/mypy.rst @@ -11,9 +11,10 @@ the :func:`_orm.mapped_column` construct introduced in SQLAlchemy 2.0. .. deprecated:: 2.0 - **The SQLAlchemy Mypy Plugin is DEPRECATED, and will be removed possibly - as early as the SQLAlchemy 2.1 release. We would urge users to please - migrate away from it ASAP.** + **The SQLAlchemy Mypy Plugin is DEPRECATED, and will be removed in + the SQLAlchemy 2.1 release. We would urge users to please + migrate away from it ASAP. The mypy plugin also works only up until + mypy version 1.10.1. version 1.11.0 and greater may not work properly.** This plugin cannot be maintained across constantly changing releases of mypy and its stability going forward CANNOT be guaranteed. @@ -24,7 +25,11 @@ the :func:`_orm.mapped_column` construct introduced in SQLAlchemy 2.0. .. topic:: SQLAlchemy Mypy Plugin Status Update - **Updated July 2023** + **Updated July 2024** + + The mypy plugin is supported **only up until mypy 1.10.1, and it will have + issues running with 1.11.0 or greater**. Use with mypy 1.11.0 or greater + may have error conditions which currently cannot be resolved. For SQLAlchemy 2.0, the Mypy plugin continues to work at the level at which it reached in the SQLAlchemy 1.4 release. SQLAlchemy 2.0 however features @@ -179,8 +184,7 @@ following:: ) name: Mapped[Optional[str]] = Mapped._special_method(Column(String)) - def __init__(self, id: Optional[int] = ..., name: Optional[str] = ...) -> None: - ... + def __init__(self, id: Optional[int] = ..., name: Optional[str] = ...) -> None: ... some_user = User(id=5, name="user") @@ -498,7 +502,7 @@ plugin that a particular class intends to serve as a declarative mixin:: class HasCompany: @declared_attr def company_id(cls) -> Mapped[int]: # uses Mapped - return Column(ForeignKey("company.id")) + return mapped_column(ForeignKey("company.id")) @declared_attr def company(cls) -> Mapped["Company"]: diff --git a/doc/build/orm/inheritance.rst b/doc/build/orm/inheritance.rst index fe3e06bf0f0..7a19de9ae42 100644 --- a/doc/build/orm/inheritance.rst +++ b/doc/build/orm/inheritance.rst @@ -3,12 +3,13 @@ Mapping Class Inheritance Hierarchies ===================================== -SQLAlchemy supports three forms of inheritance: **single table inheritance**, -where several types of classes are represented by a single table, **concrete -table inheritance**, where each type of class is represented by independent -tables, and **joined table inheritance**, where the class hierarchy is broken -up among dependent tables, each class represented by its own table that only -includes those attributes local to that class. +SQLAlchemy supports three forms of inheritance: + +* **single table inheritance** – several types of classes are represented by a single table; + +* **concrete table inheritance** – each type of class is represented by independent tables; + +* **joined table inheritance** – the class hierarchy is broken up among dependent tables. Each class represented by its own table that only includes those attributes local to that class. The most common forms of inheritance are single and joined table, while concrete inheritance presents more configurational challenges. @@ -203,12 +204,10 @@ and ``Employee``:: } - class Manager(Employee): - ... + class Manager(Employee): ... - class Engineer(Employee): - ... + class Engineer(Employee): ... If the foreign key constraint is on a table corresponding to a subclass, the relationship should target that subclass instead. In the example @@ -248,8 +247,7 @@ established between the ``Manager`` and ``Company`` classes:: } - class Engineer(Employee): - ... + class Engineer(Employee): ... Above, the ``Manager`` class will have a ``Manager.company`` attribute; ``Company`` will have a ``Company.managers`` attribute that always @@ -638,7 +636,7 @@ using :paramref:`_orm.Mapper.polymorphic_abstract` as follows:: class SysAdmin(Technologist): """a systems administrator""" - __mapper_args__ = {"polymorphic_identity": "engineer"} + __mapper_args__ = {"polymorphic_identity": "sysadmin"} In the above example, the new classes ``Technologist`` and ``Executive`` are ordinary mapped classes, and also indicate new columns to be added to the diff --git a/doc/build/orm/join_conditions.rst b/doc/build/orm/join_conditions.rst index ef6d74e6676..8a220c9d8a1 100644 --- a/doc/build/orm/join_conditions.rst +++ b/doc/build/orm/join_conditions.rst @@ -142,7 +142,7 @@ load those ``Address`` objects which specify a city of "Boston":: name = mapped_column(String) boston_addresses = relationship( "Address", - primaryjoin="and_(User.id==Address.user_id, " "Address.city=='Boston')", + primaryjoin="and_(User.id==Address.user_id, Address.city=='Boston')", ) @@ -297,7 +297,7 @@ a :func:`_orm.relationship`:: network = relationship( "Network", - primaryjoin="IPA.v4address.bool_op('<<')" "(foreign(Network.v4representation))", + primaryjoin="IPA.v4address.bool_op('<<')(foreign(Network.v4representation))", viewonly=True, ) @@ -389,7 +389,7 @@ for both; then to make ``Article`` refer to ``Writer`` as well, article_id = mapped_column(Integer) magazine_id = mapped_column(ForeignKey("magazine.id")) - writer_id = mapped_column() + writer_id = mapped_column(Integer) magazine = relationship("Magazine") writer = relationship("Writer") @@ -424,13 +424,19 @@ What this refers to originates from the fact that ``Article.magazine_id`` is the subject of two different foreign key constraints; it refers to ``Magazine.id`` directly as a source column, but also refers to ``Writer.magazine_id`` as a source column in the context of the -composite key to ``Writer``. If we associate an ``Article`` with a -particular ``Magazine``, but then associate the ``Article`` with a -``Writer`` that's associated with a *different* ``Magazine``, the ORM -will overwrite ``Article.magazine_id`` non-deterministically, silently -changing which magazine to which we refer; it may -also attempt to place NULL into this column if we de-associate a -``Writer`` from an ``Article``. The warning lets us know this is the case. +composite key to ``Writer``. + +When objects are added to an ORM :class:`.Session` using :meth:`.Session.add`, +the ORM :term:`flush` process takes on the task of reconciling object +refereneces that correspond to :func:`_orm.relationship` configurations and +delivering this state to the databse using INSERT/UPDATE/DELETE statements. In +this specific example, if we associate an ``Article`` with a particular +``Magazine``, but then associate the ``Article`` with a ``Writer`` that's +associated with a *different* ``Magazine``, this flush process will overwrite +``Article.magazine_id`` non-deterministically, silently changing which magazine +to which we refer; it may also attempt to place NULL into this column if we +de-associate a ``Writer`` from an ``Article``. The warning lets us know that +this scenario may occur during ORM flush sequences. To solve this, we need to break out the behavior of ``Article`` to include all three of the following features: @@ -543,9 +549,9 @@ is when establishing a many-to-many relationship from a class to itself, as show from typing import List - from sqlalchemy import Integer, ForeignKey, String, Column, Table - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import relationship + from sqlalchemy import Integer, ForeignKey, Column, Table + from sqlalchemy.orm import DeclarativeBase, Mapped + from sqlalchemy.orm import mapped_column, relationship class Base(DeclarativeBase): @@ -564,14 +570,14 @@ is when establishing a many-to-many relationship from a class to itself, as show __tablename__ = "node" id: Mapped[int] = mapped_column(primary_key=True) label: Mapped[str] - right_nodes: Mapped[List["None"]] = relationship( + right_nodes: Mapped[List["Node"]] = relationship( "Node", secondary=node_to_node, primaryjoin=id == node_to_node.c.left_node_id, secondaryjoin=id == node_to_node.c.right_node_id, back_populates="left_nodes", ) - left_nodes: Mapped[List["None"]] = relationship( + left_nodes: Mapped[List["Node"]] = relationship( "Node", secondary=node_to_node, primaryjoin=id == node_to_node.c.right_node_id, @@ -702,7 +708,7 @@ join condition (requires version 0.9.2 at least to function as is):: d = relationship( "D", - secondary="join(B, D, B.d_id == D.id)." "join(C, C.d_id == D.id)", + secondary="join(B, D, B.d_id == D.id).join(C, C.d_id == D.id)", primaryjoin="and_(A.b_id == B.id, A.id == C.a_id)", secondaryjoin="D.id == B.d_id", uselist=False, @@ -752,10 +758,17 @@ there's just "one" table on both the "left" and the "right" side; the complexity is kept within the middle. .. warning:: A relationship like the above is typically marked as - ``viewonly=True`` and should be considered as read-only. While there are + ``viewonly=True``, using :paramref:`_orm.relationship.viewonly`, + and should be considered as read-only. While there are sometimes ways to make relationships like the above writable, this is generally complicated and error prone. +.. seealso:: + + :ref:`relationship_viewonly_notes` + + + .. _relationship_non_primary_mapper: .. _relationship_aliased_class: @@ -763,14 +776,6 @@ complexity is kept within the middle. Relationship to Aliased Class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. versionadded:: 1.3 - The :class:`.AliasedClass` construct can now be specified as the - target of a :func:`_orm.relationship`, replacing the previous approach - of using non-primary mappers, which had limitations such that they did - not inherit sub-relationships of the mapped entity as well as that they - required complex configuration against an alternate selectable. The - recipes in this section are now updated to use :class:`.AliasedClass`. - In the previous section, we illustrated a technique where we used :paramref:`_orm.relationship.secondary` in order to place additional tables within a join condition. There is one complex join case where @@ -847,6 +852,81 @@ With the above mapping, a simple join looks like: {execsql}SELECT a.id AS a_id, a.b_id AS a_b_id FROM a JOIN (b JOIN d ON d.b_id = b.id JOIN c ON c.id = d.c_id) ON a.b_id = b.id +Integrating AliasedClass Mappings with Typing and Avoiding Early Mapper Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The creation of the :func:`_orm.aliased` construct against a mapped class +forces the :func:`_orm.configure_mappers` step to proceed, which will resolve +all current classes and their relationships. This may be problematic if +unrelated mapped classes needed by the current mappings have not yet been +declared, or if the configuration of the relationship itself needs access +to as-yet undeclared classes. Additionally, SQLAlchemy's Declarative pattern +works with Python typing most effectively when relationships are declared +up front. + +To organize the construction of the relationship to work with these issues, a +configure level event hook like :meth:`.MapperEvents.before_mapper_configured` +may be used, which will invoke the configuration code only when all mappings +are ready for configuration:: + + from sqlalchemy import event + + + class A(Base): + __tablename__ = "a" + + id = mapped_column(Integer, primary_key=True) + b_id = mapped_column(ForeignKey("b.id")) + + + @event.listens_for(A, "before_mapper_configured") + def _configure_ab_relationship(mapper, cls): + # do the above configuration in a configuration hook + + j = join(B, D, D.b_id == B.id).join(C, C.id == D.c_id) + B_viacd = aliased(B, j, flat=True) + A.b = relationship(B_viacd, primaryjoin=A.b_id == j.c.b_id) + +Above, the function ``_configure_ab_relationship()`` will be invoked only +when a fully configured version of ``A`` is requested, at which point the +classes ``B``, ``D`` and ``C`` would be available. + +For an approach that integrates with inline typing, a similar technique can be +used to effectively generate a "singleton" creation pattern for the aliased +class where it is late-initialized as a global variable, which can then be used +in the relationship inline:: + + from typing import Any + + B_viacd: Any = None + b_viacd_join: Any = None + + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + b_id: Mapped[int] = mapped_column(ForeignKey("b.id")) + + # 1. the relationship can be declared using lambdas, allowing it to resolve + # to targets that are late-configured + b: Mapped[B] = relationship( + lambda: B_viacd, primaryjoin=lambda: A.b_id == b_viacd_join.c.b_id + ) + + + # 2. configure the targets of the relationship using a before_mapper_configured + # hook. + @event.listens_for(A, "before_mapper_configured") + def _configure_ab_relationship(mapper, cls): + # 3. set up the join() and AliasedClass as globals from within + # the configuration hook. + + global B_viacd, b_viacd_join + + b_viacd_join = join(B, D, D.b_id == B.id).join(C, C.id == D.c_id) + B_viacd = aliased(B, b_viacd_join, flat=True) + Using the AliasedClass target in Queries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -986,3 +1066,247 @@ of special Python attributes. .. seealso:: :ref:`mapper_hybrids` + +.. _relationship_viewonly_notes: + +Notes on using the viewonly relationship parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :paramref:`_orm.relationship.viewonly` parameter when applied to a +:func:`_orm.relationship` construct indicates that this :func:`_orm.relationship` +will not take part in any ORM :term:`unit of work` operations, and additionally +that the attribute does not expect to participate within in-Python mutations +of its represented collection. This means +that while the viewonly relationship may refer to a mutable Python collection +like a list or set, making changes to that list or set as present on a +mapped instance will have **no effect** on the ORM flush process. + +To explore this scenario consider this mapping:: + + from __future__ import annotations + + import datetime + + from sqlalchemy import and_ + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class User(Base): + __tablename__ = "user_account" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str | None] + + all_tasks: Mapped[list[Task]] = relationship() + + current_week_tasks: Mapped[list[Task]] = relationship( + primaryjoin=lambda: and_( + User.id == Task.user_account_id, + # this expression works on PostgreSQL but may not be supported + # by other database engines + Task.task_date >= func.now() - datetime.timedelta(days=7), + ), + viewonly=True, + ) + + + class Task(Base): + __tablename__ = "task" + + id: Mapped[int] = mapped_column(primary_key=True) + user_account_id: Mapped[int] = mapped_column(ForeignKey("user_account.id")) + description: Mapped[str | None] + task_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) + + user: Mapped[User] = relationship(back_populates="current_week_tasks") + +The following sections will note different aspects of this configuration. + +In-Python mutations including backrefs are not appropriate with viewonly=True +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The above mapping targets the ``User.current_week_tasks`` viewonly relationship +as the :term:`backref` target of the ``Task.user`` attribute. This is not +currently flagged by SQLAlchemy's ORM configuration process, however is a +configuration error. Changing the ``.user`` attribute on a ``Task`` will not +affect the ``.current_week_tasks`` attribute:: + + >>> u1 = User() + >>> t1 = Task(task_date=datetime.datetime.now()) + >>> t1.user = u1 + >>> u1.current_week_tasks + [] + +There is another parameter called :paramref:`_orm.relationship.sync_backrefs` +which can be turned on here to allow ``.current_week_tasks`` to be mutated in this +case, however this is not considered to be a best practice with a viewonly +relationship, which instead should not be relied upon for in-Python mutations. + +In this mapping, backrefs can be configured between ``User.all_tasks`` and +``Task.user``, as these are both not viewonly and will synchronize normally. + +Beyond the issue of backref mutations being disabled for viewonly relationships, +plain changes to the ``User.all_tasks`` collection in Python +are also not reflected in the ``User.current_week_tasks`` collection until +changes have been flushed to the database. + +Overall, for a use case where a custom collection should respond immediately to +in-Python mutations, the viewonly relationship is generally not appropriate. A +better approach is to use the :ref:`hybrids_toplevel` feature of SQLAlchemy, or +for instance-only cases to use a Python ``@property``, where a user-defined +collection that is generated in terms of the current Python instance can be +implemented. To change our example to work this way, we repair the +:paramref:`_orm.relationship.back_populates` parameter on ``Task.user`` to +reference ``User.all_tasks``, and +then illustrate a simple ``@property`` that will deliver results in terms of +the immediate ``User.all_tasks`` collection:: + + class User(Base): + __tablename__ = "user_account" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str | None] + + all_tasks: Mapped[list[Task]] = relationship(back_populates="user") + + @property + def current_week_tasks(self) -> list[Task]: + past_seven_days = datetime.datetime.now() - datetime.timedelta(days=7) + return [t for t in self.all_tasks if t.task_date >= past_seven_days] + + + class Task(Base): + __tablename__ = "task" + + id: Mapped[int] = mapped_column(primary_key=True) + user_account_id: Mapped[int] = mapped_column(ForeignKey("user_account.id")) + description: Mapped[str | None] + task_date: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) + + user: Mapped[User] = relationship(back_populates="all_tasks") + +Using an in-Python collection calculated on the fly each time, we are guaranteed +to have the correct answer at all times, without the need to use a database +at all:: + + >>> u1 = User() + >>> t1 = Task(task_date=datetime.datetime.now()) + >>> t1.user = u1 + >>> u1.current_week_tasks + [<__main__.Task object at 0x7f3d699523c0>] + + +viewonly=True collections / attributes do not get re-queried until expired +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Continuing with the original viewonly attribute, if we do in fact make changes +to the ``User.all_tasks`` collection on a :term:`persistent` object, the +viewonly collection can only show the net result of this change after **two** +things occur. The first is that the change to ``User.all_tasks`` is +:term:`flushed`, so that the new data is available in the database, at least +within the scope of the local transaction. The second is that the ``User.current_week_tasks`` +attribute is :term:`expired` and reloaded via a new SQL query to the database. + +To support this requirement, the simplest flow to use is one where the +**viewonly relationship is consumed only in operations that are primarily read +only to start with**. Such as below, if we retrieve a ``User`` fresh from +the database, the collection will be current:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... print(u1.current_week_tasks) + [<__main__.Task object at 0x7f8711b906b0>] + + +When we make modifications to ``u1.all_tasks``, if we want to see these changes +reflected in the ``u1.current_week_tasks`` viewonly relationship, these changes need to be flushed +and the ``u1.current_week_tasks`` attribute needs to be expired, so that +it will :term:`lazy load` on next access. The simplest approach to this is +to use :meth:`_orm.Session.commit`, keeping the :paramref:`_orm.Session.expire_on_commit` +parameter set at its default of ``True``:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... u1.all_tasks.append(Task(task_date=datetime.datetime.now())) + ... sess.commit() + ... print(u1.current_week_tasks) + [<__main__.Task object at 0x7f8711b90ec0>, <__main__.Task object at 0x7f8711b90a10>] + +Above, the call to :meth:`_orm.Session.commit` flushed the changes to ``u1.all_tasks`` +to the database, then expired all objects, so that when we accessed ``u1.current_week_tasks``, +a :term:` lazy load` occurred which fetched the contents for this attribute +freshly from the database. + +To intercept operations without actually committing the transaction, +the attribute needs to be explicitly :term:`expired` +first. A simplistic way to do this is to just call it directly. In +the example below, :meth:`_orm.Session.flush` sends pending changes to the +database, then :meth:`_orm.Session.expire` is used to expire the ``u1.current_week_tasks`` +collection so that it re-fetches on next access:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... u1.all_tasks.append(Task(task_date=datetime.datetime.now())) + ... sess.flush() + ... sess.expire(u1, ["current_week_tasks"]) + ... print(u1.current_week_tasks) + [<__main__.Task object at 0x7fd95a4c8c50>, <__main__.Task object at 0x7fd95a4c8c80>] + +We can in fact skip the call to :meth:`_orm.Session.flush`, assuming a +:class:`_orm.Session` that keeps :paramref:`_orm.Session.autoflush` at its +default value of ``True``, as the expired ``current_week_tasks`` attribute will +trigger autoflush when accessed after expiration:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... u1.all_tasks.append(Task(task_date=datetime.datetime.now())) + ... sess.expire(u1, ["current_week_tasks"]) + ... print(u1.current_week_tasks) # triggers autoflush before querying + [<__main__.Task object at 0x7fd95a4c8c50>, <__main__.Task object at 0x7fd95a4c8c80>] + +Continuing with the above approach to something more elaborate, we can apply +the expiration programmatically when the related ``User.all_tasks`` collection +changes, using :ref:`event hooks `. This an **advanced +technique**, where simpler architectures like ``@property`` or sticking to +read-only use cases should be examined first. In our simple example, this +would be configured as:: + + from sqlalchemy import event, inspect + + + @event.listens_for(User.all_tasks, "append") + @event.listens_for(User.all_tasks, "remove") + @event.listens_for(User.all_tasks, "bulk_replace") + def _expire_User_current_week_tasks(target, value, initiator): + inspect(target).session.expire(target, ["current_week_tasks"]) + +With the above hooks, mutation operations are intercepted and result in +the ``User.current_week_tasks`` collection to be expired automatically:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... u1.all_tasks.append(Task(task_date=datetime.datetime.now())) + ... print(u1.current_week_tasks) + [<__main__.Task object at 0x7f66d093ccb0>, <__main__.Task object at 0x7f66d093cce0>] + +The :class:`_orm.AttributeEvents` event hooks used above are also triggered +by backref mutations, so with the above hooks a change to ``Task.user`` is +also intercepted:: + + >>> with Session(e) as sess: + ... u1 = sess.scalar(select(User).where(User.id == 1)) + ... t1 = Task(task_date=datetime.datetime.now()) + ... t1.user = u1 + ... sess.add(t1) + ... print(u1.current_week_tasks) + [<__main__.Task object at 0x7f3b0c070d10>, <__main__.Task object at 0x7f3b0c057d10>] + diff --git a/doc/build/orm/mapped_attributes.rst b/doc/build/orm/mapped_attributes.rst index d0610f4e0fa..b114680132e 100644 --- a/doc/build/orm/mapped_attributes.rst +++ b/doc/build/orm/mapped_attributes.rst @@ -234,7 +234,7 @@ logic:: """Produce a SQL expression that represents the value of the _email column, minus the last twelve characters.""" - return func.substr(cls._email, 0, func.length(cls._email) - 12) + return func.substr(cls._email, 1, func.length(cls._email) - 12) Above, accessing the ``email`` property of an instance of ``EmailAddress`` will return the value of the ``_email`` attribute, removing or adding the @@ -249,7 +249,7 @@ attribute, a SQL function is rendered which produces the same effect: {execsql}SELECT address.email AS address_email, address.id AS address_id FROM address WHERE substr(address.email, ?, length(address.email) - ?) = ? - (0, 12, 'address') + (1, 12, 'address') {stop} Read more about Hybrids at :ref:`hybrids_toplevel`. diff --git a/doc/build/orm/mapping_api.rst b/doc/build/orm/mapping_api.rst index 57ef5e00e0f..399111d6058 100644 --- a/doc/build/orm/mapping_api.rst +++ b/doc/build/orm/mapping_api.rst @@ -53,11 +53,11 @@ Class Mapping API class HasIdMixin: @declared_attr.cascading - def id(cls): + def id(cls) -> Mapped[int]: if has_inherited_table(cls): - return Column(ForeignKey("myclass.id"), primary_key=True) + return mapped_column(ForeignKey("myclass.id"), primary_key=True) else: - return Column(Integer, primary_key=True) + return mapped_column(Integer, primary_key=True) class MyClass(HasIdMixin, Base): diff --git a/doc/build/orm/mapping_columns.rst b/doc/build/orm/mapping_columns.rst index 25c6604fafa..30220baebc8 100644 --- a/doc/build/orm/mapping_columns.rst +++ b/doc/build/orm/mapping_columns.rst @@ -4,6 +4,6 @@ Mapping Table Columns ===================== This section has been integrated into the -:ref:`orm_declarative_table_config_toplevel` Declarative section. +:ref:`orm_declarative_table_config_toplevel` section. diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst index fbe4267be78..8a4b8aece84 100644 --- a/doc/build/orm/mapping_styles.rst +++ b/doc/build/orm/mapping_styles.rst @@ -370,6 +370,13 @@ An object of type ``User`` above will have a constructor which allows Python dataclasses, and allows for a highly configurable constructor form. +.. warning:: + + The ``__init__()`` method of the class is called only when the object is + constructed in Python code, and **not when an object is loaded or refreshed + from the database**. See the next section :ref:`mapped_class_load_events` + for a primer on how to invoke special logic when objects are loaded. + A class that includes an explicit ``__init__()`` method will maintain that method, and no default constructor will be applied. @@ -404,6 +411,99 @@ will also feature the default constructor associated with the :class:`_orm.regis constructor when they are mapped via the :meth:`_orm.registry.map_imperatively` method. +.. _mapped_class_load_events: + +Maintaining Non-Mapped State Across Loads +------------------------------------------ + +The ``__init__()`` method of the mapped class is invoked when the object +is constructed directly in Python code:: + + u1 = User(name="some name", fullname="some fullname") + +However, when an object is loaded using the ORM :class:`_orm.Session`, +the ``__init__()`` method is **not** called:: + + u1 = session.scalars(select(User).where(User.name == "some name")).first() + +The reason for this is that when loaded from the database, the operation +used to construct the object, in the above example the ``User``, is more +analogous to **deserialization**, such as unpickling, rather than initial +construction. The majority of the object's important state is not being +assembled for the first time, it's being re-loaded from database rows. + +Therefore to maintain state within the object that is not part of the data +that's stored to the database, such that this state is present when objects +are loaded as well as constructed, there are two general approaches detailed +below. + +1. Use Python descriptors like ``@property``, rather than state, to dynamically + compute attributes as needed. + + For simple attributes, this is the simplest approach and the least error prone. + For example if an object ``Point`` with ``Point.x`` and ``Point.y`` wanted + an attribute with the sum of these attributes:: + + class Point(Base): + __tablename__ = "point" + id: Mapped[int] = mapped_column(primary_key=True) + x: Mapped[int] + y: Mapped[int] + + @property + def x_plus_y(self): + return self.x + self.y + + An advantage of using dynamic descriptors is that the value is computed + every time, meaning it maintains the correct value as the underlying + attributes (``x`` and ``y`` in this case) might change. + + Other forms of the above pattern include Python standard library + `cached_property `_ + decorator (which is cached, and not re-computed each time), as well as SQLAlchemy's :class:`.hybrid_property` decorator which + allows for attributes that can work for SQL querying as well. + + +2. Establish state on-load using :meth:`.InstanceEvents.load`, and optionally + supplemental methods :meth:`.InstanceEvents.refresh` and :meth:`.InstanceEvents.refresh_flush`. + + These are event hooks that are invoked whenever the object is loaded + from the database, or when it is refreshed after being expired. Typically + only the :meth:`.InstanceEvents.load` is needed, since non-mapped local object + state is not affected by expiration operations. To revise the ``Point`` + example above looks like:: + + from sqlalchemy import event + + + class Point(Base): + __tablename__ = "point" + id: Mapped[int] = mapped_column(primary_key=True) + x: Mapped[int] + y: Mapped[int] + + def __init__(self, x, y, **kw): + super().__init__(x=x, y=y, **kw) + self.x_plus_y = x + y + + + @event.listens_for(Point, "load") + def receive_load(target, context): + target.x_plus_y = target.x + target.y + + If using the refresh events as well, the event hooks can be stacked on + top of one callable if needed, as:: + + @event.listens_for(Point, "load") + @event.listens_for(Point, "refresh") + @event.listens_for(Point, "refresh_flush") + def receive_load(target, context, attrs=None): + target.x_plus_y = target.x + target.y + + Above, the ``attrs`` attribute will be present for the ``refresh`` and + ``refresh_flush`` events and indicate a list of attribute names that are + being refreshed. + .. _orm_mapper_inspection: Runtime Introspection of Mapped classes, Instances and Mappers diff --git a/doc/build/orm/persistence_techniques.rst b/doc/build/orm/persistence_techniques.rst index 982f27ebdc6..a877fcd0e0e 100644 --- a/doc/build/orm/persistence_techniques.rst +++ b/doc/build/orm/persistence_techniques.rst @@ -37,7 +37,7 @@ from the database. The feature also has conditional support to work in conjunction with primary key columns. For backends that have RETURNING support -(including Oracle, SQL Server, MariaDB 10.5, SQLite 3.35) a +(including Oracle Database, SQL Server, MariaDB 10.5, SQLite 3.35) a SQL expression may be assigned to a primary key column as well. This allows both the SQL expression to be evaluated, as well as allows any server side triggers that modify the primary key value on INSERT, to be successfully @@ -90,7 +90,7 @@ This is most easily accomplished using the session = Session() # execute a string statement - result = session.execute("select * from table where id=:id", {"id": 7}) + result = session.execute(text("select * from table where id=:id"), {"id": 7}) # execute a SQL expression construct result = session.execute(select(mytable).where(mytable.c.id == 7)) @@ -274,7 +274,7 @@ answered are, 1. is this column part of the primary key or not, and 2. does the database support RETURNING or an equivalent, such as "OUTPUT inserted"; these are SQL phrases which return a server-generated value at the same time as the INSERT or UPDATE statement is invoked. RETURNING is currently supported -by PostgreSQL, Oracle, MariaDB 10.5, SQLite 3.35, and SQL Server. +by PostgreSQL, Oracle Database, MariaDB 10.5, SQLite 3.35, and SQL Server. Case 1: non primary key, RETURNING or equivalent is supported ------------------------------------------------------------- @@ -332,7 +332,7 @@ Case 2: Table includes trigger-generated values which are not compatible with RE The ``"auto"`` setting of :paramref:`_orm.Mapper.eager_defaults` means that a backend that supports RETURNING will usually make use of RETURNING with -INSERT statements in order to retreive newly generated default values. +INSERT statements in order to retrieve newly generated default values. However there are limitations of server-generated values that are generated using triggers, such that RETURNING can't be used: @@ -367,7 +367,7 @@ this looks like:: On SQL Server with the pyodbc driver, an INSERT for the above table will not use RETURNING and will use the SQL Server ``scope_identity()`` function -to retreive the newly generated primary key value: +to retrieve the newly generated primary key value: .. sourcecode:: sql @@ -438,7 +438,7 @@ PostgreSQL SERIAL, these types are handled automatically by the Core; databases include functions for fetching the "last inserted id" where RETURNING is not supported, and where RETURNING is supported SQLAlchemy will use that. -For example, using Oracle with a column marked as :class:`.Identity`, +For example, using Oracle Database with a column marked as :class:`.Identity`, RETURNING is used automatically to fetch the new primary key value:: class MyOracleModel(Base): @@ -447,7 +447,7 @@ RETURNING is used automatically to fetch the new primary key value:: id: Mapped[int] = mapped_column(Identity(), primary_key=True) data: Mapped[str] = mapped_column(String(50)) -The INSERT for a model as above on Oracle looks like: +The INSERT for a model as above on Oracle Database looks like: .. sourcecode:: sql @@ -460,7 +460,7 @@ place and the new value will be returned immediately. For non-integer values generated by server side functions or triggers, as well as for integer values that come from constructs outside the table itself, including explicit sequences and triggers, the server default generation must -be marked in the table metadata. Using Oracle as the example again, we can +be marked in the table metadata. Using Oracle Database as the example again, we can illustrate a similar table as above naming an explicit sequence using the :class:`.Sequence` construct:: @@ -470,7 +470,7 @@ illustrate a similar table as above naming an explicit sequence using the id: Mapped[int] = mapped_column(Sequence("my_oracle_seq"), primary_key=True) data: Mapped[str] = mapped_column(String(50)) -An INSERT for this version of the model on Oracle would look like: +An INSERT for this version of the model on Oracle Database would look like: .. sourcecode:: sql @@ -713,20 +713,16 @@ connections:: pass - class User(BaseA): - ... + class User(BaseA): ... - class Address(BaseA): - ... + class Address(BaseA): ... - class GameInfo(BaseB): - ... + class GameInfo(BaseB): ... - class GameStats(BaseB): - ... + class GameStats(BaseB): ... Session = sessionmaker() diff --git a/doc/build/orm/queryguide/api.rst b/doc/build/orm/queryguide/api.rst index 15301cbd003..fe4d6b02a49 100644 --- a/doc/build/orm/queryguide/api.rst +++ b/doc/build/orm/queryguide/api.rst @@ -111,6 +111,8 @@ a per-query basis. Options for which this apply include: * The :func:`_orm.with_loader_criteria` option +* The :func:`_orm.load_only` option to select what attributes to refresh + The ``populate_existing`` execution option is equvialent to the :meth:`_orm.Query.populate_existing` method in :term:`1.x style` ORM queries. diff --git a/doc/build/orm/queryguide/columns.rst b/doc/build/orm/queryguide/columns.rst index 93d0919ba56..ace6a63f4ce 100644 --- a/doc/build/orm/queryguide/columns.rst +++ b/doc/build/orm/queryguide/columns.rst @@ -595,7 +595,7 @@ by default not loadable:: ... sqlalchemy.exc.InvalidRequestError: 'Book.summary' is not available due to raiseload=True -Only by overridding their behavior at query time, typically using +Only by overriding their behavior at query time, typically using :func:`_orm.undefer` or :func:`_orm.undefer_group`, or less commonly :func:`_orm.defer`, may the attributes be loaded. The example below applies ``undefer('*')`` to undefer all attributes, also making use of diff --git a/doc/build/orm/queryguide/dml.rst b/doc/build/orm/queryguide/dml.rst index 967397f1ae9..91fe9e7741d 100644 --- a/doc/build/orm/queryguide/dml.rst +++ b/doc/build/orm/queryguide/dml.rst @@ -204,8 +204,8 @@ the operation will INSERT one row at a time:: .. _orm_queryguide_insert_heterogeneous_params: -Using Heterogenous Parameter Dictionaries -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using Heterogeneous Parameter Dictionaries +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. Setup code, not for display @@ -215,7 +215,7 @@ Using Heterogenous Parameter Dictionaries BEGIN (implicit)... The ORM bulk insert feature supports lists of parameter dictionaries that are -"heterogenous", which basically means "individual dictionaries can have different +"heterogeneous", which basically means "individual dictionaries can have different keys". When this condition is detected, the ORM will break up the parameter dictionaries into groups corresponding to each set of keys and batch accordingly into separate INSERT statements:: @@ -552,7 +552,7 @@ are not present: or other multi-table mappings are not supported, since that would require multiple INSERT statements. -* :ref:`Heterogenous parameter sets ` +* :ref:`Heterogeneous parameter sets ` are not supported - each element in the VALUES set must have the same columns. @@ -993,6 +993,52 @@ For a DELETE, an example of deleting rows based on criteria:: >>> session.connection() BEGIN (implicit)... +.. warning:: Please read the following section :ref:`orm_queryguide_update_delete_caveats` + for important notes regarding how the functionality of ORM-Enabled UPDATE and DELETE + diverges from that of ORM :term:`unit of work` features, such + as using the :meth:`_orm.Session.delete` method to delete individual objects. + + +.. _orm_queryguide_update_delete_caveats: + +Important Notes and Caveats for ORM-Enabled Update and Delete +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ORM-enabled UPDATE and DELETE features bypass ORM :term:`unit of work` +automation in favor of being able to emit a single UPDATE or DELETE statement +that matches multiple rows at once without complexity. + +* The operations do not offer in-Python cascading of relationships - it is + assumed that ON UPDATE CASCADE and/or ON DELETE CASCADE is configured for any + foreign key references which require it, otherwise the database may emit an + integrity violation if foreign key references are being enforced. See the + notes at :ref:`passive_deletes` for some examples. + +* After the UPDATE or DELETE, dependent objects in the :class:`.Session` which + were impacted by an ON UPDATE CASCADE or ON DELETE CASCADE on related tables, + particularly objects that refer to rows that have now been deleted, may still + reference those objects. This issue is resolved once the :class:`.Session` + is expired, which normally occurs upon :meth:`.Session.commit` or can be + forced by using :meth:`.Session.expire_all`. + +* ORM-enabled UPDATEs and DELETEs do not handle joined table inheritance + automatically. See the section :ref:`orm_queryguide_update_delete_joined_inh` + for notes on how to work with joined-inheritance mappings. + +* The WHERE criteria needed in order to limit the polymorphic identity to + specific subclasses for single-table-inheritance mappings **is included + automatically** . This only applies to a subclass mapper that has no table of + its own. + +* The :func:`_orm.with_loader_criteria` option **is supported** by ORM + update and delete operations; criteria here will be added to that of the UPDATE + or DELETE statement being emitted, as well as taken into account during the + "synchronize" process. + +* In order to intercept ORM-enabled UPDATE and DELETE operations with event + handlers, use the :meth:`_orm.SessionEvents.do_orm_execute` event. + + .. _orm_queryguide_update_delete_sync: diff --git a/doc/build/orm/queryguide/inheritance.rst b/doc/build/orm/queryguide/inheritance.rst index 136bed55a60..537d51ae59e 100644 --- a/doc/build/orm/queryguide/inheritance.rst +++ b/doc/build/orm/queryguide/inheritance.rst @@ -128,7 +128,7 @@ objects at once. This loader option works in a similar fashion as the SELECT statement against each sub-table for objects loaded in the hierarchy, using ``IN`` to query for additional rows based on primary key. -:func:`_orm.selectinload` accepts as its arguments the base entity that is +:func:`_orm.selectin_polymorphic` accepts as its arguments the base entity that is being queried, followed by a sequence of subclasses of that entity for which their specific attributes should be loaded for incoming rows:: diff --git a/doc/build/orm/queryguide/relationships.rst b/doc/build/orm/queryguide/relationships.rst index 30c8b1906fc..d63ae67ac74 100644 --- a/doc/build/orm/queryguide/relationships.rst +++ b/doc/build/orm/queryguide/relationships.rst @@ -828,10 +828,10 @@ will JOIN across all three tables to match rows from one side to the other. Things to know about this kind of loading include: * The strategy emits a SELECT for up to 500 parent primary key values at a - time, as the primary keys are rendered into a large IN expression in the - SQL statement. Some databases like Oracle have a hard limit on how large - an IN expression can be, and overall the size of the SQL string shouldn't - be arbitrarily large. + time, as the primary keys are rendered into a large IN expression in the SQL + statement. Some databases like Oracle Database have a hard limit on how + large an IN expression can be, and overall the size of the SQL string + shouldn't be arbitrarily large. * As "selectin" loading relies upon IN, for a mapping with composite primary keys, it must use the "tuple" form of IN, which looks like ``WHERE @@ -1001,8 +1001,7 @@ Wildcard Loading Strategies --------------------------- Each of :func:`_orm.joinedload`, :func:`.subqueryload`, :func:`.lazyload`, -:func:`.selectinload`, -:func:`.noload`, and :func:`.raiseload` can be used to set the default +:func:`.selectinload`, and :func:`.raiseload` can be used to set the default style of :func:`_orm.relationship` loading for a particular query, affecting all :func:`_orm.relationship` -mapped attributes not otherwise diff --git a/doc/build/orm/queryguide/select.rst b/doc/build/orm/queryguide/select.rst index 678565932dd..a8b273a62dc 100644 --- a/doc/build/orm/queryguide/select.rst +++ b/doc/build/orm/queryguide/select.rst @@ -360,7 +360,7 @@ Selecting Entities from Subqueries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :func:`_orm.aliased` construct discussed in the previous section -can be used with any :class:`_sql.Subuqery` construct that comes from a +can be used with any :class:`_sql.Subquery` construct that comes from a method such as :meth:`_sql.Select.subquery` to link ORM entities to the columns returned by that subquery; there must be a **column correspondence** relationship between the columns delivered by the subquery and the columns @@ -721,7 +721,7 @@ Joining to Subqueries ^^^^^^^^^^^^^^^^^^^^^ The target of a join may be any "selectable" entity which includes -subuqeries. When using the ORM, it is typical +subqueries. When using the ORM, it is typical that these targets are stated in terms of an :func:`_orm.aliased` construct, but this is not strictly required, particularly if the joined entity is not being returned in the results. For example, to join from the diff --git a/doc/build/orm/relationship_persistence.rst b/doc/build/orm/relationship_persistence.rst index 9a5a036c695..ba686d691d1 100644 --- a/doc/build/orm/relationship_persistence.rst +++ b/doc/build/orm/relationship_persistence.rst @@ -35,12 +35,13 @@ Or: 1 'somewidget' 5 5 'someentry' 1 In the first case, a row points to itself. Technically, a database that uses -sequences such as PostgreSQL or Oracle can INSERT the row at once using a -previously generated value, but databases which rely upon autoincrement-style -primary key identifiers cannot. The :func:`~sqlalchemy.orm.relationship` -always assumes a "parent/child" model of row population during flush, so -unless you are populating the primary key/foreign key columns directly, -:func:`~sqlalchemy.orm.relationship` needs to use two statements. +sequences such as PostgreSQL or Oracle Database can INSERT the row at once +using a previously generated value, but databases which rely upon +autoincrement-style primary key identifiers cannot. The +:func:`~sqlalchemy.orm.relationship` always assumes a "parent/child" model of +row population during flush, so unless you are populating the primary +key/foreign key columns directly, :func:`~sqlalchemy.orm.relationship` needs to +use two statements. In the second case, the "widget" row must be inserted before any referring "entry" rows, but then the "favorite_entry_id" column of that "widget" row @@ -243,7 +244,7 @@ by emitting an UPDATE statement against foreign key columns that immediately reference a primary key column whose value has changed. The primary platforms without referential integrity features are MySQL when the ``MyISAM`` storage engine is used, and SQLite when the -``PRAGMA foreign_keys=ON`` pragma is not used. The Oracle database also +``PRAGMA foreign_keys=ON`` pragma is not used. Oracle Database also has no support for ``ON UPDATE CASCADE``, but because it still enforces referential integrity, needs constraints to be marked as deferrable so that SQLAlchemy can emit UPDATE statements. @@ -297,7 +298,7 @@ Key limitations of ``passive_updates=False`` include: map for objects that may be referencing the one with a mutating primary key, not throughout the database. -As virtually all databases other than Oracle now support ``ON UPDATE CASCADE``, -it is highly recommended that traditional ``ON UPDATE CASCADE`` support be used -in the case that natural and mutable primary key values are in use. - +As virtually all databases other than Oracle Database now support ``ON UPDATE +CASCADE``, it is highly recommended that traditional ``ON UPDATE CASCADE`` +support be used in the case that natural and mutable primary key values are in +use. diff --git a/doc/build/orm/session_basics.rst b/doc/build/orm/session_basics.rst index 0fcbf7900b1..0c04e34b2ed 100644 --- a/doc/build/orm/session_basics.rst +++ b/doc/build/orm/session_basics.rst @@ -15,12 +15,15 @@ ORM-mapped objects. The ORM objects themselves are maintained inside the structure that maintains unique copies of each object, where "unique" means "only one object with a particular primary key". -The :class:`.Session` begins in a mostly stateless form. Once queries are -issued or other objects are persisted with it, it requests a connection -resource from an :class:`_engine.Engine` that is associated with the -:class:`.Session`, and then establishes a transaction on that connection. This -transaction remains in effect until the :class:`.Session` is instructed to -commit or roll back the transaction. +The :class:`.Session` in its most common pattern of use begins in a mostly +stateless form. Once queries are issued or other objects are persisted with it, +it requests a connection resource from an :class:`_engine.Engine` that is +associated with the :class:`.Session`, and then establishes a transaction on +that connection. This transaction remains in effect until the :class:`.Session` +is instructed to commit or roll back the transaction. When the transaction +ends, the connection resource associated with the :class:`_engine.Engine` +is :term:`released` to the connection pool managed by the engine. A new +transaction then starts with a new connection checkout. The ORM objects maintained by a :class:`_orm.Session` are :term:`instrumented` such that whenever an attribute or a collection is modified in the Python @@ -151,7 +154,7 @@ The purpose of :class:`_orm.sessionmaker` is to provide a factory for :class:`_orm.Session` objects with a fixed configuration. As it is typical that an application will have an :class:`_engine.Engine` object in module scope, the :class:`_orm.sessionmaker` can provide a factory for -:class:`_orm.Session` objects that are against this engine:: +:class:`_orm.Session` objects that are constructed against this engine:: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -643,8 +646,26 @@ connections. If no pending changes are detected, then no SQL is emitted to the database. This behavior is not configurable and is not affected by the :paramref:`.Session.autoflush` parameter. -Subsequent to that, :meth:`_orm.Session.commit` will then COMMIT the actual -database transaction or transactions, if any, that are in place. +Subsequent to that, assuming the :class:`_orm.Session` is bound to an +:class:`_engine.Engine`, :meth:`_orm.Session.commit` will then COMMIT the +actual database transaction that is in place, if one was started. After the +commit, the :class:`_engine.Connection` object associated with that transaction +is closed, causing its underlying DBAPI connection to be :term:`released` back +to the connection pool associated with the :class:`_engine.Engine` to which the +:class:`_orm.Session` is bound. + +For a :class:`_orm.Session` that's bound to multiple engines (e.g. as described +at :ref:`Partitioning Strategies `), the same COMMIT +steps will proceed for each :class:`_engine.Engine` / +:class:`_engine.Connection` that is in play within the "logical" transaction +being committed. These database transactions are uncoordinated with each other +unless :ref:`two-phase features ` are enabled. + +Other connection-interaction patterns are available as well, by binding the +:class:`_orm.Session` to a :class:`_engine.Connection` directly; in this case, +it's assumed that an externally-managed transaction is present, and a real +COMMIT will not be emitted automatically in this case; see the section +:ref:`session_external_transaction` for background on this pattern. Finally, all objects within the :class:`_orm.Session` are :term:`expired` as the transaction is closed out. This is so that when the instances are next @@ -671,9 +692,25 @@ been begun either via :ref:`autobegin ` or by calling the :meth:`_orm.Session.begin` method explicitly, is as follows: - * All transactions are rolled back and all connections returned to the - connection pool, unless the Session was bound directly to a Connection, in - which case the connection is still maintained (but still rolled back). + * Database transactions are rolled back. For a :class:`_orm.Session` + bound to a single :class:`_engine.Engine`, this means ROLLBACK is emitted + for at most a single :class:`_engine.Connection` that's currently in use. + For :class:`_orm.Session` objects bound to multiple :class:`_engine.Engine` + objects, ROLLBACK is emitted for all :class:`_engine.Connection` objects + that were checked out. + * Database connections are :term:`released`. This follows the same connection-related + behavior noted in :ref:`session_committing`, where + :class:`_engine.Connection` objects obtained from :class:`_engine.Engine` + objects are closed, causing the DBAPI connections to be :term:`released` to + the connection pool within the :class:`_engine.Engine`. New connections + are checked out from the :class:`_engine.Engine` if and when a new + transaction begins. + * For a :class:`_orm.Session` + that's bound directly to a :class:`_engine.Connection` as described + at :ref:`session_external_transaction`, rollback behavior on this + :class:`_engine.Connection` would follow the behavior specified by the + :paramref:`_orm.Session.join_transaction_mode` parameter, which could + involve rolling back savepoints or emitting a real ROLLBACK. * Objects which were initially in the :term:`pending` state when they were added to the :class:`~sqlalchemy.orm.session.Session` within the lifespan of the transaction are expunged, corresponding to their INSERT statement being diff --git a/doc/build/orm/session_transaction.rst b/doc/build/orm/session_transaction.rst index 10da76eda80..55ade3e5326 100644 --- a/doc/build/orm/session_transaction.rst +++ b/doc/build/orm/session_transaction.rst @@ -60,7 +60,7 @@ or rolled back:: session.commit() # commits # will automatically begin again - result = session.execute("< some select statement >") + result = session.execute(text("< some select statement >")) session.add_all([more_objects, ...]) session.commit() # commits @@ -100,7 +100,7 @@ first:: session.commit() # commits - result = session.execute("") + result = session.execute(text("")) # remaining transactional state from the .execute() call is # discarded @@ -529,8 +529,8 @@ used in a read-only fashion**, that is:: with autocommit_session() as session: - some_objects = session.execute("") - some_other_objects = session.execute("") + some_objects = session.execute(text("")) + some_other_objects = session.execute(text("")) # closes connection diff --git a/doc/build/orm/versioning.rst b/doc/build/orm/versioning.rst index 87865917cdf..9c08acef682 100644 --- a/doc/build/orm/versioning.rst +++ b/doc/build/orm/versioning.rst @@ -207,7 +207,8 @@ missed version counters: It is *strongly recommended* that server side version counters only be used when absolutely necessary and only on backends that support :term:`RETURNING`, -currently PostgreSQL, Oracle, MariaDB 10.5, SQLite 3.35, and SQL Server. +currently PostgreSQL, Oracle Database, MariaDB 10.5, SQLite 3.35, and SQL +Server. Programmatic or Conditional Version Counters @@ -232,14 +233,14 @@ at our choosing:: __mapper_args__ = {"version_id_col": version_uuid, "version_id_generator": False} - u1 = User(name="u1", version_uuid=uuid.uuid4()) + u1 = User(name="u1", version_uuid=uuid.uuid4().hex) session.add(u1) session.commit() u1.name = "u2" - u1.version_uuid = uuid.uuid4() + u1.version_uuid = uuid.uuid4().hex session.commit() diff --git a/doc/build/tutorial/data_select.rst b/doc/build/tutorial/data_select.rst index ffeb9dfdb65..5052a5bae32 100644 --- a/doc/build/tutorial/data_select.rst +++ b/doc/build/tutorial/data_select.rst @@ -130,7 +130,7 @@ for a :func:`_sql.select` by using a tuple of string names:: FROM user_account .. versionadded:: 2.0 Added tuple-accessor capability to the - :attr`.FromClause.c` collection + :attr:`.FromClause.c` collection .. _tutorial_selecting_orm_entities: @@ -447,7 +447,7 @@ explicitly:: FROM user_account JOIN address ON user_account.id = address.user_id -The other is the the :meth:`_sql.Select.join` method, which indicates only the +The other is the :meth:`_sql.Select.join` method, which indicates only the right side of the JOIN, the left hand-side is inferred:: >>> print(select(user_table.c.name, address_table.c.email_address).join(address_table)) @@ -1124,7 +1124,7 @@ When using :meth:`_expression.Select.lateral`, the behavior of UNION, UNION ALL and other set operations ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In SQL,SELECT statements can be merged together using the UNION or UNION ALL +In SQL, SELECT statements can be merged together using the UNION or UNION ALL SQL operation, which produces the set of all rows produced by one or more statements together. Other set operations such as INTERSECT [ALL] and EXCEPT [ALL] are also possible. @@ -1387,8 +1387,8 @@ At the same time, a relatively small set of extremely common SQL functions such as :class:`_functions.count`, :class:`_functions.now`, :class:`_functions.max`, :class:`_functions.concat` include pre-packaged versions of themselves which provide for proper typing information as well as backend-specific SQL -generation in some cases. The example below contrasts the SQL generation -that occurs for the PostgreSQL dialect compared to the Oracle dialect for +generation in some cases. The example below contrasts the SQL generation that +occurs for the PostgreSQL dialect compared to the Oracle Database dialect for the :class:`_functions.now` function:: >>> from sqlalchemy.dialects import postgresql @@ -1410,11 +1410,18 @@ as opposed to the "return type" of a Python function. The SQL return type of any SQL function may be accessed, typically for debugging purposes, by referring to the :attr:`_functions.Function.type` -attribute:: +attribute; this will be pre-configured for a **select few** of extremely +common SQL functions, but for most SQL functions is the "null" datatype +if not otherwise specified:: + >>> # pre-configured SQL function (only a few dozen of these) >>> func.now().type DateTime() + >>> # arbitrary SQL function (all other SQL functions) + >>> func.run_some_calculation().type + NullType() + These SQL return types are significant when making use of the function expression in the context of a larger expression; that is, math operators will work better when the datatype of the expression is @@ -1676,10 +1683,10 @@ Table-Valued Functions Table-valued SQL functions support a scalar representation that contains named sub-elements. Often used for JSON and ARRAY-oriented functions as well as functions like ``generate_series()``, the table-valued function is specified in -the FROM clause, and is then referenced as a table, or sometimes even as -a column. Functions of this form are prominent within the PostgreSQL database, +the FROM clause, and is then referenced as a table, or sometimes even as a +column. Functions of this form are prominent within the PostgreSQL database, however some forms of table valued functions are also supported by SQLite, -Oracle, and SQL Server. +Oracle Database, and SQL Server. .. seealso:: @@ -1728,9 +1735,9 @@ towards as ``value``, and then selected two of its three rows. Column Valued Functions - Table Valued Function as a Scalar Column ################################################################## -A special syntax supported by PostgreSQL and Oracle is that of referring -towards a function in the FROM clause, which then delivers itself as a -single column in the columns clause of a SELECT statement or other column +A special syntax supported by PostgreSQL and Oracle Database is that of +referring towards a function in the FROM clause, which then delivers itself as +a single column in the columns clause of a SELECT statement or other column expression context. PostgreSQL makes great use of this syntax for such functions as ``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``, ``json_each()``, etc. @@ -1745,8 +1752,8 @@ to a :class:`_functions.Function` construct:: {printsql}SELECT x FROM json_array_elements(:json_array_elements_1) AS x -The "column valued" form is also supported by the Oracle dialect, where -it is usable for custom SQL functions:: +The "column valued" form is also supported by the Oracle Database dialects, +where it is usable for custom SQL functions:: >>> from sqlalchemy.dialects import oracle >>> stmt = select(func.scalar_strings(5).column_valued("s")) diff --git a/doc/build/tutorial/data_update.rst b/doc/build/tutorial/data_update.rst index a82f070a3f6..e32b6676c76 100644 --- a/doc/build/tutorial/data_update.rst +++ b/doc/build/tutorial/data_update.rst @@ -279,17 +279,24 @@ Facts about :attr:`_engine.CursorResult.rowcount`: the statement. It does not matter if the row were actually modified or not. * :attr:`_engine.CursorResult.rowcount` is not necessarily available for an UPDATE - or DELETE statement that uses RETURNING. + or DELETE statement that uses RETURNING, or for one that uses an + :ref:`executemany ` execution. The availability + depends on the DBAPI module in use. -* For an :ref:`executemany ` execution, - :attr:`_engine.CursorResult.rowcount` may not be available either, which depends - highly on the DBAPI module in use as well as configured options. The - attribute :attr:`_engine.CursorResult.supports_sane_multi_rowcount` indicates - if this value will be available for the current backend in use. +* In any case where the DBAPI does not determine the rowcount for some type + of statement, the returned value will be ``-1``. + +* SQLAlchemy pre-memoizes the DBAPIs ``cursor.rowcount`` value before the cursor + is closed, as some DBAPIs don't support accessing this attribute after the + fact. In order to pre-memoize ``cursor.rowcount`` for a statement that is + not UPDATE or DELETE, such as INSERT or SELECT, the + :paramref:`_engine.Connection.execution_options.preserve_rowcount` execution + option may be used. * Some drivers, particularly third party dialects for non-relational databases, may not support :attr:`_engine.CursorResult.rowcount` at all. The - :attr:`_engine.CursorResult.supports_sane_rowcount` will indicate this. + :attr:`_engine.CursorResult.supports_sane_rowcount` cursor attribute will + indicate this. * "rowcount" is used by the ORM :term:`unit of work` process to validate that an UPDATE or DELETE statement matched the expected number of rows, and is diff --git a/doc/build/tutorial/dbapi_transactions.rst b/doc/build/tutorial/dbapi_transactions.rst index ade14eb4fb3..5525acfe510 100644 --- a/doc/build/tutorial/dbapi_transactions.rst +++ b/doc/build/tutorial/dbapi_transactions.rst @@ -11,32 +11,32 @@ Working with Transactions and the DBAPI -With the :class:`_engine.Engine` object ready to go, we may now proceed -to dive into the basic operation of an :class:`_engine.Engine` and -its primary interactive endpoints, the :class:`_engine.Connection` and -:class:`_engine.Result`. We will additionally introduce the ORM's -:term:`facade` for these objects, known as the :class:`_orm.Session`. +With the :class:`_engine.Engine` object ready to go, we can +dive into the basic operation of an :class:`_engine.Engine` and +its primary endpoints, the :class:`_engine.Connection` and +:class:`_engine.Result`. We'll also introduce the ORM's :term:`facade` +for these objects, known as the :class:`_orm.Session`. .. container:: orm-header **Note to ORM readers** - When using the ORM, the :class:`_engine.Engine` is managed by another - object called the :class:`_orm.Session`. The :class:`_orm.Session` in - modern SQLAlchemy emphasizes a transactional and SQL execution pattern that - is largely identical to that of the :class:`_engine.Connection` discussed - below, so while this subsection is Core-centric, all of the concepts here - are essentially relevant to ORM use as well and is recommended for all ORM + When using the ORM, the :class:`_engine.Engine` is managed by the + :class:`_orm.Session`. The :class:`_orm.Session` in modern SQLAlchemy + emphasizes a transactional and SQL execution pattern that is largely + identical to that of the :class:`_engine.Connection` discussed below, + so while this subsection is Core-centric, all of the concepts here + are relevant to ORM use as well and is recommended for all ORM learners. The execution pattern used by the :class:`_engine.Connection` - will be contrasted with that of the :class:`_orm.Session` at the end + will be compared to the :class:`_orm.Session` at the end of this section. As we have yet to introduce the SQLAlchemy Expression Language that is the -primary feature of SQLAlchemy, we will make use of one simple construct within -this package called the :func:`_sql.text` construct, which allows us to write -SQL statements as **textual SQL**. Rest assured that textual SQL in -day-to-day SQLAlchemy use is by far the exception rather than the rule for most -tasks, even though it always remains fully available. +primary feature of SQLAlchemy, we'll use a simple construct within +this package called the :func:`_sql.text` construct, to write +SQL statements as **textual SQL**. Rest assured that textual SQL is the +exception rather than the rule in day-to-day SQLAlchemy use, but it's +always available. .. rst-class:: core-header @@ -45,17 +45,15 @@ tasks, even though it always remains fully available. Getting a Connection --------------------- -The sole purpose of the :class:`_engine.Engine` object from a user-facing -perspective is to provide a unit of -connectivity to the database called the :class:`_engine.Connection`. When -working with the Core directly, the :class:`_engine.Connection` object -is how all interaction with the database is done. As the :class:`_engine.Connection` -represents an open resource against the database, we want to always limit -the scope of our use of this object to a specific context, and the best -way to do that is by using Python context manager form, also known as -`the with statement `_. -Below we illustrate "Hello World", using a textual SQL statement. Textual -SQL is emitted using a construct called :func:`_sql.text` that will be discussed +The purpose of the :class:`_engine.Engine` is to connect to the database by +providing a :class:`_engine.Connection` object. When working with the Core +directly, the :class:`_engine.Connection` object is how all interaction with the +database is done. Because the :class:`_engine.Connection` creates an open +resource against the database, we want to limit our use of this object to a +specific context. The best way to do that is with a Python context manager, also +known as `the with statement `_. +Below we use a textual SQL statement to show "Hello World". Textual SQL is +created with a construct called :func:`_sql.text` which we'll discuss in more detail later: .. sourcecode:: pycon+sql @@ -71,21 +69,21 @@ in more detail later: {stop}[('hello world',)] {execsql}ROLLBACK{stop} -In the above example, the context manager provided for a database connection -and also framed the operation inside of a transaction. The default behavior of -the Python DBAPI includes that a transaction is always in progress; when the -scope of the connection is :term:`released`, a ROLLBACK is emitted to end the -transaction. The transaction is **not committed automatically**; when we want -to commit data we normally need to call :meth:`_engine.Connection.commit` +In the example above, the context manager creates a database connection +and executes the operation in a transaction. The default behavior of +the Python DBAPI is that a transaction is always in progress; when the +connection is :term:`released`, a ROLLBACK is emitted to end the +transaction. The transaction is **not committed automatically**; if we want +to commit data we need to call :meth:`_engine.Connection.commit` as we'll see in the next section. .. tip:: "autocommit" mode is available for special cases. The section :ref:`dbapi_autocommit` discusses this. -The result of our SELECT was also returned in an object called -:class:`_engine.Result` that will be discussed later, however for the moment -we'll add that it's best to ensure this object is consumed within the -"connect" block, and is not passed along outside of the scope of our connection. +The result of our SELECT was returned in an object called +:class:`_engine.Result` that will be discussed later. For the moment +we'll add that it's best to use this object within the "connect" block, +and to not use it outside of the scope of our connection. .. rst-class:: core-header @@ -94,11 +92,11 @@ we'll add that it's best to ensure this object is consumed within the Committing Changes ------------------ -We just learned that the DBAPI connection is non-autocommitting. What if -we want to commit some data? We can alter our above example to create a -table and insert some data, and the transaction is then committed using -the :meth:`_engine.Connection.commit` method, invoked **inside** the block -where we acquired the :class:`_engine.Connection` object: +We just learned that the DBAPI connection doesn't commit automatically. +What if we want to commit some data? We can change our example above to create a +table, insert some data and then commit the transaction using +the :meth:`_engine.Connection.commit` method, **inside** the block +where we have the :class:`_engine.Connection` object: .. sourcecode:: pycon+sql @@ -119,24 +117,22 @@ where we acquired the :class:`_engine.Connection` object: COMMIT -Above, we emitted two SQL statements that are generally transactional, a -"CREATE TABLE" statement [1]_ and an "INSERT" statement that's parameterized -(the parameterization syntax above is discussed a few sections below in -:ref:`tutorial_multiple_parameters`). As we want the work we've done to be -committed within our block, we invoke the +Above, we execute two SQL statements, a "CREATE TABLE" statement [1]_ +and an "INSERT" statement that's parameterized (we discuss the parameterization syntax +later in :ref:`tutorial_multiple_parameters`). +To commit the work we've done in our block, we call the :meth:`_engine.Connection.commit` method which commits the transaction. After -we call this method inside the block, we can continue to run more SQL -statements and if we choose we may call :meth:`_engine.Connection.commit` -again for subsequent statements. SQLAlchemy refers to this style as **commit as +this, we can continue to run more SQL statements and call :meth:`_engine.Connection.commit` +again for those statements. SQLAlchemy refers to this style as **commit as you go**. -There is also another style of committing data, which is that we can declare -our "connect" block to be a transaction block up front. For this mode of -operation, we use the :meth:`_engine.Engine.begin` method to acquire the -connection, rather than the :meth:`_engine.Engine.connect` method. This method -will both manage the scope of the :class:`_engine.Connection` and also -enclose everything inside of a transaction with COMMIT at the end, assuming -a successful block, or ROLLBACK in case of exception raise. This style +There's also another style to commit data. We can declare +our "connect" block to be a transaction block up front. To do this, we use the +:meth:`_engine.Engine.begin` method to get the connection, rather than the +:meth:`_engine.Engine.connect` method. This method +will manage the scope of the :class:`_engine.Connection` and also +enclose everything inside of a transaction with either a COMMIT at the end +if the block was successful, or a ROLLBACK if an exception was raised. This style is known as **begin once**: .. sourcecode:: pycon+sql @@ -153,9 +149,9 @@ is known as **begin once**: COMMIT -"Begin once" style is often preferred as it is more succinct and indicates the -intention of the entire block up front. However, within this tutorial we will -normally use "commit as you go" style as it is more flexible for demonstration +You should mostly prefer the "begin once" style because it's shorter and shows the +intention of the entire block up front. However, in this tutorial we'll +use "commit as you go" style as it's more flexible for demonstration purposes. .. topic:: What's "BEGIN (implicit)"? @@ -169,8 +165,8 @@ purposes. .. [1] :term:`DDL` refers to the subset of SQL that instructs the database to create, modify, or remove schema-level constructs such as tables. DDL - such as "CREATE TABLE" is recommended to be within a transaction block that - ends with COMMIT, as many databases uses transactional DDL such that the + such as "CREATE TABLE" should be in a transaction block that + ends with COMMIT, as many databases use transactional DDL such that the schema changes don't take place until the transaction is committed. However, as we'll see later, we usually let SQLAlchemy run DDL sequences for us as part of a higher level operation where we don't generally need to worry diff --git a/doc/build/tutorial/orm_data_manipulation.rst b/doc/build/tutorial/orm_data_manipulation.rst index 73fef50aba3..9329d205245 100644 --- a/doc/build/tutorial/orm_data_manipulation.rst +++ b/doc/build/tutorial/orm_data_manipulation.rst @@ -157,7 +157,7 @@ Another effect of the INSERT that occurred was that the ORM has retrieved the new primary key identifiers for each new object; internally it normally uses the same :attr:`_engine.CursorResult.inserted_primary_key` accessor we introduced previously. The ``squidward`` and ``krabs`` objects now have these new -primary key identifiers associated with them and we can view them by acesssing +primary key identifiers associated with them and we can view them by accessing the ``id`` attribute:: >>> squidward.id @@ -533,6 +533,7 @@ a context manager as well, accomplishes the following things: are no longer associated with any database transaction in which to be refreshed:: + # note that 'squidward.name' was just expired previously, so its value is unloaded >>> squidward.name Traceback (most recent call last): ... diff --git a/examples/adjacency_list/__init__.py b/examples/adjacency_list/__init__.py index 65ce311e6de..b029e421b93 100644 --- a/examples/adjacency_list/__init__.py +++ b/examples/adjacency_list/__init__.py @@ -4,9 +4,9 @@ E.g.:: - node = TreeNode('rootnode') - node.append('node1') - node.append('node3') + node = TreeNode("rootnode") + node.append("node1") + node.append("node3") session.add(node) session.commit() diff --git a/examples/association/basic_association.py b/examples/association/basic_association.py index d2271ad430e..7a5b46097e3 100644 --- a/examples/association/basic_association.py +++ b/examples/association/basic_association.py @@ -105,7 +105,7 @@ def __init__(self, item, price=None): ) # print customers who bought 'MySQL Crowbar' on sale - q = session.query(Order).join("order_items", "item") + q = session.query(Order).join(OrderItem).join(Item) q = q.filter( and_(Item.description == "MySQL Crowbar", Item.price > OrderItem.price) ) diff --git a/examples/association/proxied_association.py b/examples/association/proxied_association.py index 0ec8fa899ac..65dcd6c0b66 100644 --- a/examples/association/proxied_association.py +++ b/examples/association/proxied_association.py @@ -112,7 +112,8 @@ def __init__(self, item, price=None): # print customers who bought 'MySQL Crowbar' on sale orders = ( session.query(Order) - .join("order_items", "item") + .join(OrderItem) + .join(Item) .filter(Item.description == "MySQL Crowbar") .filter(Item.price > OrderItem.price) ) diff --git a/examples/asyncio/async_orm.py b/examples/asyncio/async_orm.py index 592323be429..daf810c65d2 100644 --- a/examples/asyncio/async_orm.py +++ b/examples/asyncio/async_orm.py @@ -2,6 +2,7 @@ for asynchronous ORM use. """ + from __future__ import annotations import asyncio diff --git a/examples/asyncio/async_orm_writeonly.py b/examples/asyncio/async_orm_writeonly.py index 263c0d29198..8ddc0ecdb23 100644 --- a/examples/asyncio/async_orm_writeonly.py +++ b/examples/asyncio/async_orm_writeonly.py @@ -2,6 +2,7 @@ of ORM collections under asyncio. """ + from __future__ import annotations import asyncio diff --git a/examples/asyncio/basic.py b/examples/asyncio/basic.py index 6cfa9ed0144..5994fc765e7 100644 --- a/examples/asyncio/basic.py +++ b/examples/asyncio/basic.py @@ -6,7 +6,6 @@ """ - import asyncio from sqlalchemy import Column diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py index aa9ea7a6899..da22ee3276c 100644 --- a/examples/custom_attributes/custom_management.py +++ b/examples/custom_attributes/custom_management.py @@ -9,6 +9,7 @@ """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/dogpile_caching/__init__.py b/examples/dogpile_caching/__init__.py index f8c1bb582bc..7fd6dba7217 100644 --- a/examples/dogpile_caching/__init__.py +++ b/examples/dogpile_caching/__init__.py @@ -44,13 +44,13 @@ The demo scripts themselves, in order of complexity, are run as Python modules so that relative imports work:: - python -m examples.dogpile_caching.helloworld + $ python -m examples.dogpile_caching.helloworld - python -m examples.dogpile_caching.relationship_caching + $ python -m examples.dogpile_caching.relationship_caching - python -m examples.dogpile_caching.advanced + $ python -m examples.dogpile_caching.advanced - python -m examples.dogpile_caching.local_session_caching + $ python -m examples.dogpile_caching.local_session_caching .. autosource:: :files: environment.py, caching_query.py, model.py, fixture_data.py, \ diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index b1848631565..8c85d74811c 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -19,6 +19,7 @@ dogpile.cache constructs. """ + from dogpile.cache.api import NO_VALUE from sqlalchemy import event @@ -28,7 +29,6 @@ class ORMCache: - """An add-on for an ORM :class:`.Session` optionally loads full results from a dogpile cache region. diff --git a/examples/dogpile_caching/environment.py b/examples/dogpile_caching/environment.py index 4b5a317917b..4962826280a 100644 --- a/examples/dogpile_caching/environment.py +++ b/examples/dogpile_caching/environment.py @@ -2,6 +2,7 @@ bootstrap fixture data if necessary. """ + from hashlib import md5 import os diff --git a/examples/dogpile_caching/fixture_data.py b/examples/dogpile_caching/fixture_data.py index 8387a2cb275..775fb63b1a8 100644 --- a/examples/dogpile_caching/fixture_data.py +++ b/examples/dogpile_caching/fixture_data.py @@ -3,6 +3,7 @@ with a randomly selected postal code. """ + import random from .environment import Base diff --git a/examples/dogpile_caching/helloworld.py b/examples/dogpile_caching/helloworld.py index 01934c59fab..df1c2a318ef 100644 --- a/examples/dogpile_caching/helloworld.py +++ b/examples/dogpile_caching/helloworld.py @@ -1,6 +1,4 @@ -"""Illustrate how to load some data, and cache the results. - -""" +"""Illustrate how to load some data, and cache the results.""" from sqlalchemy import select from .caching_query import FromCache diff --git a/examples/dogpile_caching/model.py b/examples/dogpile_caching/model.py index cae2ae27762..926a5fa5d68 100644 --- a/examples/dogpile_caching/model.py +++ b/examples/dogpile_caching/model.py @@ -7,6 +7,7 @@ City --(has a)--> Country """ + from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy import Integer diff --git a/examples/dogpile_caching/relationship_caching.py b/examples/dogpile_caching/relationship_caching.py index 058d5522259..a5b654b06c8 100644 --- a/examples/dogpile_caching/relationship_caching.py +++ b/examples/dogpile_caching/relationship_caching.py @@ -6,6 +6,7 @@ term cache. """ + import os from sqlalchemy import select diff --git a/examples/dynamic_dict/__init__.py b/examples/dynamic_dict/__init__.py index ed31df062fb..c1d52d3c430 100644 --- a/examples/dynamic_dict/__init__.py +++ b/examples/dynamic_dict/__init__.py @@ -1,4 +1,4 @@ -""" Illustrates how to place a dictionary-like facade on top of a +"""Illustrates how to place a dictionary-like facade on top of a "dynamic" relation, so that dictionary operations (assuming simple string keys) can operate upon a large collection without loading the full collection at once. diff --git a/examples/extending_query/temporal_range.py b/examples/extending_query/temporal_range.py index 50cbb664591..29ea1193623 100644 --- a/examples/extending_query/temporal_range.py +++ b/examples/extending_query/temporal_range.py @@ -5,6 +5,7 @@ """ import datetime +from functools import partial from sqlalchemy import Column from sqlalchemy import create_engine @@ -23,7 +24,9 @@ class HasTemporal: """Mixin that identifies a class as having a timestamp column""" timestamp = Column( - DateTime, default=datetime.datetime.utcnow, nullable=False + DateTime, + default=partial(datetime.datetime.now, datetime.timezone.utc), + nullable=False, ) diff --git a/examples/generic_associations/discriminator_on_association.py b/examples/generic_associations/discriminator_on_association.py index f0f1d7ed99c..93c1b29ef98 100644 --- a/examples/generic_associations/discriminator_on_association.py +++ b/examples/generic_associations/discriminator_on_association.py @@ -15,6 +15,7 @@ objects, but is also slightly more complex. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/generic_associations/generic_fk.py b/examples/generic_associations/generic_fk.py index 5c70f93aac5..d45166d333f 100644 --- a/examples/generic_associations/generic_fk.py +++ b/examples/generic_associations/generic_fk.py @@ -17,6 +17,7 @@ or "table_per_association" instead of this approach. """ + from sqlalchemy import and_ from sqlalchemy import Column from sqlalchemy import create_engine diff --git a/examples/generic_associations/table_per_association.py b/examples/generic_associations/table_per_association.py index 2e412869f08..04786bd49be 100644 --- a/examples/generic_associations/table_per_association.py +++ b/examples/generic_associations/table_per_association.py @@ -11,6 +11,7 @@ """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/generic_associations/table_per_related.py b/examples/generic_associations/table_per_related.py index 5b83e6e68f3..23c75b0b9d6 100644 --- a/examples/generic_associations/table_per_related.py +++ b/examples/generic_associations/table_per_related.py @@ -16,6 +16,7 @@ is completely automated. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/inheritance/concrete.py b/examples/inheritance/concrete.py index f7f6b3ac641..e718e2fc350 100644 --- a/examples/inheritance/concrete.py +++ b/examples/inheritance/concrete.py @@ -1,4 +1,5 @@ """Concrete-table (table-per-class) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/inheritance/joined.py b/examples/inheritance/joined.py index 7dee935fab2..c2ba6942cc8 100644 --- a/examples/inheritance/joined.py +++ b/examples/inheritance/joined.py @@ -1,4 +1,5 @@ """Joined-table (table-per-subclass) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/inheritance/single.py b/examples/inheritance/single.py index 8da75dd7c45..6337bb4b2e4 100644 --- a/examples/inheritance/single.py +++ b/examples/inheritance/single.py @@ -1,4 +1,5 @@ """Single-table (table-per-hierarchy) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/materialized_paths/materialized_paths.py b/examples/materialized_paths/materialized_paths.py index f458270c726..19d3ed491c1 100644 --- a/examples/materialized_paths/materialized_paths.py +++ b/examples/materialized_paths/materialized_paths.py @@ -26,6 +26,7 @@ descendants and changing the prefix. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import func diff --git a/examples/nested_sets/__init__.py b/examples/nested_sets/__init__.py index 5fdfbcedc08..cacab411b9a 100644 --- a/examples/nested_sets/__init__.py +++ b/examples/nested_sets/__init__.py @@ -1,4 +1,4 @@ -""" Illustrates a rudimentary way to implement the "nested sets" +"""Illustrates a rudimentary way to implement the "nested sets" pattern for hierarchical data using the SQLAlchemy ORM. .. autosource:: diff --git a/examples/performance/__init__.py b/examples/performance/__init__.py index 7e24b9b8fdd..3854fdbea52 100644 --- a/examples/performance/__init__.py +++ b/examples/performance/__init__.py @@ -129,15 +129,15 @@ class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) children = relationship("Child") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) - parent_id = Column(Integer, ForeignKey('parent.id')) + parent_id = Column(Integer, ForeignKey("parent.id")) # Init with name of file, default number of items @@ -152,10 +152,12 @@ def setup_once(dburl, echo, num): Base.metadata.drop_all(engine) Base.metadata.create_all(engine) sess = Session(engine) - sess.add_all([ - Parent(children=[Child() for j in range(100)]) - for i in range(num) - ]) + sess.add_all( + [ + Parent(children=[Child() for j in range(100)]) + for i in range(num) + ] + ) sess.commit() @@ -191,7 +193,8 @@ def test_subqueryload(n): for parent in session.query(Parent).options(subqueryload("children")): parent.children - if __name__ == '__main__': + + if __name__ == "__main__": Profiler.main() We can run our new script directly:: @@ -205,6 +208,7 @@ def test_subqueryload(n): """ # noqa + import argparse import cProfile import gc diff --git a/examples/performance/bulk_updates.py b/examples/performance/bulk_updates.py index c15d0f16726..de5e6dc27da 100644 --- a/examples/performance/bulk_updates.py +++ b/examples/performance/bulk_updates.py @@ -3,8 +3,10 @@ """ + from sqlalchemy import Column from sqlalchemy import create_engine +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.declarative import declarative_base @@ -18,7 +20,7 @@ class Customer(Base): __tablename__ = "customer" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255)) description = Column(String(255)) diff --git a/examples/performance/large_resultsets.py b/examples/performance/large_resultsets.py index 9c0d9fc4e21..36171411276 100644 --- a/examples/performance/large_resultsets.py +++ b/examples/performance/large_resultsets.py @@ -13,8 +13,10 @@ provide a huge amount of functionality. """ + from sqlalchemy import Column from sqlalchemy import create_engine +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.declarative import declarative_base @@ -29,7 +31,7 @@ class Customer(Base): __tablename__ = "customer" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255)) description = Column(String(255)) diff --git a/examples/performance/short_selects.py b/examples/performance/short_selects.py index d0e5f6e9d22..bc6a9c79ac4 100644 --- a/examples/performance/short_selects.py +++ b/examples/performance/short_selects.py @@ -3,11 +3,13 @@ """ + import random from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import create_engine +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String @@ -28,7 +30,7 @@ class Customer(Base): __tablename__ = "customer" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255)) description = Column(String(255)) q = Column(Integer) diff --git a/examples/performance/single_inserts.py b/examples/performance/single_inserts.py index 991d213a07b..4b8132c50af 100644 --- a/examples/performance/single_inserts.py +++ b/examples/performance/single_inserts.py @@ -4,9 +4,11 @@ a database connection, inserts the row, commits and closes. """ + from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import create_engine +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import pool from sqlalchemy import String @@ -21,7 +23,7 @@ class Customer(Base): __tablename__ = "customer" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255)) description = Column(String(255)) diff --git a/examples/sharding/asyncio.py b/examples/sharding/asyncio.py index 4b32034c9f1..a63b0fcaaae 100644 --- a/examples/sharding/asyncio.py +++ b/examples/sharding/asyncio.py @@ -8,6 +8,7 @@ the routine that generates new primary keys. """ + from __future__ import annotations import asyncio diff --git a/examples/sharding/separate_databases.py b/examples/sharding/separate_databases.py index f836aaec00a..9a700734c51 100644 --- a/examples/sharding/separate_databases.py +++ b/examples/sharding/separate_databases.py @@ -1,4 +1,5 @@ """Illustrates sharding using distinct SQLite databases.""" + from __future__ import annotations import datetime diff --git a/examples/sharding/separate_schema_translates.py b/examples/sharding/separate_schema_translates.py index 095ae1cc698..fd754356e5d 100644 --- a/examples/sharding/separate_schema_translates.py +++ b/examples/sharding/separate_schema_translates.py @@ -4,6 +4,7 @@ In this example we will set a "shard id" at all times. """ + from __future__ import annotations import datetime diff --git a/examples/sharding/separate_tables.py b/examples/sharding/separate_tables.py index 1caaaf329b0..3084e9f0693 100644 --- a/examples/sharding/separate_tables.py +++ b/examples/sharding/separate_tables.py @@ -1,5 +1,6 @@ """Illustrates sharding using a single SQLite database, that will however have multiple tables using a naming convention.""" + from __future__ import annotations import datetime diff --git a/examples/space_invaders/__init__.py b/examples/space_invaders/__init__.py index 944f8bb466c..993d1e45431 100644 --- a/examples/space_invaders/__init__.py +++ b/examples/space_invaders/__init__.py @@ -11,11 +11,11 @@ To run:: - python -m examples.space_invaders.space_invaders + $ python -m examples.space_invaders.space_invaders While it runs, watch the SQL output in the log:: - tail -f space_invaders.log + $ tail -f space_invaders.log enjoy! diff --git a/examples/versioned_history/__init__.py b/examples/versioned_history/__init__.py index 0593881e2de..a872a63c034 100644 --- a/examples/versioned_history/__init__.py +++ b/examples/versioned_history/__init__.py @@ -6,21 +6,23 @@ class which represents historical versions of the target object. Compare to the :ref:`examples_versioned_rows` examples which write updates as new rows in the same table, without using a separate history table. -Usage is illustrated via a unit test module ``test_versioning.py``, which can -be run like any other module, using ``unittest`` internally:: +Usage is illustrated via a unit test module ``test_versioning.py``, which is +run using SQLAlchemy's internal pytest plugin:: - python -m examples.versioned_history.test_versioning + $ pytest test/base/test_examples.py A fragment of example usage, using declarative:: from history_meta import Versioned, versioned_session + class Base(DeclarativeBase): pass + class SomeClass(Versioned, Base): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -28,25 +30,25 @@ class SomeClass(Versioned, Base): def __eq__(self, other): assert type(other) is SomeClass and other.id == self.id + Session = sessionmaker(bind=engine) versioned_session(Session) sess = Session() - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() - sc.name = 'sc1modified' + sc.name = "sc1modified" sess.commit() assert sc.version == 2 SomeClassHistory = SomeClass.__history_mapper__.class_ - assert sess.query(SomeClassHistory).\\ - filter(SomeClassHistory.version == 1).\\ - all() \\ - == [SomeClassHistory(version=1, name='sc1')] + assert sess.query(SomeClassHistory).filter( + SomeClassHistory.version == 1 + ).all() == [SomeClassHistory(version=1, name="sc1")] The ``Versioned`` mixin is designed to work with declarative. To use the extension with classical mappers, the ``_history_mapper`` function @@ -64,7 +66,7 @@ def __eq__(self, other): set the flag ``Versioned.use_mapper_versioning`` to True:: class SomeClass(Versioned, Base): - __tablename__ = 'sometable' + __tablename__ = "sometable" use_mapper_versioning = True diff --git a/examples/versioned_history/history_meta.py b/examples/versioned_history/history_meta.py index 806267cb414..88fb16a0049 100644 --- a/examples/versioned_history/history_meta.py +++ b/examples/versioned_history/history_meta.py @@ -2,13 +2,16 @@ import datetime +from sqlalchemy import and_ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import select from sqlalchemy import util from sqlalchemy.orm import attributes from sqlalchemy.orm import object_mapper @@ -56,6 +59,10 @@ def _history_mapper(local_mapper): local_mapper.local_table.metadata, name=local_mapper.local_table.name + "_history", ) + for idx in history_table.indexes: + if idx.name is not None: + idx.name += "_history" + idx.unique = False for orig_c, history_c in zip( local_mapper.local_table.c, history_table.c @@ -144,8 +151,39 @@ def _history_mapper(local_mapper): super_history_table.append_column(col) if not super_mapper: + + def default_version_from_history(context): + # Set default value of version column to the maximum of the + # version in history columns already present +1 + # Otherwise re-appearance of deleted rows would cause an error + # with the next update + current_parameters = context.get_current_parameters() + return context.connection.scalar( + select( + func.coalesce(func.max(history_table.c.version), 0) + 1 + ).where( + and_( + *[ + history_table.c[c.name] + == current_parameters.get(c.name, None) + for c in inspect( + local_mapper.local_table + ).primary_key + ] + ) + ) + ) + local_mapper.local_table.append_column( - Column("version", Integer, default=1, nullable=False), + Column( + "version", + Integer, + # if rows are not being deleted from the main table with + # subsequent re-use of primary key, this default can be + # "1" instead of running a query per INSERT + default=default_version_from_history, + nullable=False, + ), replace_existing=True, ) local_mapper.add_property( diff --git a/examples/versioned_history/test_versioning.py b/examples/versioned_history/test_versioning.py index 7b9c82c60fa..b3fe2170904 100644 --- a/examples/versioned_history/test_versioning.py +++ b/examples/versioned_history/test_versioning.py @@ -8,11 +8,15 @@ from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import join from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy import UniqueConstraint from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property from sqlalchemy.orm import declarative_base @@ -31,7 +35,6 @@ from .history_meta import Versioned from .history_meta import versioned_session - warnings.simplefilter("error") @@ -127,6 +130,98 @@ class SomeClass(Versioned, self.Base, ComparableEntity): ], ) + @testing.variation( + "constraint_type", + [ + "index_single_col", + "composite_index", + "explicit_name_index", + "unique_constraint", + "unique_constraint_naming_conv", + "unique_constraint_explicit_name", + "fk_constraint", + "fk_constraint_naming_conv", + "fk_constraint_explicit_name", + ], + ) + def test_index_naming(self, constraint_type): + """test #10920""" + + if ( + constraint_type.unique_constraint_naming_conv + or constraint_type.fk_constraint_naming_conv + ): + self.Base.metadata.naming_convention = { + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "fk": ( + "fk_%(table_name)s_%(column_0_name)s" + "_%(referred_table_name)s" + ), + } + + if ( + constraint_type.fk_constraint + or constraint_type.fk_constraint_naming_conv + or constraint_type.fk_constraint_explicit_name + ): + + class Related(self.Base): + __tablename__ = "related" + + id = Column(Integer, primary_key=True) + + class SomeClass(Versioned, self.Base): + __tablename__ = "sometable" + + id = Column(Integer, primary_key=True) + x = Column(Integer) + y = Column(Integer) + + # Index objects are copied and these have to have a new name + if constraint_type.index_single_col: + __table_args__ = ( + Index( + None, + x, + ), + ) + elif constraint_type.composite_index: + __table_args__ = (Index(None, x, y),) + elif constraint_type.explicit_name_index: + __table_args__ = (Index("my_index", x, y),) + # unique constraint objects are discarded. + elif ( + constraint_type.unique_constraint + or constraint_type.unique_constraint_naming_conv + ): + __table_args__ = (UniqueConstraint(x, y),) + elif constraint_type.unique_constraint_explicit_name: + __table_args__ = (UniqueConstraint(x, y, name="my_uq"),) + # foreign key constraint objects are copied and have the same + # name, but no database in Core has any problem with this as the + # names are local to the parent table. + elif ( + constraint_type.fk_constraint + or constraint_type.fk_constraint_naming_conv + ): + __table_args__ = (ForeignKeyConstraint([x], [Related.id]),) + elif constraint_type.fk_constraint_explicit_name: + __table_args__ = ( + ForeignKeyConstraint([x], [Related.id], name="my_fk"), + ) + else: + constraint_type.fail() + + eq_( + set(idx.name + "_history" for idx in SomeClass.__table__.indexes), + set( + idx.name + for idx in SomeClass.__history_mapper__.local_table.indexes + ), + ) + self.create_tables() + def test_discussion_9546(self): class ThingExternal(Versioned, self.Base): __tablename__ = "things_external" @@ -786,6 +881,79 @@ class SomeClass(Versioned, self.Base, ComparableEntity): sc2.name = "sc2 modified" sess.commit() + def test_external_id(self): + class ObjectExternal(Versioned, self.Base, ComparableEntity): + __tablename__ = "externalobjects" + + id1 = Column(String(3), primary_key=True) + id2 = Column(String(3), primary_key=True) + name = Column(String(50)) + + self.create_tables() + sess = self.session + sc = ObjectExternal(id1="aaa", id2="bbb", name="sc1") + sess.add(sc) + sess.commit() + + sc.name = "sc1modified" + sess.commit() + + assert sc.version == 2 + + ObjectExternalHistory = ObjectExternal.__history_mapper__.class_ + + eq_( + sess.query(ObjectExternalHistory).all(), + [ + ObjectExternalHistory( + version=1, id1="aaa", id2="bbb", name="sc1" + ), + ], + ) + + sess.delete(sc) + sess.commit() + + assert sess.query(ObjectExternal).count() == 0 + + eq_( + sess.query(ObjectExternalHistory).all(), + [ + ObjectExternalHistory( + version=1, id1="aaa", id2="bbb", name="sc1" + ), + ObjectExternalHistory( + version=2, id1="aaa", id2="bbb", name="sc1modified" + ), + ], + ) + + sc = ObjectExternal(id1="aaa", id2="bbb", name="sc1reappeared") + sess.add(sc) + sess.commit() + + assert sc.version == 3 + + sc.name = "sc1reappearedmodified" + sess.commit() + + assert sc.version == 4 + + eq_( + sess.query(ObjectExternalHistory).all(), + [ + ObjectExternalHistory( + version=1, id1="aaa", id2="bbb", name="sc1" + ), + ObjectExternalHistory( + version=2, id1="aaa", id2="bbb", name="sc1modified" + ), + ObjectExternalHistory( + version=3, id1="aaa", id2="bbb", name="sc1reappeared" + ), + ], + ) + class TestVersioningNewBase(TestVersioning): def make_base(self): diff --git a/examples/versioned_rows/versioned_rows.py b/examples/versioned_rows/versioned_rows.py index 96d2e399ec1..80803b39329 100644 --- a/examples/versioned_rows/versioned_rows.py +++ b/examples/versioned_rows/versioned_rows.py @@ -3,6 +3,7 @@ row is inserted with the new data, keeping the old row intact. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import event diff --git a/examples/versioned_rows/versioned_rows_w_versionid.py b/examples/versioned_rows/versioned_rows_w_versionid.py index fcf8082814a..d030ed065cc 100644 --- a/examples/versioned_rows/versioned_rows_w_versionid.py +++ b/examples/versioned_rows/versioned_rows_w_versionid.py @@ -6,6 +6,7 @@ as the ability to see which row is the most "current" version. """ + from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import create_engine diff --git a/examples/vertical/__init__.py b/examples/vertical/__init__.py index b0c00b664e7..997510e1b07 100644 --- a/examples/vertical/__init__.py +++ b/examples/vertical/__init__.py @@ -15,19 +15,20 @@ Example:: - shrew = Animal(u'shrew') - shrew[u'cuteness'] = 5 - shrew[u'weasel-like'] = False - shrew[u'poisonous'] = True + shrew = Animal("shrew") + shrew["cuteness"] = 5 + shrew["weasel-like"] = False + shrew["poisonous"] = True session.add(shrew) session.flush() - q = (session.query(Animal). - filter(Animal.facts.any( - and_(AnimalFact.key == u'weasel-like', - AnimalFact.value == True)))) - print('weasel-like animals', q.all()) + q = session.query(Animal).filter( + Animal.facts.any( + and_(AnimalFact.key == "weasel-like", AnimalFact.value == True) + ) + ) + print("weasel-like animals", q.all()) .. autosource:: diff --git a/examples/vertical/dictlike-polymorphic.py b/examples/vertical/dictlike-polymorphic.py index 69f32cf4a8e..7de8fa80d9f 100644 --- a/examples/vertical/dictlike-polymorphic.py +++ b/examples/vertical/dictlike-polymorphic.py @@ -3,15 +3,17 @@ Builds upon the dictlike.py example to also add differently typed columns to the "fact" table, e.g.:: - Table('properties', metadata - Column('owner_id', Integer, ForeignKey('owner.id'), - primary_key=True), - Column('key', UnicodeText), - Column('type', Unicode(16)), - Column('int_value', Integer), - Column('char_value', UnicodeText), - Column('bool_value', Boolean), - Column('decimal_value', Numeric(10,2))) + Table( + "properties", + metadata, + Column("owner_id", Integer, ForeignKey("owner.id"), primary_key=True), + Column("key", UnicodeText), + Column("type", Unicode(16)), + Column("int_value", Integer), + Column("char_value", UnicodeText), + Column("bool_value", Boolean), + Column("decimal_value", Numeric(10, 2)), + ) For any given properties row, the value of the 'type' column will point to the '_value' column active for that row. diff --git a/examples/vertical/dictlike.py b/examples/vertical/dictlike.py index f561499e8fd..bd1701c89c6 100644 --- a/examples/vertical/dictlike.py +++ b/examples/vertical/dictlike.py @@ -6,24 +6,30 @@ example, instead of:: # A regular ("horizontal") table has columns for 'species' and 'size' - Table('animal', metadata, - Column('id', Integer, primary_key=True), - Column('species', Unicode), - Column('size', Unicode)) + Table( + "animal", + metadata, + Column("id", Integer, primary_key=True), + Column("species", Unicode), + Column("size", Unicode), + ) A vertical table models this as two tables: one table for the base or parent entity, and another related table holding key/value pairs:: - Table('animal', metadata, - Column('id', Integer, primary_key=True)) + Table("animal", metadata, Column("id", Integer, primary_key=True)) # The properties table will have one row for a 'species' value, and # another row for the 'size' value. - Table('properties', metadata - Column('animal_id', Integer, ForeignKey('animal.id'), - primary_key=True), - Column('key', UnicodeText), - Column('value', UnicodeText)) + Table( + "properties", + metadata, + Column( + "animal_id", Integer, ForeignKey("animal.id"), primary_key=True + ), + Column("key", UnicodeText), + Column("value", UnicodeText), + ) Because the key/value pairs in a vertical scheme are not fixed in advance, accessing them like a Python dict can be very convenient. The example below diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 472f01ad063..ec8060bccd0 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -1,5 +1,5 @@ -# sqlalchemy/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# __init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -55,7 +55,7 @@ from .pool import PoolProxiedConnection as PoolProxiedConnection from .pool import PoolResetState as PoolResetState from .pool import QueuePool as QueuePool -from .pool import SingletonThreadPool as SingleonThreadPool +from .pool import SingletonThreadPool as SingletonThreadPool from .pool import StaticPool as StaticPool from .schema import BaseDDLElement as BaseDDLElement from .schema import BLANK_SCHEMA as BLANK_SCHEMA @@ -269,13 +269,11 @@ from .types import VARBINARY as VARBINARY from .types import VARCHAR as VARCHAR -__version__ = "2.0.24" +__version__ = "2.0.42" def __go(lcls: Any) -> None: - from . import util as _sa_util - - _sa_util.preloaded.import_prefix("sqlalchemy") + _util.preloaded.import_prefix("sqlalchemy") from . import exc diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index 1969d7236bc..43cd1035c62 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -1,5 +1,5 @@ # connectors/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/connectors/aioodbc.py b/lib/sqlalchemy/connectors/aioodbc.py index c6986366e1c..6e4b864e7dc 100644 --- a/lib/sqlalchemy/connectors/aioodbc.py +++ b/lib/sqlalchemy/connectors/aioodbc.py @@ -1,5 +1,5 @@ # connectors/aioodbc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -20,6 +20,7 @@ from ..util.concurrency import await_fallback from ..util.concurrency import await_only + if TYPE_CHECKING: from ..engine.interfaces import ConnectArgsType from ..engine.url import URL @@ -58,6 +59,15 @@ def autocommit(self, value): self._connection._conn.autocommit = value + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def add_output_converter(self, *arg, **kw): + self._connection.add_output_converter(*arg, **kw) + + def character_set_name(self): + return self._connection.character_set_name() + def cursor(self, server_side=False): # aioodbc sets connection=None when closed and just fails with # AttributeError here. Here we use the same ProgrammingError + @@ -170,18 +180,5 @@ def get_pool_class(cls, url): else: return pool.AsyncAdaptedQueuePool - def _do_isolation_level(self, connection, autocommit, isolation_level): - connection.set_autocommit(autocommit) - connection.set_isolation_level(isolation_level) - - def _do_autocommit(self, connection, value): - connection.set_autocommit(value) - - def set_readonly(self, connection, value): - connection.set_read_only(value) - - def set_deferrable(self, connection, value): - connection.set_deferrable(value) - def get_driver_connection(self, connection): return connection._connection diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 997407ccd58..fda21b6d6f0 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -1,22 +1,124 @@ # connectors/asyncio.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """generic asyncio-adapted versions of DBAPI connection and cursor""" from __future__ import annotations +import asyncio import collections -import itertools +import sys +from typing import Any +from typing import AsyncIterator +from typing import Deque +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING from ..engine import AdaptedConnection -from ..util.concurrency import asyncio from ..util.concurrency import await_fallback from ..util.concurrency import await_only +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.interfaces import _DBAPICursorDescription + from ..engine.interfaces import _DBAPIMultiExecuteParams + from ..engine.interfaces import _DBAPISingleExecuteParams + from ..engine.interfaces import DBAPIModule + from ..util.typing import Self + + +class AsyncIODBAPIConnection(Protocol): + """protocol representing an async adapted version of a + :pep:`249` database connection. + + + """ + + # note that async DBAPIs dont agree if close() should be awaitable, + # so it is omitted here and picked up by the __getattr__ hook below + + async def commit(self) -> None: ... + + def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... + + async def rollback(self) -> None: ... + + def __getattr__(self, key: str) -> Any: ... + + def __setattr__(self, key: str, value: Any) -> None: ... + + +class AsyncIODBAPICursor(Protocol): + """protocol representing an async adapted version + of a :pep:`249` database cursor. + + + """ + + def __aenter__(self) -> Any: ... + + @property + def description( + self, + ) -> _DBAPICursorDescription: + """The description attribute of the Cursor.""" + ... + + @property + def rowcount(self) -> int: ... + + arraysize: int + + lastrowid: int + + async def close(self) -> None: ... + + async def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: ... + + async def executemany( + self, + operation: Any, + parameters: _DBAPIMultiExecuteParams, + ) -> Any: ... + + async def fetchone(self) -> Optional[Any]: ... + + async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ... + + async def fetchall(self) -> Sequence[Any]: ... + + async def setinputsizes(self, sizes: Sequence[Any]) -> None: ... + + def setoutputsize(self, size: Any, column: Any) -> None: ... + + async def callproc( + self, procname: str, parameters: Sequence[Any] = ... + ) -> Any: ... + + async def nextset(self) -> Optional[bool]: ... + + def __aiter__(self) -> AsyncIterator[Any]: ... + + +class AsyncAdapt_dbapi_module: + if TYPE_CHECKING: + Error = DBAPIModule.Error + OperationalError = DBAPIModule.OperationalError + InterfaceError = DBAPIModule.InterfaceError + IntegrityError = DBAPIModule.IntegrityError + + def __getattr__(self, key: str) -> Any: ... class AsyncAdapt_dbapi_cursor: @@ -29,99 +131,136 @@ class AsyncAdapt_dbapi_cursor: "_rows", ) - def __init__(self, adapt_connection): + _cursor: AsyncIODBAPICursor + _adapt_connection: AsyncAdapt_dbapi_connection + _connection: AsyncIODBAPIConnection + _rows: Deque[Any] + + def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ - cursor = self._connection.cursor() + cursor = self._make_new_cursor(self._connection) + self._cursor = self._aenter_cursor(cursor) + + if not self.server_side: + self._rows = collections.deque() + + def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor: + return self.await_(cursor.__aenter__()) # type: ignore[no-any-return] - self._cursor = self.await_(cursor.__aenter__()) - self._rows = collections.deque() + def _make_new_cursor( + self, connection: AsyncIODBAPIConnection + ) -> AsyncIODBAPICursor: + return connection.cursor() @property - def description(self): + def description(self) -> Optional[_DBAPICursorDescription]: return self._cursor.description @property - def rowcount(self): + def rowcount(self) -> int: return self._cursor.rowcount @property - def arraysize(self): + def arraysize(self) -> int: return self._cursor.arraysize @arraysize.setter - def arraysize(self, value): + def arraysize(self, value: int) -> None: self._cursor.arraysize = value @property - def lastrowid(self): + def lastrowid(self) -> int: return self._cursor.lastrowid - def close(self): + def close(self) -> None: # note we aren't actually closing the cursor here, # we are just letting GC do it. see notes in aiomysql dialect self._rows.clear() - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + try: + return self.await_(self._execute_async(operation, parameters)) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany( + self, + operation: Any, + seq_of_parameters: _DBAPIMultiExecuteParams, + ) -> Any: + try: + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + except Exception as error: + self._adapt_connection._handle_exception(error) - async def _execute_async(self, operation, parameters): + async def _execute_async( + self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] + ) -> Any: async with self._adapt_connection._execute_mutex: - result = await self._cursor.execute(operation, parameters or ()) + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) if self._cursor.description and not self.server_side: - # aioodbc has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. self._rows = collections.deque(await self._cursor.fetchall()) return result - async def _executemany_async(self, operation, seq_of_parameters): + async def _executemany_async( + self, + operation: Any, + seq_of_parameters: _DBAPIMultiExecuteParams, + ) -> Any: async with self._adapt_connection._execute_mutex: return await self._cursor.executemany(operation, seq_of_parameters) - def nextset(self): + def nextset(self) -> None: self.await_(self._cursor.nextset()) if self._cursor.description and not self.server_side: self._rows = collections.deque( self.await_(self._cursor.fetchall()) ) - def setinputsizes(self, *inputsizes): + def setinputsizes(self, *inputsizes: Any) -> None: # NOTE: this is overrridden in aioodbc due to # see https://github.com/aio-libs/aioodbc/issues/451 # right now return self.await_(self._cursor.setinputsizes(*inputsizes)) - def __iter__(self): + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + def __iter__(self) -> Iterator[Any]: while self._rows: yield self._rows.popleft() - def fetchone(self): + def fetchone(self) -> Optional[Any]: if self._rows: return self._rows.popleft() else: return None - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: if size is None: size = self.arraysize + rr = self._rows + return [rr.popleft() for _ in range(min(size, len(rr)))] - rr = iter(self._rows) - retval = list(itertools.islice(rr, 0, size)) - self._rows = collections.deque(rr) - return retval - - def fetchall(self): + def fetchall(self) -> Sequence[Any]: retval = list(self._rows) self._rows.clear() return retval @@ -131,75 +270,78 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): __slots__ = () server_side = True - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor() - - self._cursor = self.await_(cursor.__aenter__()) - - def close(self): + def close(self) -> None: if self._cursor is not None: self.await_(self._cursor.close()) - self._cursor = None + self._cursor = None # type: ignore - def fetchone(self): + def fetchone(self) -> Optional[Any]: return self.await_(self._cursor.fetchone()) - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> Any: return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self): + def fetchall(self) -> Sequence[Any]: return self.await_(self._cursor.fetchall()) + def __iter__(self) -> Iterator[Any]: + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + class AsyncAdapt_dbapi_connection(AdaptedConnection): _cursor_cls = AsyncAdapt_dbapi_cursor _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") - def __init__(self, dbapi, connection): + _connection: AsyncIODBAPIConnection + + def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection): self.dbapi = dbapi self._connection = connection self._execute_mutex = asyncio.Lock() - def ping(self, reconnect): - return self.await_(self._connection.ping(reconnect)) - - def add_output_converter(self, *arg, **kw): - self._connection.add_output_converter(*arg, **kw) - - def character_set_name(self): - return self._connection.character_set_name() - - @property - def autocommit(self): - return self._connection.autocommit - - @autocommit.setter - def autocommit(self, value): - # https://github.com/aio-libs/aioodbc/issues/448 - # self._connection.autocommit = value - - self._connection._conn.autocommit = value - - def cursor(self, server_side=False): + def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor: if server_side: return self._ss_cursor_cls(self) else: return self._cursor_cls(self) - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - - def close(self): + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: + """lots of DBAPIs seem to provide this, so include it""" + cursor = self.cursor() + cursor.execute(operation, parameters) + return cursor + + def _handle_exception(self, error: Exception) -> NoReturn: + exc_info = sys.exc_info() + + raise error.with_traceback(exc_info[2]) + + def rollback(self) -> None: + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self) -> None: + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self) -> None: self.await_(self._connection.close()) diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 49712a57c41..766493e2e0c 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -1,5 +1,5 @@ # connectors/pyodbc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,7 +8,6 @@ from __future__ import annotations import re -from types import ModuleType import typing from typing import Any from typing import Dict @@ -29,6 +28,7 @@ from ..sql.type_api import TypeEngine if typing.TYPE_CHECKING: + from ..engine.interfaces import DBAPIModule from ..engine.interfaces import IsolationLevel @@ -48,15 +48,13 @@ class PyODBCConnector(Connector): # hold the desired driver name pyodbc_driver_name: Optional[str] = None - dbapi: ModuleType - def __init__(self, use_setinputsizes: bool = False, **kw: Any): super().__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @classmethod - def import_dbapi(cls) -> ModuleType: + def import_dbapi(cls) -> DBAPIModule: return __import__("pyodbc") def create_connect_args(self, url: URL) -> ConnectArgsType: @@ -150,7 +148,7 @@ def is_disconnect( ], cursor: Optional[interfaces.DBAPICursor], ) -> bool: - if isinstance(e, self.dbapi.ProgrammingError): + if isinstance(e, self.loaded_dbapi.ProgrammingError): return "The cursor's connection has been closed." in str( e ) or "Attempt to use a closed connection." in str(e) @@ -217,19 +215,19 @@ def do_set_input_sizes( cursor.setinputsizes( [ - (dbtype, None, None) - if not isinstance(dbtype, tuple) - else dbtype + ( + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + ) for key, dbtype, sqltype in list_of_tuples ] ) def get_isolation_level_values( - self, dbapi_connection: interfaces.DBAPIConnection + self, dbapi_conn: interfaces.DBAPIConnection ) -> List[IsolationLevel]: - return super().get_isolation_level_values(dbapi_connection) + [ - "AUTOCOMMIT" - ] + return [*super().get_isolation_level_values(dbapi_conn), "AUTOCOMMIT"] def set_isolation_level( self, diff --git a/lib/sqlalchemy/cyextension/__init__.py b/lib/sqlalchemy/cyextension/__init__.py index e69de29bb2d..cb8dc2c6ec3 100644 --- a/lib/sqlalchemy/cyextension/__init__.py +++ b/lib/sqlalchemy/cyextension/__init__.py @@ -0,0 +1,6 @@ +# cyextension/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index 4d134ccf302..86d24852b3f 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -1,3 +1,9 @@ +# cyextension/collections.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php cimport cython from cpython.long cimport PyLong_FromLongLong from cpython.set cimport PySet_Add diff --git a/lib/sqlalchemy/cyextension/immutabledict.pxd b/lib/sqlalchemy/cyextension/immutabledict.pxd index fe7ad6a81a8..76f22893168 100644 --- a/lib/sqlalchemy/cyextension/immutabledict.pxd +++ b/lib/sqlalchemy/cyextension/immutabledict.pxd @@ -1,2 +1,8 @@ +# cyextension/immutabledict.pxd +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php cdef class immutabledict(dict): pass diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx index 100287b380d..b37eccc4c39 100644 --- a/lib/sqlalchemy/cyextension/immutabledict.pyx +++ b/lib/sqlalchemy/cyextension/immutabledict.pyx @@ -1,3 +1,9 @@ +# cyextension/immutabledict.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size diff --git a/lib/sqlalchemy/cyextension/processors.pyx b/lib/sqlalchemy/cyextension/processors.pyx index b0ad865c54a..3d714569fa0 100644 --- a/lib/sqlalchemy/cyextension/processors.pyx +++ b/lib/sqlalchemy/cyextension/processors.pyx @@ -1,3 +1,9 @@ +# cyextension/processors.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php import datetime from datetime import datetime as datetime_cls from datetime import time as time_cls diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx index 0d7eeece93c..b6e357a1f35 100644 --- a/lib/sqlalchemy/cyextension/resultproxy.pyx +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -1,3 +1,9 @@ +# cyextension/resultproxy.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php import operator cdef class BaseRow: diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx index 92e91a6edc1..cb17acd69c0 100644 --- a/lib/sqlalchemy/cyextension/util.pyx +++ b/lib/sqlalchemy/cyextension/util.pyx @@ -1,3 +1,9 @@ +# cyextension/util.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from collections.abc import Mapping from sqlalchemy import exc diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 055d087cf24..30928a98455 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -1,5 +1,5 @@ # dialects/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,6 +7,7 @@ from __future__ import annotations +from typing import Any from typing import Callable from typing import Optional from typing import Type @@ -39,7 +40,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]: # hardcoded. if mysql / mariadb etc were third party dialects # they would just publish all the entrypoints, which would actually # look much nicer. - module = __import__( + module: Any = __import__( "sqlalchemy.dialects.mysql.mariadb" ).dialects.mysql.mariadb return module.loader(driver) # type: ignore diff --git a/lib/sqlalchemy/dialects/_typing.py b/lib/sqlalchemy/dialects/_typing.py index 932742bd045..4dd40d7220f 100644 --- a/lib/sqlalchemy/dialects/_typing.py +++ b/lib/sqlalchemy/dialects/_typing.py @@ -1,3 +1,9 @@ +# dialects/_typing.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations from typing import Any @@ -6,14 +12,19 @@ from typing import Optional from typing import Union -from ..sql._typing import _DDLColumnArgument -from ..sql.elements import DQLDMLClauseElement +from ..sql import roles +from ..sql.base import ColumnCollection +from ..sql.schema import Column from ..sql.schema import ColumnCollectionConstraint from ..sql.schema import Index _OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None] -_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]] -_OnConflictIndexWhereT = Optional[DQLDMLClauseElement] -_OnConflictSetT = Optional[Mapping[Any, Any]] -_OnConflictWhereT = Union[DQLDMLClauseElement, str, None] +_OnConflictIndexElementsT = Optional[ + Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]] +] +_OnConflictIndexWhereT = Optional[roles.WhereHavingRole] +_OnConflictSetT = Optional[ + Union[Mapping[Any, Any], ColumnCollection[Any, Any]] +] +_OnConflictWhereT = Optional[roles.WhereHavingRole] diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index 6bbb934157a..20140fdddb3 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -1,5 +1,5 @@ -# mssql/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/mssql/aioodbc.py b/lib/sqlalchemy/dialects/mssql/aioodbc.py index 23c2790f29d..522ad1d6b0d 100644 --- a/lib/sqlalchemy/dialects/mssql/aioodbc.py +++ b/lib/sqlalchemy/dialects/mssql/aioodbc.py @@ -1,5 +1,5 @@ -# mssql/aioodbc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/aioodbc.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -32,13 +32,12 @@ styles are otherwise equivalent to those documented in the pyodbc section:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine( "mssql+aioodbc://scott:tiger@mssql2017:1433/test?" "driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" ) - - """ from __future__ import annotations diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 687de04e4d3..f641ff03ea8 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1,5 +1,5 @@ -# mssql/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,6 @@ """ .. dialect:: mssql :name: Microsoft SQL Server - :full_support: 2017 :normal_support: 2012+ :best_effort: 2005+ @@ -40,9 +39,12 @@ from sqlalchemy import Table, MetaData, Column, Integer m = MetaData() - t = Table('t', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) + t = Table( + "t", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) m.create_all(engine) The above example will generate DDL as: @@ -60,9 +62,12 @@ on the first integer primary key column:: m = MetaData() - t = Table('t', m, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('x', Integer)) + t = Table( + "t", + m, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", Integer), + ) m.create_all(engine) To add the ``IDENTITY`` keyword to a non-primary key column, specify @@ -72,9 +77,12 @@ is set to ``False`` on any integer primary key column:: m = MetaData() - t = Table('t', m, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('x', Integer, autoincrement=True)) + t = Table( + "t", + m, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("x", Integer, autoincrement=True), + ) m.create_all(engine) .. versionchanged:: 1.4 Added :class:`_schema.Identity` construct @@ -137,14 +145,12 @@ from sqlalchemy import Table, Integer, Column, Identity test = Table( - 'test', metadata, + "test", + metadata, Column( - 'id', - Integer, - primary_key=True, - Identity(start=100, increment=10) + "id", Integer, primary_key=True, Identity(start=100, increment=10) ), - Column('name', String(20)) + Column("name", String(20)), ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -154,7 +160,7 @@ CREATE TABLE test ( id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, name VARCHAR(20) NULL, - ) + ) .. note:: @@ -187,6 +193,7 @@ Base = declarative_base() + class TestTable(Base): __tablename__ = "test" id = Column( @@ -212,8 +219,9 @@ class TestTable(Base): from sqlalchemy import TypeDecorator + class NumericAsInteger(TypeDecorator): - '''normalize floating point return values into ints''' + "normalize floating point return values into ints" impl = Numeric(10, 0, asdecimal=False) cache_ok = True @@ -223,6 +231,7 @@ def process_result_value(self, value, dialect): value = int(value) return value + class TestTable(Base): __tablename__ = "test" id = Column( @@ -271,11 +280,11 @@ class TestTable(Base): fetched in order to receive the value. Given a table as:: t = Table( - 't', + "t", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - implicit_returning=False + Column("id", Integer, primary_key=True), + Column("x", Integer), + implicit_returning=False, ) an INSERT will look like: @@ -301,12 +310,13 @@ class TestTable(Base): execution. Given this example:: m = MetaData() - t = Table('t', m, Column('id', Integer, primary_key=True), - Column('x', Integer)) + t = Table( + "t", m, Column("id", Integer, primary_key=True), Column("x", Integer) + ) m.create_all(engine) with engine.begin() as conn: - conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + conn.execute(t.insert(), {"id": 1, "x": 1}, {"id": 2, "x": 2}) The above column will be created with IDENTITY, however the INSERT statement we emit is specifying explicit values. In the echo output we can see @@ -342,7 +352,11 @@ class TestTable(Base): >>> from sqlalchemy import Sequence >>> from sqlalchemy.schema import CreateSequence >>> from sqlalchemy.dialects import mssql - >>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect())) + >>> print( + ... CreateSequence(Sequence("my_seq", start=1)).compile( + ... dialect=mssql.dialect() + ... ) + ... ) {printsql}CREATE SEQUENCE my_seq START WITH 1 For integer primary key generation, SQL Server's ``IDENTITY`` construct should @@ -376,12 +390,12 @@ class TestTable(Base): To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None:: my_table = Table( - 'my_table', metadata, - Column('my_data', VARCHAR(None)), - Column('my_n_data', NVARCHAR(None)) + "my_table", + metadata, + Column("my_data", VARCHAR(None)), + Column("my_n_data", NVARCHAR(None)), ) - Collation Support ----------------- @@ -389,10 +403,13 @@ class TestTable(Base): specified by the string argument "collation":: from sqlalchemy import VARCHAR - Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) + + Column("login", VARCHAR(32, collation="Latin1_General_CI_AS")) When such a column is associated with a :class:`_schema.Table`, the -CREATE TABLE statement for this column will yield:: +CREATE TABLE statement for this column will yield: + +.. sourcecode:: sql login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL @@ -412,7 +429,9 @@ class TestTable(Base): select(some_table).limit(5) -will render similarly to:: +will render similarly to: + +.. sourcecode:: sql SELECT TOP 5 col1, col2.. FROM table @@ -422,7 +441,9 @@ class TestTable(Base): select(some_table).order_by(some_table.c.col3).limit(5).offset(10) -will render similarly to:: +will render similarly to: + +.. sourcecode:: sql SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2, ROW_NUMBER() OVER (ORDER BY col3) AS @@ -475,16 +496,13 @@ class TestTable(Base): To set isolation level using :func:`_sa.create_engine`:: engine = create_engine( - "mssql+pyodbc://scott:tiger@ms_2008", - isolation_level="REPEATABLE READ" + "mssql+pyodbc://scott:tiger@ms_2008", isolation_level="REPEATABLE READ" ) To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options( - isolation_level="READ COMMITTED" - ) + connection = connection.execution_options(isolation_level="READ COMMITTED") Valid values for ``isolation_level`` include: @@ -534,7 +552,6 @@ class TestTable(Base): mssql_engine = create_engine( "mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", - # disable default reset-on-return scheme pool_reset_on_return=None, ) @@ -563,13 +580,17 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): ----------- MSSQL has support for three levels of column nullability. The default nullability allows nulls and is explicit in the CREATE TABLE -construct:: +construct: + +.. sourcecode:: sql name VARCHAR(20) NULL If ``nullable=None`` is specified then no specification is made. In other words the database's configured default is used. This will -render:: +render: + +.. sourcecode:: sql name VARCHAR(20) @@ -625,8 +646,9 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): * The flag can be set to either ``True`` or ``False`` when the dialect is created, typically via :func:`_sa.create_engine`:: - eng = create_engine("mssql+pymssql://user:pass@host/db", - deprecate_large_types=True) + eng = create_engine( + "mssql+pymssql://user:pass@host/db", deprecate_large_types=True + ) * Complete control over whether the "old" or "new" types are rendered is available in all SQLAlchemy versions by using the UPPERCASE type objects @@ -648,9 +670,10 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): :class:`_schema.Table`:: Table( - "some_table", metadata, + "some_table", + metadata, Column("q", String(50)), - schema="mydatabase.dbo" + schema="mydatabase.dbo", ) When performing operations such as table or component reflection, a schema @@ -662,9 +685,10 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): special characters. Given an argument as below:: Table( - "some_table", metadata, + "some_table", + metadata, Column("q", String(50)), - schema="MyDataBase.dbo" + schema="MyDataBase.dbo", ) The above schema would be rendered as ``[MyDataBase].dbo``, and also in @@ -677,21 +701,22 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): "database" will be None:: Table( - "some_table", metadata, + "some_table", + metadata, Column("q", String(50)), - schema="[MyDataBase.dbo]" + schema="[MyDataBase.dbo]", ) To individually specify both database and owner name with special characters or embedded dots, use two sets of brackets:: Table( - "some_table", metadata, + "some_table", + metadata, Column("q", String(50)), - schema="[MyDataBase.Period].[MyOwner.Dot]" + schema="[MyDataBase.Period].[MyOwner.Dot]", ) - .. versionchanged:: 1.2 the SQL Server dialect now treats brackets as identifier delimiters splitting the schema into separate database and owner tokens, to allow dots within either name itself. @@ -706,10 +731,11 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): SELECT statement; given a table:: account_table = Table( - 'account', metadata, - Column('id', Integer, primary_key=True), - Column('info', String(100)), - schema="customer_schema" + "account", + metadata, + Column("id", Integer, primary_key=True), + Column("info", String(100)), + schema="customer_schema", ) this legacy mode of rendering would assume that "customer_schema.account" @@ -752,37 +778,55 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): To generate a clustered primary key use:: - Table('my_table', metadata, - Column('x', ...), - Column('y', ...), - PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + Table( + "my_table", + metadata, + Column("x", ...), + Column("y", ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True), + ) -which will render the table, for example, as:: +which will render the table, for example, as: - CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, - PRIMARY KEY CLUSTERED (x, y)) +.. sourcecode:: sql + + CREATE TABLE my_table ( + x INTEGER NOT NULL, + y INTEGER NOT NULL, + PRIMARY KEY CLUSTERED (x, y) + ) Similarly, we can generate a clustered unique constraint using:: - Table('my_table', metadata, - Column('x', ...), - Column('y', ...), - PrimaryKeyConstraint("x"), - UniqueConstraint("y", mssql_clustered=True), - ) + Table( + "my_table", + metadata, + Column("x", ...), + Column("y", ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) To explicitly request a non-clustered primary key (for example, when a separate clustered index is desired), use:: - Table('my_table', metadata, - Column('x', ...), - Column('y', ...), - PrimaryKeyConstraint("x", "y", mssql_clustered=False)) + Table( + "my_table", + metadata, + Column("x", ...), + Column("y", ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=False), + ) -which will render the table, for example, as:: +which will render the table, for example, as: + +.. sourcecode:: sql - CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, - PRIMARY KEY NONCLUSTERED (x, y)) + CREATE TABLE my_table ( + x INTEGER NOT NULL, + y INTEGER NOT NULL, + PRIMARY KEY NONCLUSTERED (x, y) + ) Columnstore Index Support ------------------------- @@ -820,7 +864,7 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): The ``mssql_include`` option renders INCLUDE(colname) for the given string names:: - Index("my_index", table.c.x, mssql_include=['y']) + Index("my_index", table.c.x, mssql_include=["y"]) would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` @@ -875,18 +919,19 @@ def _reset_mssql(dbapi_connection, connection_record, reset_state): specify ``implicit_returning=False`` for each :class:`_schema.Table` which has triggers:: - Table('mytable', metadata, - Column('id', Integer, primary_key=True), + Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), # ..., - implicit_returning=False + implicit_returning=False, ) Declarative form:: class MyClass(Base): # ... - __table_args__ = {'implicit_returning':False} - + __table_args__ = {"implicit_returning": False} .. _mssql_rowcount_versioning: @@ -920,7 +965,9 @@ class MyClass(Base): applications to have long held locks and frequent deadlocks. Enabling snapshot isolation for the database as a whole is recommended for modern levels of concurrency support. This is accomplished via the -following ALTER DATABASE commands executed at the SQL prompt:: +following ALTER DATABASE commands executed at the SQL prompt: + +.. sourcecode:: sql ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON @@ -1426,7 +1473,6 @@ class ROWVERSION(TIMESTAMP): class NTEXT(sqltypes.UnicodeText): - """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" @@ -1551,44 +1597,11 @@ def process(value): def process(value): return f"""'{ - value.replace("-", "").replace("'", "''") - }'""" + value.replace("-", "").replace("'", "''") + }'""" return process - def _sentinel_value_resolver(self, dialect): - """Return a callable that will receive the uuid object or string - as it is normally passed to the DB in the parameter set, after - bind_processor() is called. Convert this value to match - what it would be as coming back from an INSERT..OUTPUT inserted. - - for the UUID type, there are four varieties of settings so here - we seek to convert to the string or UUID representation that comes - back from the driver. - - """ - character_based_uuid = ( - not dialect.supports_native_uuid or not self.native_uuid - ) - - if character_based_uuid: - if self.native_uuid: - # for pyodbc, uuid.uuid() objects are accepted for incoming - # data, as well as strings. but the driver will always return - # uppercase strings in result sets. - def process(value): - return str(value).upper() - - else: - - def process(value): - return str(value) - - return process - else: - # for pymssql, we get uuid.uuid() objects back. - return None - class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): __visit_name__ = "UNIQUEIDENTIFIER" @@ -1596,12 +1609,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): @overload def __init__( self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... - ): - ... + ): ... @overload - def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): - ... + def __init__( + self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ... + ): ... def __init__(self, as_uuid: bool = True): """Construct a :class:`_mssql.UNIQUEIDENTIFIER` type. @@ -1852,7 +1865,6 @@ class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False _lastrowid = None - _rowcount = None dialect: MSDialect @@ -1972,13 +1984,6 @@ def post_exec(self): def get_lastrowid(self): return self._lastrowid - @property - def rowcount(self): - if self._rowcount is not None: - return self._rowcount - else: - return self.cursor.rowcount - def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: @@ -2030,6 +2035,10 @@ def __init__(self, *args, **kwargs): self.tablealiases = {} super().__init__(*args, **kwargs) + def _format_frame_clause(self, range_, **kw): + kw["literal_execute"] = True + return super()._format_frame_clause(range_, **kw) + def _with_legacy_schema_aliasing(fn): def decorate(self, *arg, **kw): if self.dialect.legacy_schema_aliasing: @@ -2483,10 +2492,12 @@ def _render_json_extract_from_binary(self, binary, operator, **kw): type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), - "FLOAT" - if isinstance(binary.type, sqltypes.Float) - else "NUMERIC(%s, %s)" - % (binary.type.precision, binary.type.scale), + ( + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale) + ), ) elif binary.type._type_affinity is sqltypes.Boolean: # the NULL handling is particularly weird with boolean, so @@ -2522,7 +2533,6 @@ def visit_sequence(self, seq, **kw): class MSSQLStrictCompiler(MSSQLCompiler): - """A subclass of MSSQLCompiler which disables the usage of bind parameters where not allowed natively by MS-SQL. @@ -3981,10 +3991,8 @@ def get_foreign_keys( ) # group rows by constraint ID, to handle multi-column FKs - fkeys = [] - - def fkey_rec(): - return { + fkeys = util.defaultdict( + lambda: { "name": None, "constrained_columns": [], "referred_schema": None, @@ -3992,8 +4000,7 @@ def fkey_rec(): "referred_columns": [], "options": {}, } - - fkeys = util.defaultdict(fkey_rec) + ) for r in connection.execute(s).all(): ( diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index e770313f937..b60bb158b46 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -1,5 +1,5 @@ -# mssql/information_schema.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/information_schema.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -207,6 +207,7 @@ class NumericSqlVariant(TypeDecorator): int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the correct value as string. """ + impl = Unicode cache_ok = True diff --git a/lib/sqlalchemy/dialects/mssql/json.py b/lib/sqlalchemy/dialects/mssql/json.py index 815b5d2ff86..a2d3ce81469 100644 --- a/lib/sqlalchemy/dialects/mssql/json.py +++ b/lib/sqlalchemy/dialects/mssql/json.py @@ -1,3 +1,9 @@ +# dialects/mssql/json.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import types as sqltypes @@ -48,9 +54,7 @@ class JSON(sqltypes.JSON): dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor should be used:: - stmt = select( - data_table.c.data["some key"].as_json() - ).where( + stmt = select(data_table.c.data["some key"].as_json()).where( data_table.c.data["some key"].as_json() == {"sub": "structure"} ) @@ -61,9 +65,7 @@ class JSON(sqltypes.JSON): :meth:`_types.JSON.Comparator.as_integer`, :meth:`_types.JSON.Comparator.as_float`:: - stmt = select( - data_table.c.data["some key"].as_string() - ).where( + stmt = select(data_table.c.data["some key"].as_string()).where( data_table.c.data["some key"].as_string() == "some string" ) diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index 096ae03fa56..10165856e1a 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -1,3 +1,9 @@ +# dialects/mssql/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from sqlalchemy import inspect @@ -16,10 +22,17 @@ from ...testing.provision import get_temp_table_name from ...testing.provision import log from ...testing.provision import normalize_sequence +from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs from ...testing.provision import temp_table_keyword_args +@post_configure_engine.for_db("mssql") +def post_configure_engine(url, engine, follower_ident): + if engine.driver == "pyodbc": + engine.dialect.dbapi.pooling = False + + @generate_driver_url.for_db("mssql") def generate_driver_url(url, driver, query_str): backend = url.get_backend_name() @@ -29,6 +42,9 @@ def generate_driver_url(url, driver, query_str): if driver not in ("pyodbc", "aioodbc"): new_url = new_url.set(query="") + if driver == "aioodbc": + new_url = new_url.update_query_dict({"MARS_Connection": "Yes"}) + if query_str: new_url = new_url.update_query_string(query_str) diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 3823db91b3a..301a98eb417 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -1,5 +1,5 @@ -# mssql/pymssql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/pymssql.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -103,6 +103,7 @@ def is_disconnect(self, e, connection, cursor): "message 20006", # Write to the server failed "message 20017", # Unexpected EOF from the server "message 20047", # DBPROCESS is dead or not enabled + "The server failed to resume the transaction", ): if msg in str(e): return True diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index a8f12fd984c..cbf0adbfe08 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,5 +1,5 @@ -# mssql/pyodbc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mssql/pyodbc.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -30,7 +30,9 @@ engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn") -Which above, will pass the following connection string to PyODBC:: +Which above, will pass the following connection string to PyODBC: + +.. sourcecode:: text DSN=some_dsn;UID=scott;PWD=tiger @@ -49,7 +51,9 @@ query parameters of the URL. As these names usually have spaces in them, the name must be URL encoded which means using plus signs for spaces:: - engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server") + engine = create_engine( + "mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server" + ) The ``driver`` keyword is significant to the pyodbc dialect and must be specified in lowercase. @@ -69,6 +73,7 @@ The equivalent URL can be constructed using :class:`_sa.engine.URL`:: from sqlalchemy.engine import URL + connection_url = URL.create( "mssql+pyodbc", username="scott", @@ -83,7 +88,6 @@ }, ) - Pass through exact Pyodbc string ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -94,8 +98,11 @@ can help make this easier:: from sqlalchemy.engine import URL + connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password" - connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string}) + connection_url = URL.create( + "mssql+pyodbc", query={"odbc_connect": connection_string} + ) engine = create_engine(connection_url) @@ -127,7 +134,8 @@ from sqlalchemy.engine.url import URL from azure import identity - SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h + # Connection option for access tokens, as defined in msodbcsql.h + SQL_COPT_SS_ACCESS_TOKEN = 1256 TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" @@ -136,14 +144,19 @@ azure_credentials = identity.DefaultAzureCredential() + @event.listens_for(engine, "do_connect") def provide_token(dialect, conn_rec, cargs, cparams): # remove the "Trusted_Connection" parameter that SQLAlchemy adds cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") # create token credential - raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le") - token_struct = struct.pack(f" 7 into strings. The routines here are needed for older pyodbc versions diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index b6af683b5e0..9174c54413a 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -1,5 +1,5 @@ -# mysql/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -53,7 +53,8 @@ from .dml import Insert from .dml import insert from .expression import match -from ...util import compat +from .mariadb import INET4 +from .mariadb import INET6 # default dialect base.dialect = dialect = mysqldb.dialect @@ -71,6 +72,8 @@ "DOUBLE", "ENUM", "FLOAT", + "INET4", + "INET6", "INTEGER", "INTEGER", "JSON", diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 2a0c6ba7832..e2ac70b0294 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -1,10 +1,9 @@ -# mysql/aiomysql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" .. dialect:: mysql+aiomysql @@ -23,207 +22,105 @@ :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4") + engine = create_async_engine( + "mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4" + ) """ # noqa +from __future__ import annotations + +from types import ModuleType +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...engine import AdaptedConnection -from ...util.concurrency import asyncio +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: -class AsyncAdapt_aiomysql_cursor: - # TODO: base on connectors/asyncio.py - # see #10415 - server_side = False - __slots__ = ( - "_adapt_connection", - "_connection", - "await_", - "_cursor", - "_rows", - ) - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor(adapt_connection.dbapi.Cursor) - - # see https://github.com/aio-libs/aiomysql/issues/543 - self._cursor = self.await_(cursor.__aenter__()) - self._rows = [] - - @property - def description(self): - return self._cursor.description - - @property - def rowcount(self): - return self._cursor.rowcount - - @property - def arraysize(self): - return self._cursor.arraysize - - @arraysize.setter - def arraysize(self, value): - self._cursor.arraysize = value - - @property - def lastrowid(self): - return self._cursor.lastrowid - - def close(self): - # note we aren't actually closing the cursor here, - # we are just letting GC do it. to allow this to be async - # we would need the Result to change how it does "Safe close cursor". - # MySQL "cursors" don't actually have state to be "closed" besides - # exhausting rows, which we already have done for sync cursor. - # another option would be to emulate aiosqlite dialect and assign - # cursor only if we are doing server side cursor operation. - self._rows[:] = [] - - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) - - async def _execute_async(self, operation, parameters): - async with self._adapt_connection._execute_mutex: - result = await self._cursor.execute(operation, parameters) - - if not self.server_side: - # aiomysql has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. - self._rows = list(await self._cursor.fetchall()) - return result - - async def _executemany_async(self, operation, seq_of_parameters): - async with self._adapt_connection._execute_mutex: - return await self._cursor.executemany(operation, seq_of_parameters) - - def setinputsizes(self, *inputsizes): - pass - - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval + from ...connectors.asyncio import AsyncIODBAPIConnection + from ...connectors.asyncio import AsyncIODBAPICursor + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import PoolProxiedConnection + from ...engine.url import URL -class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): - # TODO: base on connectors/asyncio.py - # see #10415 +class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor): __slots__ = () - server_side = True - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor) - self._cursor = self.await_(cursor.__aenter__()) + def _make_new_cursor( + self, connection: AsyncIODBAPIConnection + ) -> AsyncIODBAPICursor: + return connection.cursor(self._adapt_connection.dbapi.Cursor) - def close(self): - if self._cursor is not None: - self.await_(self._cursor.close()) - self._cursor = None - def fetchone(self): - return self.await_(self._cursor.fetchone()) - - def fetchmany(self, size=None): - return self.await_(self._cursor.fetchmany(size=size)) +class AsyncAdapt_aiomysql_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor +): + __slots__ = () - def fetchall(self): - return self.await_(self._cursor.fetchall()) + def _make_new_cursor( + self, connection: AsyncIODBAPIConnection + ) -> AsyncIODBAPICursor: + return connection.cursor( + self._adapt_connection.dbapi.aiomysql.cursors.SSCursor + ) -class AsyncAdapt_aiomysql_connection(AdaptedConnection): - # TODO: base on connectors/asyncio.py - # see #10415 - await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_execute_mutex") +class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): + __slots__ = () - def __init__(self, dbapi, connection): - self.dbapi = dbapi - self._connection = connection - self._execute_mutex = asyncio.Lock() + _cursor_cls = AsyncAdapt_aiomysql_cursor + _ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor - def ping(self, reconnect): - return self.await_(self._connection.ping(reconnect)) + def ping(self, reconnect: bool) -> None: + assert not reconnect + self.await_(self._connection.ping(reconnect)) - def character_set_name(self): - return self._connection.character_set_name() + def character_set_name(self) -> Optional[str]: + return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501 - def autocommit(self, value): + def autocommit(self, value: Any) -> None: self.await_(self._connection.autocommit(value)) - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_aiomysql_ss_cursor(self) - else: - return AsyncAdapt_aiomysql_cursor(self) - - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - - def close(self): + def terminate(self) -> None: # it's not awaitable. self._connection.close() + def close(self) -> None: + self.await_(self._connection.ensure_closed()) + class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): - # TODO: base on connectors/asyncio.py - # see #10415 __slots__ = () await_ = staticmethod(await_fallback) -class AsyncAdapt_aiomysql_dbapi: - def __init__(self, aiomysql, pymysql): +class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): + def __init__(self, aiomysql: ModuleType, pymysql: ModuleType): self.aiomysql = aiomysql self.pymysql = pymysql self.paramstyle = "format" self._init_dbapi_attributes() self.Cursor, self.SSCursor = self._init_cursors_subclasses() - def _init_dbapi_attributes(self): + def _init_dbapi_attributes(self) -> None: for name in ( "Warning", "Error", @@ -249,7 +146,7 @@ def _init_dbapi_attributes(self): ): setattr(self, name, getattr(self.pymysql, name)) - def connect(self, *arg, **kw): + def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection: async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect) @@ -264,17 +161,23 @@ def connect(self, *arg, **kw): await_only(creator_fn(*arg, **kw)), ) - def _init_cursors_subclasses(self): + def _init_cursors_subclasses( + self, + ) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]: # suppress unconditional warning emitted by aiomysql - class Cursor(self.aiomysql.Cursor): - async def _show_warnings(self, conn): + class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined] + async def _show_warnings( + self, conn: AsyncIODBAPIConnection + ) -> None: pass - class SSCursor(self.aiomysql.SSCursor): - async def _show_warnings(self, conn): + class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501 + async def _show_warnings( + self, conn: AsyncIODBAPIConnection + ) -> None: pass - return Cursor, SSCursor + return Cursor, SSCursor # type: ignore[return-value] class MySQLDialect_aiomysql(MySQLDialect_pymysql): @@ -285,15 +188,16 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): _sscursor = AsyncAdapt_aiomysql_ss_cursor is_async = True + has_terminate = True @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi: return AsyncAdapt_aiomysql_dbapi( __import__("aiomysql"), __import__("pymysql") ) @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type: async_fallback = url.query.get("async_fallback", False) if util.asbool(async_fallback): @@ -301,25 +205,37 @@ def get_pool_class(cls, url): else: return pool.AsyncAdaptedQueuePool - def create_connect_args(self, url): + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.terminate() + + def create_connect_args( + self, url: URL, _translate_args: Optional[Dict[str, Any]] = None + ) -> ConnectArgsType: return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() return "not connected" in str_e - def _found_rows_client_flag(self): - from pymysql.constants import CLIENT + def _found_rows_client_flag(self) -> int: + from pymysql.constants import CLIENT # type: ignore - return CLIENT.FOUND_ROWS + return CLIENT.FOUND_ROWS # type: ignore[no-any-return] - def get_driver_connection(self, connection): - return connection._connection + def get_driver_connection( + self, connection: DBAPIConnection + ) -> AsyncIODBAPIConnection: + return connection._connection # type: ignore[no-any-return] dialect = MySQLDialect_aiomysql diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 92058d60dd3..750735e8f1e 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -1,10 +1,9 @@ -# mysql/asyncmy.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" .. dialect:: mysql+asyncmy @@ -21,210 +20,97 @@ :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4") + engine = create_async_engine( + "mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4" + ) """ # noqa -from contextlib import asynccontextmanager +from __future__ import annotations + +from types import ModuleType +from typing import Any +from typing import NoReturn +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...engine import AdaptedConnection -from ...util.concurrency import asyncio +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from ...connectors.asyncio import AsyncIODBAPIConnection + from ...connectors.asyncio import AsyncIODBAPICursor + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import PoolProxiedConnection + from ...engine.url import URL -class AsyncAdapt_asyncmy_cursor: - # TODO: base on connectors/asyncio.py - # see #10415 - server_side = False - __slots__ = ( - "_adapt_connection", - "_connection", - "await_", - "_cursor", - "_rows", - ) - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - - cursor = self._connection.cursor() - - self._cursor = self.await_(cursor.__aenter__()) - self._rows = [] - - @property - def description(self): - return self._cursor.description - - @property - def rowcount(self): - return self._cursor.rowcount - - @property - def arraysize(self): - return self._cursor.arraysize - - @arraysize.setter - def arraysize(self, value): - self._cursor.arraysize = value - - @property - def lastrowid(self): - return self._cursor.lastrowid - - def close(self): - # note we aren't actually closing the cursor here, - # we are just letting GC do it. to allow this to be async - # we would need the Result to change how it does "Safe close cursor". - # MySQL "cursors" don't actually have state to be "closed" besides - # exhausting rows, which we already have done for sync cursor. - # another option would be to emulate aiosqlite dialect and assign - # cursor only if we are doing server side cursor operation. - self._rows[:] = [] - - def execute(self, operation, parameters=None): - return self.await_(self._execute_async(operation, parameters)) - - def executemany(self, operation, seq_of_parameters): - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) - - async def _execute_async(self, operation, parameters): - async with self._adapt_connection._mutex_and_adapt_errors(): - if parameters is None: - result = await self._cursor.execute(operation) - else: - result = await self._cursor.execute(operation, parameters) - - if not self.server_side: - # asyncmy has a "fake" async result, so we have to pull it out - # of that here since our default result is not async. - # we could just as easily grab "_rows" here and be done with it - # but this is safer. - self._rows = list(await self._cursor.fetchall()) - return result - - async def _executemany_async(self, operation, seq_of_parameters): - async with self._adapt_connection._mutex_and_adapt_errors(): - return await self._cursor.executemany(operation, seq_of_parameters) - - def setinputsizes(self, *inputsizes): - pass - - def __iter__(self): - while self._rows: - yield self._rows.pop(0) - - def fetchone(self): - if self._rows: - return self._rows.pop(0) - else: - return None - - def fetchmany(self, size=None): - if size is None: - size = self.arraysize - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval - - def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] - return retval +class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () -class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): - # TODO: base on connectors/asyncio.py - # see #10415 +class AsyncAdapt_asyncmy_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor +): __slots__ = () - server_side = True - - def __init__(self, adapt_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - cursor = self._connection.cursor( - adapt_connection.dbapi.asyncmy.cursors.SSCursor + def _make_new_cursor( + self, connection: AsyncIODBAPIConnection + ) -> AsyncIODBAPICursor: + return connection.cursor( + self._adapt_connection.dbapi.asyncmy.cursors.SSCursor ) - self._cursor = self.await_(cursor.__aenter__()) - - def close(self): - if self._cursor is not None: - self.await_(self._cursor.close()) - self._cursor = None - - def fetchone(self): - return self.await_(self._cursor.fetchone()) - - def fetchmany(self, size=None): - return self.await_(self._cursor.fetchmany(size=size)) - - def fetchall(self): - return self.await_(self._cursor.fetchall()) +class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): + __slots__ = () -class AsyncAdapt_asyncmy_connection(AdaptedConnection): - # TODO: base on connectors/asyncio.py - # see #10415 - await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_execute_mutex") + _cursor_cls = AsyncAdapt_asyncmy_cursor + _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor - def __init__(self, dbapi, connection): - self.dbapi = dbapi - self._connection = connection - self._execute_mutex = asyncio.Lock() + def _handle_exception(self, error: Exception) -> NoReturn: + if isinstance(error, AttributeError): + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) - @asynccontextmanager - async def _mutex_and_adapt_errors(self): - async with self._execute_mutex: - try: - yield - except AttributeError: - raise self.dbapi.InternalError( - "network operation failed due to asyncmy attribute error" - ) + raise error - def ping(self, reconnect): + def ping(self, reconnect: bool) -> None: assert not reconnect return self.await_(self._do_ping()) - async def _do_ping(self): - async with self._mutex_and_adapt_errors(): - return await self._connection.ping(False) + async def _do_ping(self) -> None: + try: + async with self._execute_mutex: + await self._connection.ping(False) + except Exception as error: + self._handle_exception(error) - def character_set_name(self): - return self._connection.character_set_name() + def character_set_name(self) -> Optional[str]: + return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501 - def autocommit(self, value): + def autocommit(self, value: Any) -> None: self.await_(self._connection.autocommit(value)) - def cursor(self, server_side=False): - if server_side: - return AsyncAdapt_asyncmy_ss_cursor(self) - else: - return AsyncAdapt_asyncmy_cursor(self) - - def rollback(self): - self.await_(self._connection.rollback()) - - def commit(self): - self.await_(self._connection.commit()) - - def close(self): + def terminate(self) -> None: # it's not awaitable. self._connection.close() + def close(self) -> None: + self.await_(self._connection.ensure_closed()) + class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): __slots__ = () @@ -232,18 +118,13 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): await_ = staticmethod(await_fallback) -def _Binary(x): - """Return x as a binary type.""" - return bytes(x) - - -class AsyncAdapt_asyncmy_dbapi: - def __init__(self, asyncmy): +class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module): + def __init__(self, asyncmy: ModuleType): self.asyncmy = asyncmy self.paramstyle = "format" self._init_dbapi_attributes() - def _init_dbapi_attributes(self): + def _init_dbapi_attributes(self) -> None: for name in ( "Warning", "Error", @@ -264,9 +145,9 @@ def _init_dbapi_attributes(self): BINARY = util.symbol("BINARY") DATETIME = util.symbol("DATETIME") TIMESTAMP = util.symbol("TIMESTAMP") - Binary = staticmethod(_Binary) + Binary = staticmethod(bytes) - def connect(self, *arg, **kw): + def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection: async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect) @@ -290,13 +171,14 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): _sscursor = AsyncAdapt_asyncmy_ss_cursor is_async = True + has_terminate = True @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> type: async_fallback = url.query.get("async_fallback", False) if util.asbool(async_fallback): @@ -304,12 +186,20 @@ def get_pool_class(cls, url): else: return pool.AsyncAdaptedQueuePool - def create_connect_args(self, url): + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.terminate() + + def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501 return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: if super().is_disconnect(e, connection, cursor): return True else: @@ -318,13 +208,15 @@ def is_disconnect(self, e, connection, cursor): "not connected" in str_e or "network operation failed" in str_e ) - def _found_rows_client_flag(self): - from asyncmy.constants import CLIENT + def _found_rows_client_flag(self) -> int: + from asyncmy.constants import CLIENT # type: ignore - return CLIENT.FOUND_ROWS + return CLIENT.FOUND_ROWS # type: ignore[no-any-return] - def get_driver_connection(self, connection): - return connection._connection + def get_driver_connection( + self, connection: DBAPIConnection + ) -> AsyncIODBAPIConnection: + return connection._connection # type: ignore[no-any-return] dialect = MySQLDialect_asyncmy diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 92f90774fbe..f398fe8a04c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1,17 +1,15 @@ -# mysql/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" .. dialect:: mysql :name: MySQL / MariaDB - :full_support: 5.6, 5.7, 8.0 / 10.8, 10.9 :normal_support: 5.6+ / 10+ :best_effort: 5.0.2+ / 5.0.2+ @@ -35,7 +33,9 @@ To connect to a MariaDB database, no changes to the database URL are required:: - engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + engine = create_engine( + "mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4" + ) Upon first connect, the SQLAlchemy dialect employs a server version detection scheme that determines if the @@ -53,7 +53,9 @@ and is not compatible with a MySQL database. To use this mode of operation, replace the "mysql" token in the above URL with "mariadb":: - engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + engine = create_engine( + "mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4" + ) The above engine, upon first connect, will raise an error if the server version detection detects that the backing database is not MariaDB. @@ -99,7 +101,7 @@ a connection will be discarded and replaced with a new one if it has been present in the pool for a fixed number of seconds:: - engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + engine = create_engine("mysql+mysqldb://...", pool_recycle=3600) For more comprehensive disconnect detection of pooled connections, including accommodation of server restarts and network issues, a pre-ping approach may @@ -123,12 +125,14 @@ ``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE`` of ``1024``:: - Table('mytable', metadata, - Column('data', String(32)), - mysql_engine='InnoDB', - mysql_charset='utf8mb4', - mysql_key_block_size="1024" - ) + Table( + "mytable", + metadata, + Column("data", String(32)), + mysql_engine="InnoDB", + mysql_charset="utf8mb4", + mysql_key_block_size="1024", + ) When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against the "mariadb" prefix must be included as well. The values can of course @@ -137,19 +141,17 @@ # support both "mysql" and "mariadb-only" engine URLs - Table('mytable', metadata, - Column('data', String(32)), - - mysql_engine='InnoDB', - mariadb_engine='InnoDB', - - mysql_charset='utf8mb4', - mariadb_charset='utf8', - - mysql_key_block_size="1024" - mariadb_key_block_size="1024" - - ) + Table( + "mytable", + metadata, + Column("data", String(32)), + mysql_engine="InnoDB", + mariadb_engine="InnoDB", + mysql_charset="utf8mb4", + mariadb_charset="utf8", + mysql_key_block_size="1024", + mariadb_key_block_size="1024", + ) The MySQL / MariaDB dialects will normally transfer any keyword specified as ``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the @@ -179,6 +181,31 @@ constraints, all participating ``CREATE TABLE`` statements must specify a transactional engine, which in the vast majority of cases is ``InnoDB``. +Partitioning can similarly be specified using similar options. +In the example below the create table will specify ``PARTITION_BY``, +``PARTITIONS``, ``SUBPARTITIONS`` and ``SUBPARTITION_BY``:: + + # can also use mariadb_* prefix + Table( + "testtable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=True), + Column("other_id", Integer(), primary_key=True, autoincrement=False), + mysql_partitions="2", + mysql_partition_by="KEY(other_id)", + mysql_subpartition_by="HASH(some_expr)", + mysql_subpartitions="2", + ) + +This will render: + +.. sourcecode:: sql + + CREATE TABLE testtable ( + id INTEGER NOT NULL AUTO_INCREMENT, + other_id INTEGER NOT NULL, + PRIMARY KEY (id, other_id) + )PARTITION BY KEY(other_id) PARTITIONS 2 SUBPARTITION BY HASH(some_expr) SUBPARTITIONS 2 Case Sensitivity and Table Reflection ------------------------------------- @@ -215,16 +242,14 @@ To set isolation level using :func:`_sa.create_engine`:: engine = create_engine( - "mysql+mysqldb://scott:tiger@localhost/test", - isolation_level="READ UNCOMMITTED" - ) + "mysql+mysqldb://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED", + ) To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options( - isolation_level="READ COMMITTED" - ) + connection = connection.execution_options(isolation_level="READ COMMITTED") Valid values for ``isolation_level`` include: @@ -256,8 +281,8 @@ the first :class:`.Integer` primary key column which is not marked as a foreign key:: - >>> t = Table('mytable', metadata, - ... Column('mytable_id', Integer, primary_key=True) + >>> t = Table( + ... "mytable", metadata, Column("mytable_id", Integer, primary_key=True) ... ) >>> t.create() CREATE TABLE mytable ( @@ -271,10 +296,12 @@ can also be used to enable auto-increment on a secondary column in a multi-column key for some storage engines:: - Table('mytable', metadata, - Column('gid', Integer, primary_key=True, autoincrement=False), - Column('id', Integer, primary_key=True) - ) + Table( + "mytable", + metadata, + Column("gid", Integer, primary_key=True, autoincrement=False), + Column("id", Integer, primary_key=True), + ) .. _mysql_ss_cursors: @@ -292,7 +319,9 @@ option:: with engine.connect() as conn: - result = conn.execution_options(stream_results=True).execute(text("select * from table")) + result = conn.execution_options(stream_results=True).execute( + text("select * from table") + ) Note that some kinds of SQL statements may not be supported with server side cursors; generally, only SQL statements that return rows should be @@ -320,7 +349,8 @@ in the URL, such as:: e = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4" + ) This charset is the **client character set** for the connection. Some MySQL DBAPIs will default this to a value such as ``latin1``, and some @@ -340,7 +370,8 @@ DBAPI, as in:: e = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4" + ) All modern DBAPIs should support the ``utf8mb4`` charset. @@ -362,7 +393,9 @@ MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now emit a warning when attempting to pass binary data to the database, while a character set encoding is also in place, when the binary data itself is not -valid for that encoding:: +valid for that encoding: + +.. sourcecode:: text default.py:509: Warning: (1300, "Invalid utf8mb4 character string: 'F9876A'") @@ -372,7 +405,9 @@ interpret the binary string as a unicode object even if a datatype such as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires a binary "character set introducer" be present before any non-NULL value -that renders like this:: +that renders like this: + +.. sourcecode:: sql INSERT INTO table (data) VALUES (_binary %s) @@ -382,12 +417,13 @@ # mysqlclient engine = create_engine( - "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true" + ) # PyMySQL engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") - + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true" + ) The ``binary_prefix`` flag may or may not be supported by other MySQL drivers. @@ -430,7 +466,10 @@ from sqlalchemy import create_engine, event - eng = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo='debug') + eng = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test", echo="debug" + ) + # `insert=True` will ensure this is the very first listener to run @event.listens_for(eng, "connect", insert=True) @@ -438,6 +477,7 @@ def connect(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'") + conn = eng.connect() In the example illustrated above, the "connect" event will invoke the "SET" @@ -454,8 +494,8 @@ def connect(dbapi_connection, connection_record): Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic function and operator support:: - table.select(table.c.password==func.md5('plaintext')) - table.select(table.c.username.op('regexp')('^[a-d]')) + table.select(table.c.password == func.md5("plaintext")) + table.select(table.c.username.op("regexp")("^[a-d]")) And of course any valid SQL statement can be executed as a string as well. @@ -468,11 +508,18 @@ def connect(dbapi_connection, connection_record): * SELECT pragma, use :meth:`_expression.Select.prefix_with` and :meth:`_query.Query.prefix_with`:: - select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + select(...).prefix_with(["HIGH_PRIORITY", "SQL_SMALL_RESULT"]) * UPDATE with LIMIT:: - update(..., mysql_limit=10, mariadb_limit=10) + update(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) + +* DELETE + with LIMIT:: + + delete(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) + + .. versionadded:: 2.0.37 Added delete with limit * optimizer hints, use :meth:`_expression.Select.prefix_with` and :meth:`_query.Query.prefix_with`:: @@ -484,14 +531,16 @@ def connect(dbapi_connection, connection_record): select(...).with_hint(some_table, "USE INDEX xyz") -* MATCH operator support:: +* MATCH + operator support:: + + from sqlalchemy.dialects.mysql import match - from sqlalchemy.dialects.mysql import match - select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) + select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) - .. seealso:: + .. seealso:: - :class:`_mysql.match` + :class:`_mysql.match` INSERT/DELETE...RETURNING ------------------------- @@ -508,17 +557,15 @@ def connect(dbapi_connection, connection_record): # INSERT..RETURNING result = connection.execute( - table.insert(). - values(name='foo'). - returning(table.c.col1, table.c.col2) + table.insert().values(name="foo").returning(table.c.col1, table.c.col2) ) print(result.all()) # DELETE..RETURNING result = connection.execute( - table.delete(). - where(table.c.name=='foo'). - returning(table.c.col1, table.c.col2) + table.delete() + .where(table.c.name == "foo") + .returning(table.c.col1, table.c.col2) ) print(result.all()) @@ -545,12 +592,11 @@ def connect(dbapi_connection, connection_record): >>> from sqlalchemy.dialects.mysql import insert >>> insert_stmt = insert(my_table).values( - ... id='some_existing_id', - ... data='inserted value') + ... id="some_existing_id", data="inserted value" + ... ) >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( - ... data=insert_stmt.inserted.data, - ... status='U' + ... data=insert_stmt.inserted.data, status="U" ... ) >>> print(on_duplicate_key_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) @@ -575,8 +621,8 @@ def connect(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> insert_stmt = insert(my_table).values( - ... id='some_existing_id', - ... data='inserted value') + ... id="some_existing_id", data="inserted value" + ... ) >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( ... data="some data", @@ -639,13 +685,11 @@ def connect(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id='some_id', - ... data='inserted value', - ... author='jlh') + ... id="some_id", data="inserted value", author="jlh" + ... ) >>> do_update_stmt = stmt.on_duplicate_key_update( - ... data="updated value", - ... author=stmt.inserted.author + ... data="updated value", author=stmt.inserted.author ... ) >>> print(do_update_stmt) @@ -690,13 +734,13 @@ def connect(dbapi_connection, connection_record): become part of the index. SQLAlchemy provides this feature via the ``mysql_length`` and/or ``mariadb_length`` parameters:: - Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10) + Index("my_index", my_table.c.data, mysql_length=10, mariadb_length=10) - Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, - 'b': 9}) + Index("a_b_idx", my_table.c.a, my_table.c.b, mysql_length={"a": 4, "b": 9}) - Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4, - 'b': 9}) + Index( + "a_b_idx", my_table.c.a, my_table.c.b, mariadb_length={"a": 4, "b": 9} + ) Prefix lengths are given in characters for nonbinary string types and in bytes for binary string types. The value passed to the keyword argument *must* be @@ -713,7 +757,7 @@ def connect(dbapi_connection, connection_record): an index. SQLAlchemy provides this feature via the ``mysql_prefix`` parameter on :class:`.Index`:: - Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT') + Index("my_index", my_table.c.data, mysql_prefix="FULLTEXT") The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL @@ -730,11 +774,13 @@ def connect(dbapi_connection, connection_record): an index or primary key constraint. SQLAlchemy provides this feature via the ``mysql_using`` parameter on :class:`.Index`:: - Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash') + Index( + "my_index", my_table.c.data, mysql_using="hash", mariadb_using="hash" + ) As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: - PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash') + PrimaryKeyConstraint("data", mysql_using="hash", mariadb_using="hash") The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index @@ -753,9 +799,12 @@ def connect(dbapi_connection, connection_record): is available using the keyword argument ``mysql_with_parser``:: Index( - 'my_index', my_table.c.data, - mysql_prefix='FULLTEXT', mysql_with_parser="ngram", - mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram", + "my_index", + my_table.c.data, + mysql_prefix="FULLTEXT", + mysql_with_parser="ngram", + mariadb_prefix="FULLTEXT", + mariadb_with_parser="ngram", ) .. versionadded:: 1.3 @@ -782,6 +831,7 @@ def connect(dbapi_connection, connection_record): from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import ForeignKeyConstraint + @compiles(ForeignKeyConstraint, "mysql", "mariadb") def process(element, compiler, **kw): element.deferrable = element.initially = None @@ -803,10 +853,12 @@ def process(element, compiler, **kw): reflection will not include foreign keys. For these tables, you may supply a :class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: - Table('mytable', metadata, - ForeignKeyConstraint(['other_id'], ['othertable.other_id']), - autoload_with=engine - ) + Table( + "mytable", + metadata, + ForeignKeyConstraint(["other_id"], ["othertable.other_id"]), + autoload_with=engine, + ) .. seealso:: @@ -878,13 +930,15 @@ def process(element, compiler, **kw): mytable = Table( "mytable", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), + Column("id", Integer, primary_key=True), + Column("data", String(50)), Column( - 'last_updated', + "last_updated", TIMESTAMP, - server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") - ) + server_default=text( + "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" + ), + ), ) The same instructions apply to use of the :class:`_types.DateTime` and @@ -895,34 +949,37 @@ def process(element, compiler, **kw): mytable = Table( "mytable", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), + Column("id", Integer, primary_key=True), + Column("data", String(50)), Column( - 'last_updated', + "last_updated", DateTime, - server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") - ) + server_default=text( + "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" + ), + ), ) - Even though the :paramref:`_schema.Column.server_onupdate` feature does not generate this DDL, it still may be desirable to signal to the ORM that this updated value should be fetched. This syntax looks like the following:: from sqlalchemy.schema import FetchedValue + class MyClass(Base): - __tablename__ = 'mytable' + __tablename__ = "mytable" id = Column(Integer, primary_key=True) data = Column(String(50)) last_updated = Column( TIMESTAMP, - server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), - server_onupdate=FetchedValue() + server_default=text( + "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" + ), + server_onupdate=FetchedValue(), ) - .. _mysql_timestamp_null: TIMESTAMP Columns and NULL @@ -932,7 +989,9 @@ class MyClass(Base): TIMESTAMP datatype implicitly includes a default value of CURRENT_TIMESTAMP, even though this is not stated, and additionally sets the column as NOT NULL, the opposite behavior vs. that of all -other datatypes:: +other datatypes: + +.. sourcecode:: text mysql> CREATE TABLE ts_test ( -> a INTEGER, @@ -977,19 +1036,24 @@ class MyClass(Base): from sqlalchemy.dialects.mysql import TIMESTAMP m = MetaData() - t = Table('ts_test', m, - Column('a', Integer), - Column('b', Integer, nullable=False), - Column('c', TIMESTAMP), - Column('d', TIMESTAMP, nullable=False) - ) + t = Table( + "ts_test", + m, + Column("a", Integer), + Column("b", Integer, nullable=False), + Column("c", TIMESTAMP), + Column("d", TIMESTAMP, nullable=False), + ) from sqlalchemy import create_engine + e = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo=True) m.create_all(e) -output:: +output: + +.. sourcecode:: sql CREATE TABLE ts_test ( a INTEGER, @@ -1001,11 +1065,22 @@ class MyClass(Base): """ # noqa from __future__ import annotations -from array import array as _array from collections import defaultdict from itertools import compress import re +from typing import Any +from typing import Callable from typing import cast +from typing import DefaultDict +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from . import reflection as _reflection from .enumerated import ENUM @@ -1048,7 +1123,6 @@ class MyClass(Base): from .types import YEAR from ... import exc from ... import literal_column -from ... import log from ... import schema as sa_schema from ... import sql from ... import util @@ -1072,10 +1146,46 @@ class MyClass(Base): from ...types import BLOB from ...types import BOOLEAN from ...types import DATE +from ...types import LargeBinary from ...types import UUID from ...types import VARBINARY from ...util import topological +if TYPE_CHECKING: + + from ...dialects.mysql import expression + from ...dialects.mysql.dml import OnDuplicateClause + from ...engine.base import Connection + from ...engine.cursor import CursorResult + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import PoolProxiedConnection + from ...engine.interfaces import ReflectedCheckConstraint + from ...engine.interfaces import ReflectedColumn + from ...engine.interfaces import ReflectedForeignKeyConstraint + from ...engine.interfaces import ReflectedIndex + from ...engine.interfaces import ReflectedPrimaryKeyConstraint + from ...engine.interfaces import ReflectedTableComment + from ...engine.interfaces import ReflectedUniqueConstraint + from ...engine.row import Row + from ...engine.url import URL + from ...schema import Table + from ...sql import ddl + from ...sql import selectable + from ...sql.dml import _DMLTableElement + from ...sql.dml import Delete + from ...sql.dml import Update + from ...sql.dml import ValuesBase + from ...sql.functions import aggregate_strings + from ...sql.functions import random + from ...sql.functions import rollup + from ...sql.functions import sysdate + from ...sql.schema import Sequence as Sequence_SchemaItem + from ...sql.type_api import TypeEngine + from ...sql.visitors import ExternallyTraversible + SET_RE = re.compile( r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE @@ -1170,7 +1280,7 @@ class MyClass(Base): class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self): + def post_exec(self) -> None: if ( self.isdelete and cast(SQLCompiler, self.compiled).effective_returning @@ -1187,7 +1297,7 @@ def post_exec(self): _cursor.FullyBufferedCursorFetchStrategy( self.cursor, [ - (entry.keyname, None) + (entry.keyname, None) # type: ignore[misc] for entry in cast( SQLCompiler, self.compiled )._result_columns @@ -1196,14 +1306,18 @@ def post_exec(self): ) ) - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: if self.dialect.supports_server_side_cursors: - return self._dbapi_connection.cursor(self.dialect._sscursor) + return self._dbapi_connection.cursor( + self.dialect._sscursor # type: ignore[attr-defined] + ) else: raise NotImplementedError() - def fire_sequence(self, seq, type_): - return self._execute_scalar( + def fire_sequence( + self, seq: Sequence_SchemaItem, type_: sqltypes.Integer + ) -> int: + return self._execute_scalar( # type: ignore[no-any-return] ( "select nextval(%s)" % self.identifier_preparer.format_sequence(seq) @@ -1213,46 +1327,51 @@ def fire_sequence(self, seq, type_): class MySQLCompiler(compiler.SQLCompiler): + dialect: MySQLDialect render_table_with_column_in_update_from = True """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update({"milliseconds": "millisecond"}) - def default_from(self): + def default_from(self) -> str: """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. """ if self.stack: stmt = self.stack[-1]["selectable"] - if stmt._where_criteria: + if stmt._where_criteria: # type: ignore[attr-defined] return " FROM DUAL" return "" - def visit_random_func(self, fn, **kw): + def visit_random_func(self, fn: random, **kw: Any) -> str: return "rand%s" % self.function_argspec(fn) - def visit_rollup_func(self, fn, **kw): + def visit_rollup_func(self, fn: rollup[Any], **kw: Any) -> str: clause = ", ".join( elem._compiler_dispatch(self, **kw) for elem in fn.clauses ) return f"{clause} WITH ROLLUP" - def visit_aggregate_strings_func(self, fn, **kw): + def visit_aggregate_strings_func( + self, fn: aggregate_strings, **kw: Any + ) -> str: expr, delimeter = ( elem._compiler_dispatch(self, **kw) for elem in fn.clauses ) return f"group_concat({expr} SEPARATOR {delimeter})" - def visit_sequence(self, seq, **kw): - return "nextval(%s)" % self.preparer.format_sequence(seq) + def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str: + return "nextval(%s)" % self.preparer.format_sequence(sequence) - def visit_sysdate_func(self, fn, **kw): + def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str: return "SYSDATE()" - def _render_json_extract_from_binary(self, binary, operator, **kw): + def _render_json_extract_from_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: # note we are intentionally calling upon the process() calls in the # order in which they appear in the SQL String as this is used # by positional parameter rendering @@ -1279,9 +1398,10 @@ def _render_json_extract_from_binary(self, binary, operator, **kw): ) ) elif binary.type._type_affinity is sqltypes.Numeric: + binary_type = cast(sqltypes.Numeric[Any], binary.type) if ( - binary.type.scale is not None - and binary.type.precision is not None + binary_type.scale is not None + and binary_type.precision is not None ): # using DECIMAL here because MySQL does not recognize NUMERIC type_expression = ( @@ -1289,8 +1409,8 @@ def _render_json_extract_from_binary(self, binary, operator, **kw): % ( self.process(binary.left, **kw), self.process(binary.right, **kw), - binary.type.precision, - binary.type.scale, + binary_type.precision, + binary_type.scale, ) ) else: @@ -1324,15 +1444,22 @@ def _render_json_extract_from_binary(self, binary, operator, **kw): return case_expression + " " + type_expression + " END" - def visit_json_getitem_op_binary(self, binary, operator, **kw): + def visit_json_getitem_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return self._render_json_extract_from_binary(binary, operator, **kw) - def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + def visit_json_path_getitem_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return self._render_json_extract_from_binary(binary, operator, **kw) - def visit_on_duplicate_key_update(self, on_duplicate, **kw): - statement = self.current_executable + def visit_on_duplicate_key_update( + self, on_duplicate: OnDuplicateClause, **kw: Any + ) -> str: + statement: ValuesBase = self.current_executable + cols: List[elements.KeyedColumnElement[Any]] if on_duplicate._parameter_ordering: parameter_ordering = [ coercions.expect(roles.DMLColumnRole, key) @@ -1345,49 +1472,56 @@ def visit_on_duplicate_key_update(self, on_duplicate, **kw): if key in statement.table.c ] + [c for c in statement.table.c if c.key not in ordered_keys] else: - cols = statement.table.c + cols = list(statement.table.c) clauses = [] - requires_mysql8_alias = ( + requires_mysql8_alias = statement.select is None and ( self.dialect._requires_alias_for_on_duplicate_key ) if requires_mysql8_alias: - if statement.table.name.lower() == "new": + if statement.table.name.lower() == "new": # type: ignore[union-attr] # noqa: E501 _on_dup_alias_name = "new_1" else: _on_dup_alias_name = "new" + on_duplicate_update = { + coercions.expect_as_key(roles.DMLColumnRole, key): value + for key, value in on_duplicate.update.items() + } + # traverses through all table columns to preserve table column order - for column in (col for col in cols if col.key in on_duplicate.update): - val = on_duplicate.update[column.key] + for column in (col for col in cols if col.key in on_duplicate_update): + val = on_duplicate_update[column.key] + # TODO: this coercion should be up front. we can't cache + # SQL constructs with non-bound literals buried in them if coercions._is_literal(val): val = elements.BindParameter(None, val, type_=column.type) value_text = self.process(val.self_group(), use_schema=False) else: - def replace(obj): + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: if ( - isinstance(obj, elements.BindParameter) - and obj.type._isnull + isinstance(element, elements.BindParameter) + and element.type._isnull ): - obj = obj._clone() - obj.type = column.type - return obj + return element._with_binary_element_type(column.type) elif ( - isinstance(obj, elements.ColumnClause) - and obj.table is on_duplicate.inserted_alias + isinstance(element, elements.ColumnClause) + and element.table is on_duplicate.inserted_alias ): if requires_mysql8_alias: column_literal_clause = ( f"{_on_dup_alias_name}." - f"{self.preparer.quote(obj.name)}" + f"{self.preparer.quote(element.name)}" ) else: column_literal_clause = ( - f"VALUES({self.preparer.quote(obj.name)})" + f"VALUES({self.preparer.quote(element.name)})" ) return literal_column(column_literal_clause) else: @@ -1400,13 +1534,13 @@ def replace(obj): name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) - non_matching = set(on_duplicate.update) - {c.key for c in cols} + non_matching = set(on_duplicate_update) - {c.key for c in cols} if non_matching: util.warn( "Additional column names not matching " "any column keys in table '%s': %s" % ( - self.statement.table.name, + self.statement.table.name, # type: ignore[union-attr] (", ".join("'%s'" % c for c in non_matching)), ) ) @@ -1420,13 +1554,15 @@ def replace(obj): return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" def visit_concat_op_expression_clauselist( - self, clauselist, operator, **kw - ): + self, clauselist: elements.ClauseList, operator: Any, **kw: Any + ) -> str: return "concat(%s)" % ( ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) ) - def visit_concat_op_binary(self, binary, operator, **kw): + def visit_concat_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return "concat(%s, %s)" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), @@ -1449,10 +1585,12 @@ def visit_concat_op_binary(self, binary, operator, **kw): "WITH QUERY EXPANSION", ) - def visit_mysql_match(self, element, **kw): + def visit_mysql_match(self, element: expression.match, **kw: Any) -> str: return self.visit_match_op_binary(element, element.operator, **kw) - def visit_match_op_binary(self, binary, operator, **kw): + def visit_match_op_binary( + self, binary: expression.match, operator: Any, **kw: Any + ) -> str: """ Note that `mysql_boolean_mode` is enabled by default because of backward compatibility @@ -1473,12 +1611,11 @@ def visit_match_op_binary(self, binary, operator, **kw): "with_query_expansion=%s" % query_expansion, ) - flags = ", ".join(flags) + flags_str = ", ".join(flags) - raise exc.CompileError("Invalid MySQL match flags: %s" % flags) + raise exc.CompileError("Invalid MySQL match flags: %s" % flags_str) - match_clause = binary.left - match_clause = self.process(match_clause, **kw) + match_clause = self.process(binary.left, **kw) against_clause = self.process(binary.right, **kw) if any(flag_combination): @@ -1487,21 +1624,25 @@ def visit_match_op_binary(self, binary, operator, **kw): flag_combination, ) - against_clause = [against_clause] - against_clause.extend(flag_expressions) - - against_clause = " ".join(against_clause) + against_clause = " ".join([against_clause, *flag_expressions]) return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) - def get_from_hint_text(self, table, text): + def get_from_hint_text( + self, table: selectable.FromClause, text: Optional[str] + ) -> Optional[str]: return text - def visit_typeclause(self, typeclause, type_=None, **kw): + def visit_typeclause( + self, + typeclause: elements.TypeClause, + type_: Optional[TypeEngine[Any]] = None, + **kw: Any, + ) -> Optional[str]: if type_ is None: type_ = typeclause.type.dialect_impl(self.dialect) if isinstance(type_, sqltypes.TypeDecorator): - return self.visit_typeclause(typeclause, type_.impl, **kw) + return self.visit_typeclause(typeclause, type_.impl, **kw) # type: ignore[arg-type] # noqa: E501 elif isinstance(type_, sqltypes.Integer): if getattr(type_, "unsigned", False): return "UNSIGNED INTEGER" @@ -1540,7 +1681,7 @@ def visit_typeclause(self, typeclause, type_=None, **kw): else: return None - def visit_cast(self, cast, **kw): + def visit_cast(self, cast: elements.Cast[Any], **kw: Any) -> str: type_ = self.process(cast.typeclause) if type_ is None: util.warn( @@ -1554,7 +1695,9 @@ def visit_cast(self, cast, **kw): return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) - def render_literal_value(self, value, type_): + def render_literal_value( + self, value: Optional[str], type_: TypeEngine[Any] + ) -> str: value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") @@ -1562,16 +1705,18 @@ def render_literal_value(self, value, type_): # override native_boolean=False behavior here, as # MySQL still supports native boolean - def visit_true(self, element, **kw): + def visit_true(self, expr: elements.True_, **kw: Any) -> str: return "true" - def visit_false(self, element, **kw): + def visit_false(self, expr: elements.False_, **kw: Any) -> str: return "false" - def get_select_precolumns(self, select, **kw): + def get_select_precolumns( + self, select: selectable.Select[Any], **kw: Any + ) -> str: """Add special MySQL keywords in place of DISTINCT. - .. deprecated 1.4:: this usage is deprecated. + .. deprecated:: 1.4 This usage is deprecated. :meth:`_expression.Select.prefix_with` should be used for special keywords at the start of a SELECT. @@ -1588,7 +1733,13 @@ def get_select_precolumns(self, select, **kw): return super().get_select_precolumns(select, **kw) - def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + def visit_join( + self, + join: selectable.Join, + asfrom: bool = False, + from_linter: Optional[compiler.FromLinter] = None, + **kwargs: Any, + ) -> str: if from_linter: from_linter.edges.add((join.left, join.right)) @@ -1609,18 +1760,21 @@ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): join.right, asfrom=True, from_linter=from_linter, **kwargs ), " ON ", - self.process(join.onclause, from_linter=from_linter, **kwargs), + self.process(join.onclause, from_linter=from_linter, **kwargs), # type: ignore[arg-type] # noqa: E501 ) ) - def for_update_clause(self, select, **kw): + def for_update_clause( + self, select: selectable.GenerativeSelect, **kw: Any + ) -> str: + assert select._for_update_arg is not None if select._for_update_arg.read: tmp = " LOCK IN SHARE MODE" else: tmp = " FOR UPDATE" if select._for_update_arg.of and self.dialect.supports_for_update_of: - tables = util.OrderedSet() + tables: util.OrderedSet[elements.ClauseElement] = util.OrderedSet() for c in select._for_update_arg.of: tables.update(sql_util.surface_selectables_only(c)) @@ -1637,7 +1791,9 @@ def for_update_clause(self, select, **kw): return tmp - def limit_clause(self, select, **kw): + def limit_clause( + self, select: selectable.GenerativeSelect, **kw: Any + ) -> str: # MySQL supports: # LIMIT # LIMIT , @@ -1673,17 +1829,31 @@ def limit_clause(self, select, **kw): self.process(limit_clause, **kw), ) else: + assert limit_clause is not None # No offset provided, so just use the limit return " \n LIMIT %s" % (self.process(limit_clause, **kw),) - def update_limit_clause(self, update_stmt): + def update_limit_clause(self, update_stmt: Update) -> Optional[str]: limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) - if limit: - return "LIMIT %s" % limit + if limit is not None: + return f"LIMIT {int(limit)}" + else: + return None + + def delete_limit_clause(self, delete_stmt: Delete) -> Optional[str]: + limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + if limit is not None: + return f"LIMIT {int(limit)}" else: return None - def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + def update_tables_clause( + self, + update_stmt: Update, + from_table: _DMLTableElement, + extra_froms: List[selectable.FromClause], + **kw: Any, + ) -> str: kw["asfrom"] = True return ", ".join( t._compiler_dispatch(self, **kw) @@ -1691,11 +1861,22 @@ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): ) def update_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw - ): + self, + update_stmt: Update, + from_table: _DMLTableElement, + extra_froms: List[selectable.FromClause], + from_hints: Any, + **kw: Any, + ) -> None: return None - def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + def delete_table_clause( + self, + delete_stmt: Delete, + from_table: _DMLTableElement, + extra_froms: List[selectable.FromClause], + **kw: Any, + ) -> str: """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1705,8 +1886,13 @@ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): ) def delete_extra_from_clause( - self, delete_stmt, from_table, extra_froms, from_hints, **kw - ): + self, + delete_stmt: Delete, + from_table: _DMLTableElement, + extra_froms: List[selectable.FromClause], + from_hints: Any, + **kw: Any, + ) -> str: """Render the DELETE .. USING clause specific to MySQL.""" kw["asfrom"] = True return "USING " + ", ".join( @@ -1714,7 +1900,9 @@ def delete_extra_from_clause( for t in [from_table] + extra_froms ) - def visit_empty_set_expr(self, element_types, **kw): + def visit_empty_set_expr( + self, element_types: List[TypeEngine[Any]], **kw: Any + ) -> str: return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " "as _empty_set WHERE 1!=1" @@ -1729,25 +1917,38 @@ def visit_empty_set_expr(self, element_types, **kw): } ) - def visit_is_distinct_from_binary(self, binary, operator, **kw): + def visit_is_distinct_from_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return "NOT (%s <=> %s)" % ( self.process(binary.left), self.process(binary.right), ) - def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + def visit_is_not_distinct_from_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return "%s <=> %s" % ( self.process(binary.left), self.process(binary.right), ) - def _mariadb_regexp_flags(self, flags, pattern, **kw): + def _mariadb_regexp_flags( + self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any + ) -> str: return "CONCAT('(?', %s, ')', %s)" % ( self.render_literal_value(flags, sqltypes.STRINGTYPE), self.process(pattern, **kw), ) - def _regexp_match(self, op_string, binary, operator, **kw): + def _regexp_match( + self, + op_string: str, + binary: elements.BinaryExpression[Any], + operator: Any, + **kw: Any, + ) -> str: + assert binary.modifiers is not None flags = binary.modifiers["flags"] if flags is None: return self._generate_generic_binary(binary, op_string, **kw) @@ -1768,13 +1969,20 @@ def _regexp_match(self, op_string, binary, operator, **kw): else: return text - def visit_regexp_match_op_binary(self, binary, operator, **kw): + def visit_regexp_match_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return self._regexp_match(" REGEXP ", binary, operator, **kw) - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + def visit_not_regexp_match_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return self._regexp_match(" NOT REGEXP ", binary, operator, **kw) - def visit_regexp_replace_op_binary(self, binary, operator, **kw): + def visit_regexp_replace_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: + assert binary.modifiers is not None flags = binary.modifiers["flags"] if flags is None: return "REGEXP_REPLACE(%s, %s)" % ( @@ -1796,7 +2004,11 @@ def visit_regexp_replace_op_binary(self, binary, operator, **kw): class MySQLDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kw): + dialect: MySQLDialect + + def get_column_specification( + self, column: sa_schema.Column[Any], **kw: Any + ) -> str: """Builds column DDL.""" if ( self.dialect.is_mariadb is True @@ -1849,11 +2061,25 @@ def get_column_specification(self, column, **kw): colspec.append("AUTO_INCREMENT") else: default = self.get_column_default_string(column) + if default is not None: - colspec.append("DEFAULT " + default) + if ( + self.dialect._support_default_function + and not re.match(r"^\s*[\'\"\(]", default) + and not re.search(r"ON +UPDATE", default, re.I) + and not re.match( + r"\bnow\(\d+\)|\bcurrent_timestamp\(\d+\)", + default, + re.I, + ) + and re.match(r".*\W.*", default) + ): + colspec.append(f"DEFAULT ({default})") + else: + colspec.append("DEFAULT " + default) return " ".join(colspec) - def post_create_table(self, table): + def post_create_table(self, table: sa_schema.Table) -> str: """Build table-level CREATE options like ENGINE and COLLATE.""" table_opts = [] @@ -1937,25 +2163,27 @@ def post_create_table(self, table): return " ".join(table_opts) - def visit_create_index(self, create, **kw): + def visit_create_index(self, create: ddl.CreateIndex, **kw: Any) -> str: # type: ignore[override] # noqa: E501 index = create.element self._verify_index_table(index) preparer = self.preparer - table = preparer.format_table(index.table) + table = preparer.format_table(index.table) # type: ignore[arg-type] columns = [ self.sql_compiler.process( - elements.Grouping(expr) - if ( - isinstance(expr, elements.BinaryExpression) - or ( - isinstance(expr, elements.UnaryExpression) - and expr.modifier - not in (operators.desc_op, operators.asc_op) + ( + elements.Grouping(expr) # type: ignore[arg-type] + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) ) - or isinstance(expr, functions.FunctionElement) - ) - else expr, + else expr + ), include_table=False, literal_binds=True, ) @@ -1983,25 +2211,27 @@ def visit_create_index(self, create, **kw): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns = ", ".join( - "%s(%d)" % (expr, length[col.name]) - if col.name in length - else ( - "%s(%d)" % (expr, length[expr]) - if expr in length - else "%s" % expr + columns_str = ", ".join( + ( + "%s(%d)" % (expr, length[col.name]) # type: ignore[union-attr] # noqa: E501 + if col.name in length # type: ignore[union-attr] + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr + ) ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns = ", ".join( + columns_str = ", ".join( "%s(%d)" % (col, length) for col in columns ) else: - columns = ", ".join(columns) - text += "(%s)" % columns + columns_str = ", ".join(columns) + text += "(%s)" % columns_str parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: @@ -2013,14 +2243,16 @@ def visit_create_index(self, create, **kw): return text - def visit_primary_key_constraint(self, constraint, **kw): + def visit_primary_key_constraint( + self, constraint: sa_schema.PrimaryKeyConstraint, **kw: Any + ) -> str: text = super().visit_primary_key_constraint(constraint) using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text - def visit_drop_index(self, drop, **kw): + def visit_drop_index(self, drop: ddl.DropIndex, **kw: Any) -> str: index = drop.element text = "\nDROP INDEX " if drop.if_exists: @@ -2028,10 +2260,12 @@ def visit_drop_index(self, drop, **kw): return text + "%s ON %s" % ( self._prepared_index_name(index, include_schema=False), - self.preparer.format_table(index.table), + self.preparer.format_table(index.table), # type: ignore[arg-type] ) - def visit_drop_constraint(self, drop, **kw): + def visit_drop_constraint( + self, drop: ddl.DropConstraint, **kw: Any + ) -> str: constraint = drop.element if isinstance(constraint, sa_schema.ForeignKeyConstraint): qual = "FOREIGN KEY " @@ -2057,7 +2291,9 @@ def visit_drop_constraint(self, drop, **kw): const, ) - def define_constraint_match(self, constraint): + def define_constraint_match( + self, constraint: sa_schema.ForeignKeyConstraint + ) -> str: if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " @@ -2065,7 +2301,9 @@ def define_constraint_match(self, constraint): ) return "" - def visit_set_table_comment(self, create, **kw): + def visit_set_table_comment( + self, create: ddl.SetTableComment, **kw: Any + ) -> str: return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -2073,12 +2311,16 @@ def visit_set_table_comment(self, create, **kw): ), ) - def visit_drop_table_comment(self, create, **kw): + def visit_drop_table_comment( + self, drop: ddl.DropTableComment, **kw: Any + ) -> str: return "ALTER TABLE %s COMMENT ''" % ( - self.preparer.format_table(create.element) + self.preparer.format_table(drop.element) ) - def visit_set_column_comment(self, create, **kw): + def visit_set_column_comment( + self, create: ddl.SetColumnComment, **kw: Any + ) -> str: return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), @@ -2087,7 +2329,7 @@ def visit_set_column_comment(self, create, **kw): class MySQLTypeCompiler(compiler.GenericTypeCompiler): - def _extend_numeric(self, type_, spec): + def _extend_numeric(self, type_: _NumericType, spec: str) -> str: "Extend a numeric-type declaration with MySQL specific extensions." if not self._mysql_type(type_): @@ -2099,13 +2341,15 @@ def _extend_numeric(self, type_, spec): spec += " ZEROFILL" return spec - def _extend_string(self, type_, defaults, spec): + def _extend_string( + self, type_: _StringType, defaults: Dict[str, Any], spec: str + ) -> str: """Extend a string-type declaration with standard SQL CHARACTER SET / COLLATE annotations and MySQL specific extensions. """ - def attr(name): + def attr(name: str) -> Any: return getattr(type_, name, defaults.get(name)) if attr("charset"): @@ -2115,6 +2359,7 @@ def attr(name): elif attr("unicode"): charset = "UNICODE" else: + charset = None if attr("collation"): @@ -2133,10 +2378,10 @@ def attr(name): [c for c in (spec, charset, collation) if c is not None] ) - def _mysql_type(self, type_): + def _mysql_type(self, type_: Any) -> bool: return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: NUMERIC, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2151,7 +2396,7 @@ def visit_NUMERIC(self, type_, **kw): % {"precision": type_.precision, "scale": type_.scale}, ) - def visit_DECIMAL(self, type_, **kw): + def visit_DECIMAL(self, type_: DECIMAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2166,7 +2411,7 @@ def visit_DECIMAL(self, type_, **kw): % {"precision": type_.precision, "scale": type_.scale}, ) - def visit_DOUBLE(self, type_, **kw): + def visit_DOUBLE(self, type_: DOUBLE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.precision is not None and type_.scale is not None: return self._extend_numeric( type_, @@ -2176,7 +2421,7 @@ def visit_DOUBLE(self, type_, **kw): else: return self._extend_numeric(type_, "DOUBLE") - def visit_REAL(self, type_, **kw): + def visit_REAL(self, type_: REAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.precision is not None and type_.scale is not None: return self._extend_numeric( type_, @@ -2186,7 +2431,7 @@ def visit_REAL(self, type_, **kw): else: return self._extend_numeric(type_, "REAL") - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT(self, type_: FLOAT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if ( self._mysql_type(type_) and type_.scale is not None @@ -2202,7 +2447,7 @@ def visit_FLOAT(self, type_, **kw): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: INTEGER, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2212,7 +2457,7 @@ def visit_INTEGER(self, type_, **kw): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_, **kw): + def visit_BIGINT(self, type_: BIGINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2222,7 +2467,7 @@ def visit_BIGINT(self, type_, **kw): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_, **kw): + def visit_MEDIUMINT(self, type_: MEDIUMINT, **kw: Any) -> str: if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2232,7 +2477,7 @@ def visit_MEDIUMINT(self, type_, **kw): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_, **kw): + def visit_TINYINT(self, type_: TINYINT, **kw: Any) -> str: if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "TINYINT(%s)" % type_.display_width @@ -2240,7 +2485,7 @@ def visit_TINYINT(self, type_, **kw): else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_, **kw): + def visit_SMALLINT(self, type_: SMALLINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2250,55 +2495,55 @@ def visit_SMALLINT(self, type_, **kw): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_, **kw): + def visit_BIT(self, type_: BIT, **kw: Any) -> str: if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: DATETIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if getattr(type_, "fsp", None): - return "DATETIME(%d)" % type_.fsp + return "DATETIME(%d)" % type_.fsp # type: ignore[str-format] else: return "DATETIME" - def visit_DATE(self, type_, **kw): + def visit_DATE(self, type_: DATE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 return "DATE" - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: TIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if getattr(type_, "fsp", None): - return "TIME(%d)" % type_.fsp + return "TIME(%d)" % type_.fsp # type: ignore[str-format] else: return "TIME" - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: TIMESTAMP, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if getattr(type_, "fsp", None): - return "TIMESTAMP(%d)" % type_.fsp + return "TIMESTAMP(%d)" % type_.fsp # type: ignore[str-format] else: return "TIMESTAMP" - def visit_YEAR(self, type_, **kw): + def visit_YEAR(self, type_: YEAR, **kw: Any) -> str: if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_, **kw): + def visit_TEXT(self, type_: TEXT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.length is not None: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_, **kw): + def visit_TINYTEXT(self, type_: TINYTEXT, **kw: Any) -> str: return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_, **kw): + def visit_MEDIUMTEXT(self, type_: MEDIUMTEXT, **kw: Any) -> str: return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_, **kw): + def visit_LONGTEXT(self, type_: LONGTEXT, **kw: Any) -> str: return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_, **kw): + def visit_VARCHAR(self, type_: VARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.length is not None: return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: @@ -2306,7 +2551,7 @@ def visit_VARCHAR(self, type_, **kw): "VARCHAR requires a length on dialect %s" % self.dialect.name ) - def visit_CHAR(self, type_, **kw): + def visit_CHAR(self, type_: CHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if type_.length is not None: return self._extend_string( type_, {}, "CHAR(%(length)s)" % {"length": type_.length} @@ -2314,7 +2559,7 @@ def visit_CHAR(self, type_, **kw): else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_, **kw): + def visit_NVARCHAR(self, type_: NVARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length is not None: @@ -2328,7 +2573,7 @@ def visit_NVARCHAR(self, type_, **kw): "NVARCHAR requires a length on dialect %s" % self.dialect.name ) - def visit_NCHAR(self, type_, **kw): + def visit_NCHAR(self, type_: NCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length is not None: @@ -2340,61 +2585,70 @@ def visit_NCHAR(self, type_, **kw): else: return self._extend_string(type_, {"national": True}, "CHAR") - def visit_UUID(self, type_, **kw): + def visit_UUID(self, type_: UUID[Any], **kw: Any) -> str: # type: ignore[override] # NOQA: E501 return "UUID" - def visit_VARBINARY(self, type_, **kw): - return "VARBINARY(%d)" % type_.length + def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str: + return "VARBINARY(%d)" % type_.length # type: ignore[str-format] - def visit_JSON(self, type_, **kw): + def visit_JSON(self, type_: JSON, **kw: Any) -> str: return "JSON" - def visit_large_binary(self, type_, **kw): + def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str: return self.visit_BLOB(type_) - def visit_enum(self, type_, **kw): + def visit_enum(self, type_: ENUM, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 if not type_.native_enum: return super().visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: LargeBinary, **kw: Any) -> str: if type_.length is not None: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_, **kw): + def visit_TINYBLOB(self, type_: TINYBLOB, **kw: Any) -> str: return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_, **kw): + def visit_MEDIUMBLOB(self, type_: MEDIUMBLOB, **kw: Any) -> str: return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_, **kw): + def visit_LONGBLOB(self, type_: LONGBLOB, **kw: Any) -> str: return "LONGBLOB" - def _visit_enumerated_values(self, name, type_, enumerated_values): + def _visit_enumerated_values( + self, name: str, type_: _StringType, enumerated_values: Sequence[str] + ) -> str: quoted_enums = [] for e in enumerated_values: + if self.dialect.identifier_preparer._double_percents: + e = e.replace("%", "%%") quoted_enums.append("'%s'" % e.replace("'", "''")) return self._extend_string( type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_, **kw): + def visit_ENUM(self, type_: ENUM, **kw: Any) -> str: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_SET(self, type_, **kw): + def visit_SET(self, type_: SET, **kw: Any) -> str: return self._visit_enumerated_values("SET", type_, type_.values) - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str: return "BOOL" class MySQLIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS_MYSQL - def __init__(self, dialect, server_ansiquotes=False, **kw): + def __init__( + self, + dialect: default.DefaultDialect, + server_ansiquotes: bool = False, + **kw: Any, + ): if not server_ansiquotes: quote = "`" else: @@ -2402,7 +2656,7 @@ def __init__(self, dialect, server_ansiquotes=False, **kw): super().__init__(dialect, initial_quote=quote, escape_quote=quote) - def _quote_free_identifiers(self, *ids): + def _quote_free_identifiers(self, *ids: Optional[str]) -> Tuple[str, ...]: """Unilaterally identifier-quote any number of strings.""" return tuple([self.quote_identifier(i) for i in ids if i is not None]) @@ -2412,7 +2666,6 @@ class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): reserved_words = RESERVED_WORDS_MARIADB -@log.class_logger class MySQLDialect(default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code. @@ -2427,6 +2680,10 @@ class MySQLDialect(default.DefaultDialect): # allow for the "true" and "false" keywords, however supports_native_boolean = False + # support for BIT type; mysqlconnector coerces result values automatically, + # all other MySQL DBAPIs require a conversion routine + supports_native_bit = False + # identifiers are 64, however aliases can be 255... max_identifier_length = 255 max_index_name_length = 64 @@ -2475,9 +2732,9 @@ class MySQLDialect(default.DefaultDialect): ddl_compiler = MySQLDDLCompiler type_compiler_cls = MySQLTypeCompiler ischema_names = ischema_names - preparer = MySQLIdentifierPreparer + preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer - is_mariadb = False + is_mariadb: bool = False _mariadb_normalized_version_info = None # default SQL compilation settings - @@ -2486,9 +2743,13 @@ class MySQLDialect(default.DefaultDialect): _backslash_escapes = True _server_ansiquotes = False + server_version_info: Tuple[int, ...] + identifier_preparer: MySQLIdentifierPreparer + construct_arguments = [ (sa_schema.Table, {"*": None}), (sql.Update, {"limit": None}), + (sql.Delete, {"limit": None}), (sa_schema.PrimaryKeyConstraint, {"using": None}), ( sa_schema.Index, @@ -2503,18 +2764,20 @@ class MySQLDialect(default.DefaultDialect): def __init__( self, - json_serializer=None, - json_deserializer=None, - is_mariadb=None, - **kwargs, - ): + json_serializer: Optional[Callable[..., Any]] = None, + json_deserializer: Optional[Callable[..., Any]] = None, + is_mariadb: Optional[bool] = None, + **kwargs: Any, + ) -> None: kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self._json_serializer = json_serializer self._json_deserializer = json_deserializer - self._set_mariadb(is_mariadb, None) + self._set_mariadb(is_mariadb, ()) - def get_isolation_level_values(self, dbapi_conn): + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> Sequence[IsolationLevel]: return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -2522,13 +2785,17 @@ def get_isolation_level_values(self, dbapi_conn): "REPEATABLE READ", ) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: cursor = dbapi_connection.cursor() cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}") cursor.execute("COMMIT") cursor.close() - def get_isolation_level(self, dbapi_connection): + def get_isolation_level( + self, dbapi_connection: DBAPIConnection + ) -> IsolationLevel: cursor = dbapi_connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): cursor.execute("SELECT @@transaction_isolation") @@ -2545,10 +2812,10 @@ def get_isolation_level(self, dbapi_connection): cursor.close() if isinstance(val, bytes): val = val.decode() - return val.upper().replace("-", " ") + return val.upper().replace("-", " ") # type: ignore[no-any-return] @classmethod - def _is_mariadb_from_url(cls, url): + def _is_mariadb_from_url(cls, url: URL) -> bool: dbapi = cls.import_dbapi() dialect = cls(dbapi=dbapi) @@ -2557,7 +2824,7 @@ def _is_mariadb_from_url(cls, url): try: cursor = conn.cursor() cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") - val = cursor.fetchone()[0] + val = cursor.fetchone()[0] # type: ignore[index] except: raise else: @@ -2565,22 +2832,25 @@ def _is_mariadb_from_url(cls, url): finally: conn.close() - def _get_server_version_info(self, connection): + def _get_server_version_info( + self, connection: Connection + ) -> Tuple[int, ...]: # get database server version info explicitly over the wire # to avoid proxy servers like MaxScale getting in the # way with their own values, see #4205 dbapi_con = connection.connection cursor = dbapi_con.cursor() cursor.execute("SELECT VERSION()") - val = cursor.fetchone()[0] + + val = cursor.fetchone()[0] # type: ignore[index] cursor.close() if isinstance(val, bytes): val = val.decode() return self._parse_server_version(val) - def _parse_server_version(self, val): - version = [] + def _parse_server_version(self, val: str) -> Tuple[int, ...]: + version: List[int] = [] is_mariadb = False r = re.compile(r"[.\-+]") @@ -2601,7 +2871,7 @@ def _parse_server_version(self, val): server_version_info = tuple(version) self._set_mariadb( - server_version_info and is_mariadb, server_version_info + bool(server_version_info and is_mariadb), server_version_info ) if not is_mariadb: @@ -2617,7 +2887,9 @@ def _parse_server_version(self, val): self.server_version_info = server_version_info return server_version_info - def _set_mariadb(self, is_mariadb, server_version_info): + def _set_mariadb( + self, is_mariadb: Optional[bool], server_version_info: Tuple[int, ...] + ) -> None: if is_mariadb is None: return @@ -2627,10 +2899,12 @@ def _set_mariadb(self, is_mariadb, server_version_info): % (".".join(map(str, server_version_info)),) ) if is_mariadb: - self.preparer = MariaDBIdentifierPreparer - # this would have been set by the default dialect already, - # so set it again - self.identifier_preparer = self.preparer(self) + + if not issubclass(self.preparer, MariaDBIdentifierPreparer): + self.preparer = MariaDBIdentifierPreparer + # this would have been set by the default dialect already, + # so set it again + self.identifier_preparer = self.preparer(self) # this will be updated on first connect in initialize() # if using older mariadb version @@ -2639,38 +2913,54 @@ def _set_mariadb(self, is_mariadb, server_version_info): self.is_mariadb = is_mariadb - def do_begin_twophase(self, connection, xid): + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) - def do_prepare_twophase(self, connection, xid): + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: connection.execute(sql.text("XA END :xid"), dict(xid=xid)) connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid)) def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False - ): + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: if not is_prepared: connection.execute(sql.text("XA END :xid"), dict(xid=xid)) connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid)) def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False - ): + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid)) - def do_recover_twophase(self, connection): + def do_recover_twophase(self, connection: Connection) -> List[Any]: resultset = connection.exec_driver_sql("XA RECOVER") - return [row["data"][0 : row["gtrid_length"]] for row in resultset] + return [ + row["data"][0 : row["gtrid_length"]] + for row in resultset.mappings() + ] - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: if isinstance( e, ( - self.dbapi.OperationalError, - self.dbapi.ProgrammingError, - self.dbapi.InterfaceError, + self.dbapi.OperationalError, # type: ignore + self.dbapi.ProgrammingError, # type: ignore + self.dbapi.InterfaceError, # type: ignore ), ) and self._extract_error_code(e) in ( 1927, @@ -2683,7 +2973,7 @@ def is_disconnect(self, e, connection, cursor): ): return True elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) # type: ignore # noqa: E501 ): # if underlying connection is closed, # this is the error you get @@ -2691,13 +2981,17 @@ def is_disconnect(self, e, connection, cursor): else: return False - def _compat_fetchall(self, rp, charset=None): + def _compat_fetchall( + self, rp: CursorResult[Any], charset: Optional[str] = None + ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" return [_DecodingRow(row, charset) for row in rp.fetchall()] - def _compat_fetchone(self, rp, charset=None): + def _compat_fetchone( + self, rp: CursorResult[Any], charset: Optional[str] = None + ) -> Union[Row[Any], None, _DecodingRow]: """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" @@ -2707,7 +3001,9 @@ def _compat_fetchone(self, rp, charset=None): else: return None - def _compat_first(self, rp, charset=None): + def _compat_first( + self, rp: CursorResult[Any], charset: Optional[str] = None + ) -> Optional[_DecodingRow]: """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" @@ -2717,14 +3013,22 @@ def _compat_first(self, rp, charset=None): else: return None - def _extract_error_code(self, exception): + def _extract_error_code( + self, exception: DBAPIModule.Error + ) -> Optional[int]: raise NotImplementedError() - def _get_default_schema_name(self, connection): - return connection.exec_driver_sql("SELECT DATABASE()").scalar() + def _get_default_schema_name(self, connection: Connection) -> str: + return connection.exec_driver_sql("SELECT DATABASE()").scalar() # type: ignore[return-value] # noqa: E501 @reflection.cache - def has_table(self, connection, table_name, schema=None, **kw): + def has_table( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: self._ensure_has_table_connection(connection) if schema is None: @@ -2765,12 +3069,18 @@ def has_table(self, connection, table_name, schema=None, **kw): # # there's more "doesn't exist" kinds of messages but they are # less clear if mysql 8 would suddenly start using one of those - if self._extract_error_code(e.orig) in (1146, 1049, 1051): + if self._extract_error_code(e.orig) in (1146, 1049, 1051): # type: ignore # noqa: E501 return False raise @reflection.cache - def has_sequence(self, connection, sequence_name, schema=None, **kw): + def has_sequence( + self, + connection: Connection, + sequence_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -2790,14 +3100,16 @@ def has_sequence(self, connection, sequence_name, schema=None, **kw): ) return cursor.first() is not None - def _sequences_not_supported(self): + def _sequences_not_supported(self) -> NoReturn: raise NotImplementedError( "Sequences are supported only by the " "MariaDB series 10.3 or greater" ) @reflection.cache - def get_sequence_names(self, connection, schema=None, **kw): + def get_sequence_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -2817,10 +3129,12 @@ def get_sequence_names(self, connection, schema=None, **kw): ) ] - def initialize(self, connection): + def initialize(self, connection: Connection) -> None: # this is driver-based, does not need server version info # and is fairly critical for even basic SQL operations - self._connection_charset = self._detect_charset(connection) + self._connection_charset: Optional[str] = self._detect_charset( + connection + ) # call super().initialize() because we need to have # server_version_info set up. in 1.4 under python 2 only this does the @@ -2864,9 +3178,10 @@ def initialize(self, connection): self._warn_for_known_db_issues() - def _warn_for_known_db_issues(self): + def _warn_for_known_db_issues(self) -> None: if self.is_mariadb: mdb_version = self._mariadb_normalized_version_info + assert mdb_version is not None if mdb_version > (10, 2) and mdb_version < (10, 2, 9): util.warn( "MariaDB %r before 10.2.9 has known issues regarding " @@ -2879,7 +3194,7 @@ def _warn_for_known_db_issues(self): ) @property - def _support_float_cast(self): + def _support_float_cast(self) -> bool: if not self.server_version_info: return False elif self.is_mariadb: @@ -2890,32 +3205,49 @@ def _support_float_cast(self): return self.server_version_info >= (8, 0, 17) @property - def _is_mariadb(self): + def _support_default_function(self) -> bool: + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/ + return self.server_version_info >= (10, 2, 1) + else: + # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa + return self.server_version_info >= (8, 0, 13) + + @property + def _is_mariadb(self) -> bool: return self.is_mariadb @property - def _is_mysql(self): + def _is_mysql(self) -> bool: return not self.is_mariadb @property - def _is_mariadb_102(self): - return self.is_mariadb and self._mariadb_normalized_version_info > ( - 10, - 2, + def _is_mariadb_102(self) -> bool: + return ( + self.is_mariadb + and self._mariadb_normalized_version_info # type:ignore[operator] + > ( + 10, + 2, + ) ) @reflection.cache - def get_schema_names(self, connection, **kw): + def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: rp = connection.exec_driver_sql("SHOW schemas") return [r[0] for r in rp] @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def get_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: """Return a Unicode SHOW TABLES from a given schema.""" if schema is not None: - current_schema = schema + current_schema: str = schema else: - current_schema = self.default_schema_name + current_schema = self.default_schema_name # type: ignore charset = self._connection_charset @@ -2931,9 +3263,12 @@ def get_table_names(self, connection, schema=None, **kw): ] @reflection.cache - def get_view_names(self, connection, schema=None, **kw): + def get_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: if schema is None: schema = self.default_schema_name + assert schema is not None charset = self._connection_charset rp = connection.exec_driver_sql( "SHOW FULL TABLES FROM %s" @@ -2946,7 +3281,13 @@ def get_view_names(self, connection, schema=None, **kw): ] @reflection.cache - def get_table_options(self, connection, table_name, schema=None, **kw): + def get_table_options( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> Dict[str, Any]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -2956,7 +3297,13 @@ def get_table_options(self, connection, table_name, schema=None, **kw): return ReflectionDefaults.table_options() @reflection.cache - def get_columns(self, connection, table_name, schema=None, **kw): + def get_columns( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedColumn]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -2966,7 +3313,13 @@ def get_columns(self, connection, table_name, schema=None, **kw): return ReflectionDefaults.columns() @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedPrimaryKeyConstraint: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -2978,13 +3331,19 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): return ReflectionDefaults.pk_constraint() @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + def get_foreign_keys( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedForeignKeyConstraint]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) default_schema = None - fkeys = [] + fkeys: List[ReflectedForeignKeyConstraint] = [] for spec in parsed_state.fk_constraints: ref_name = spec["table"][-1] @@ -3004,7 +3363,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): if spec.get(opt, False) not in ("NO ACTION", None): con_kw[opt] = spec[opt] - fkey_d = { + fkey_d: ReflectedForeignKeyConstraint = { "name": spec["name"], "constrained_columns": loc_names, "referred_schema": ref_schema, @@ -3019,7 +3378,11 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): return fkeys if fkeys else ReflectionDefaults.foreign_keys() - def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): + def _correct_for_mysql_bugs_88718_96365( + self, + fkeys: List[ReflectedForeignKeyConstraint], + connection: Connection, + ) -> None: # Foreign key is always in lower case (MySQL 8.0) # https://bugs.mysql.com/bug.php?id=88718 # issue #4344 for SQLAlchemy @@ -3035,38 +3398,60 @@ def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): if self._casing in (1, 2): - def lower(s): + def lower(s: str) -> str: return s.lower() else: # if on case sensitive, there can be two tables referenced # with the same name different casing, so we need to use # case-sensitive matching. - def lower(s): + def lower(s: str) -> str: return s - default_schema_name = connection.dialect.default_schema_name - col_tuples = [ - ( - lower(rec["referred_schema"] or default_schema_name), - lower(rec["referred_table"]), - col_name, + default_schema_name: str = connection.dialect.default_schema_name # type: ignore # noqa: E501 + + # NOTE: using (table_schema, table_name, lower(column_name)) in (...) + # is very slow since mysql does not seem able to properly use indexse. + # Unpack the where condition instead. + schema_by_table_by_column: DefaultDict[ + str, DefaultDict[str, List[str]] + ] = DefaultDict(lambda: DefaultDict(list)) + for rec in fkeys: + sch = lower(rec["referred_schema"] or default_schema_name) + tbl = lower(rec["referred_table"]) + for col_name in rec["referred_columns"]: + schema_by_table_by_column[sch][tbl].append(col_name) + + if schema_by_table_by_column: + + condition = sql.or_( + *( + sql.and_( + _info_columns.c.table_schema == schema, + sql.or_( + *( + sql.and_( + _info_columns.c.table_name == table, + sql.func.lower( + _info_columns.c.column_name + ).in_(columns), + ) + for table, columns in tables.items() + ) + ), + ) + for schema, tables in schema_by_table_by_column.items() + ) ) - for rec in fkeys - for col_name in rec["referred_columns"] - ] - if col_tuples: - correct_for_wrong_fk_case = connection.execute( - sql.text( - """ - select table_schema, table_name, column_name - from information_schema.columns - where (table_schema, table_name, lower(column_name)) in - :table_data; - """ - ).bindparams(sql.bindparam("table_data", expanding=True)), - dict(table_data=col_tuples), + select = sql.select( + _info_columns.c.table_schema, + _info_columns.c.table_name, + _info_columns.c.column_name, + ).where(condition) + + correct_for_wrong_fk_case: CursorResult[Tuple[str, str, str]] = ( + connection.execute(select) ) # in casing=0, table name and schema name come back in their @@ -3079,35 +3464,41 @@ def lower(s): # SHOW CREATE TABLE converts them to *lower case*, therefore # not matching. So for this case, case-insensitive lookup # is necessary - d = defaultdict(dict) + d: DefaultDict[Tuple[str, str], Dict[str, str]] = defaultdict(dict) for schema, tname, cname in correct_for_wrong_fk_case: d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema d[(lower(schema), lower(tname))]["TABLENAME"] = tname d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - rec = d[ + rec_b = d[ ( lower(fkey["referred_schema"] or default_schema_name), lower(fkey["referred_table"]), ) ] - fkey["referred_table"] = rec["TABLENAME"] + fkey["referred_table"] = rec_b["TABLENAME"] if fkey["referred_schema"] is not None: - fkey["referred_schema"] = rec["SCHEMANAME"] + fkey["referred_schema"] = rec_b["SCHEMANAME"] fkey["referred_columns"] = [ - rec[col.lower()] for col in fkey["referred_columns"] + rec_b[col.lower()] for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints(self, connection, table_name, schema=None, **kw): + def get_check_constraints( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedCheckConstraint]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - cks = [ + cks: List[ReflectedCheckConstraint] = [ {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @@ -3115,7 +3506,13 @@ def get_check_constraints(self, connection, table_name, schema=None, **kw): return cks if cks else ReflectionDefaults.check_constraints() @reflection.cache - def get_table_comment(self, connection, table_name, schema=None, **kw): + def get_table_comment( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedTableComment: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -3126,12 +3523,18 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): return ReflectionDefaults.table_comment() @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): + def get_indexes( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedIndex]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - indexes = [] + indexes: List[ReflectedIndex] = [] for spec in parsed_state.keys: dialect_options = {} @@ -3143,32 +3546,30 @@ def get_indexes(self, connection, table_name, schema=None, **kw): unique = True elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["%s_prefix" % self.name] = flavor - elif flavor is None: - pass - else: - self.logger.info( + elif flavor is not None: + util.warn( "Converting unknown KEY type %s to a plain KEY", flavor ) - pass if spec["parser"]: dialect_options["%s_with_parser" % (self.name)] = spec[ "parser" ] - index_d = {} + index_d: ReflectedIndex = { + "name": spec["name"], + "column_names": [s[0] for s in spec["columns"]], + "unique": unique, + } - index_d["name"] = spec["name"] - index_d["column_names"] = [s[0] for s in spec["columns"]] mysql_length = { s[0]: s[1] for s in spec["columns"] if s[1] is not None } if mysql_length: dialect_options["%s_length" % self.name] = mysql_length - index_d["unique"] = unique if flavor: - index_d["type"] = flavor + index_d["type"] = flavor # type: ignore[typeddict-unknown-key] if dialect_options: index_d["dialect_options"] = dialect_options @@ -3179,13 +3580,17 @@ def get_indexes(self, connection, table_name, schema=None, **kw): @reflection.cache def get_unique_constraints( - self, connection, table_name, schema=None, **kw - ): + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedUniqueConstraint]: parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - ucs = [ + ucs: List[ReflectedUniqueConstraint] = [ { "name": key["name"], "column_names": [col[0] for col in key["columns"]], @@ -3201,7 +3606,13 @@ def get_unique_constraints( return ReflectionDefaults.unique_constraints() @reflection.cache - def get_view_definition(self, connection, view_name, schema=None, **kw): + def get_view_definition( + self, + connection: Connection, + view_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> str: charset = self._connection_charset full_name = ".".join( self.identifier_preparer._quote_free_identifiers(schema, view_name) @@ -3215,8 +3626,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): return sql def _parsed_state_or_create( - self, connection, table_name, schema=None, **kw - ): + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> _reflection.ReflectedState: return self._setup_parser( connection, table_name, @@ -3225,7 +3640,7 @@ def _parsed_state_or_create( ) @util.memoized_property - def _tabledef_parser(self): + def _tabledef_parser(self) -> _reflection.MySQLTableDefinitionParser: """return the MySQLTableDefinitionParser, generate if needed. The deferred creation ensures that the dialect has @@ -3236,7 +3651,13 @@ def _tabledef_parser(self): return _reflection.MySQLTableDefinitionParser(self, preparer) @reflection.cache - def _setup_parser(self, connection, table_name, schema=None, **kw): + def _setup_parser( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> _reflection.ReflectedState: charset = self._connection_charset parser = self._tabledef_parser full_name = ".".join( @@ -3252,10 +3673,14 @@ def _setup_parser(self, connection, table_name, schema=None, **kw): columns = self._describe_table( connection, None, charset, full_name=full_name ) - sql = parser._describe_to_create(table_name, columns) + sql = parser._describe_to_create( + table_name, columns # type: ignore[arg-type] + ) return parser.parse(sql, charset) - def _fetch_setting(self, connection, setting_name): + def _fetch_setting( + self, connection: Connection, setting_name: str + ) -> Optional[str]: charset = self._connection_charset if self.server_version_info and self.server_version_info < (5, 6): @@ -3270,12 +3695,12 @@ def _fetch_setting(self, connection, setting_name): if not row: return None else: - return row[fetch_col] + return cast(Optional[str], row[fetch_col]) - def _detect_charset(self, connection): + def _detect_charset(self, connection: Connection) -> str: raise NotImplementedError() - def _detect_casing(self, connection): + def _detect_casing(self, connection: Connection) -> int: """Sniff out identifier case sensitivity. Cached per-connection. This value can not change without a server @@ -3299,7 +3724,7 @@ def _detect_casing(self, connection): self._casing = cs return cs - def _detect_collations(self, connection): + def _detect_collations(self, connection: Connection) -> Dict[str, str]: """Pull the active COLLATIONS list from the server. Cached per-connection. @@ -3312,7 +3737,7 @@ def _detect_collations(self, connection): collations[row[0]] = row[1] return collations - def _detect_sql_mode(self, connection): + def _detect_sql_mode(self, connection: Connection) -> None: setting = self._fetch_setting(connection, "sql_mode") if setting is None: @@ -3324,7 +3749,7 @@ def _detect_sql_mode(self, connection): else: self._sql_mode = setting or "" - def _detect_ansiquotes(self, connection): + def _detect_ansiquotes(self, connection: Connection) -> None: """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode @@ -3339,34 +3764,81 @@ def _detect_ansiquotes(self, connection): # as of MySQL 5.0.1 self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode + @overload def _show_create_table( - self, connection, table, charset=None, full_name=None - ): + self, + connection: Connection, + table: Optional[Table], + charset: Optional[str], + full_name: str, + ) -> str: ... + + @overload + def _show_create_table( + self, + connection: Connection, + table: Table, + charset: Optional[str] = None, + full_name: None = None, + ) -> str: ... + + def _show_create_table( + self, + connection: Connection, + table: Optional[Table], + charset: Optional[str] = None, + full_name: Optional[str] = None, + ) -> str: """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: + assert table is not None full_name = self.identifier_preparer.format_table(table) st = "SHOW CREATE TABLE %s" % full_name - rp = None try: rp = connection.execution_options( skip_user_error_events=True ).exec_driver_sql(st) except exc.DBAPIError as e: - if self._extract_error_code(e.orig) == 1146: + if self._extract_error_code(e.orig) == 1146: # type: ignore[arg-type] # noqa: E501 raise exc.NoSuchTableError(full_name) from e else: raise row = self._compat_first(rp, charset=charset) if not row: raise exc.NoSuchTableError(full_name) - return row[1].strip() + return cast(str, row[1]).strip() - def _describe_table(self, connection, table, charset=None, full_name=None): + @overload + def _describe_table( + self, + connection: Connection, + table: Optional[Table], + charset: Optional[str], + full_name: str, + ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: ... + + @overload + def _describe_table( + self, + connection: Connection, + table: Table, + charset: Optional[str] = None, + full_name: None = None, + ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: ... + + def _describe_table( + self, + connection: Connection, + table: Optional[Table], + charset: Optional[str] = None, + full_name: Optional[str] = None, + ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: + assert table is not None full_name = self.identifier_preparer.format_table(table) st = "DESCRIBE %s" % full_name @@ -3377,7 +3849,7 @@ def _describe_table(self, connection, table, charset=None, full_name=None): skip_user_error_events=True ).exec_driver_sql(st) except exc.DBAPIError as e: - code = self._extract_error_code(e.orig) + code = self._extract_error_code(e.orig) # type: ignore[arg-type] # noqa: E501 if code == 1146: raise exc.NoSuchTableError(full_name) from e @@ -3409,7 +3881,7 @@ class _DecodingRow: # sets.Set(['value']) (seriously) but thankfully that doesn't # seem to come up in DDL queries. - _encoding_compat = { + _encoding_compat: Dict[str, str] = { "koi8r": "koi8_r", "koi8u": "koi8_u", "utf16": "utf-16-be", # MySQL's uft16 is always bigendian @@ -3419,25 +3891,33 @@ class _DecodingRow: "eucjpms": "ujis", } - def __init__(self, rowproxy, charset): + def __init__(self, rowproxy: Row[Any], charset: Optional[str]): self.rowproxy = rowproxy - self.charset = self._encoding_compat.get(charset, charset) + self.charset = ( + self._encoding_compat.get(charset, charset) + if charset is not None + else None + ) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Any: item = self.rowproxy[index] - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, bytes): return item.decode(self.charset) else: return item - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: item = getattr(self.rowproxy, attr) - if isinstance(item, _array): - item = item.tostring() if self.charset and isinstance(item, bytes): return item.decode(self.charset) else: return item + + +_info_columns = sql.table( + "columns", + sql.column("table_schema", VARCHAR(64)), + sql.column("table_name", VARCHAR(64)), + sql.column("column_name", VARCHAR(64)), + schema="information_schema", +) diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index ed3c60694aa..1d48c4e88bc 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -1,10 +1,9 @@ -# mysql/cymysql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/cymysql.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -21,18 +20,36 @@ dialects are mysqlclient and PyMySQL. """ # noqa +from __future__ import annotations + +from typing import Any +from typing import Iterable +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union -from .base import BIT from .base import MySQLDialect from .mysqldb import MySQLDialect_mysqldb +from .types import BIT from ... import util +if TYPE_CHECKING: + from ...engine.base import Connection + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import Dialect + from ...engine.interfaces import PoolProxiedConnection + from ...sql.type_api import _ResultProcessorType + class _cymysqlBIT(BIT): - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: """Convert MySQL's 64 bit, variable length binary string to a long.""" - def process(value): + def process(value: Optional[Iterable[int]]) -> Optional[int]: if value is not None: v = 0 for i in iter(value): @@ -55,17 +72,22 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: return __import__("cymysql") - def _detect_charset(self, connection): - return connection.connection.charset + def _detect_charset(self, connection: Connection) -> str: + return connection.connection.charset # type: ignore[no-any-return] - def _extract_error_code(self, exception): - return exception.errno + def _extract_error_code(self, exception: DBAPIModule.Error) -> int: + return exception.errno # type: ignore[no-any-return] - def is_disconnect(self, e, connection, cursor): - if isinstance(e, self.dbapi.OperationalError): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + if isinstance(e, self.loaded_dbapi.OperationalError): return self._extract_error_code(e) in ( 2006, 2013, @@ -73,7 +95,7 @@ def is_disconnect(self, e, connection, cursor): 2045, 2055, ) - elif isinstance(e, self.dbapi.InterfaceError): + elif isinstance(e, self.loaded_dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get return True diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index dfa39f6e086..cceb0818f9b 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -1,5 +1,5 @@ -# mysql/dml.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/dml.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Any +from typing import Dict from typing import List from typing import Mapping from typing import Optional @@ -141,7 +142,11 @@ def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self: in :ref:`tutorial_parameter_ordered_updates`:: insert().on_duplicate_key_update( - [("name", "some name"), ("value", "some value")]) + [ + ("name", "some name"), + ("value", "some value"), + ] + ) .. versionchanged:: 1.3 parameters can be specified as a dictionary or list of 2-tuples; the latter form provides for parameter @@ -181,6 +186,7 @@ class OnDuplicateClause(ClauseElement): _parameter_ordering: Optional[List[str]] = None + update: Dict[str, Any] stringify_dialect = "mysql" def __init__( diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index 2e1d3c3da9f..ab305207cc6 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -1,34 +1,51 @@ -# mysql/enumerated.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/enumerated.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations +import enum import re +from typing import Any +from typing import Dict +from typing import Optional +from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from .types import _StringType from ... import exc from ... import sql from ... import util from ...sql import sqltypes +from ...sql import type_api +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.elements import ColumnElement + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType + from ...sql.type_api import TypeEngine + from ...sql.type_api import TypeEngineMixin -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): + +class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType): """MySQL ENUM type.""" __visit_name__ = "ENUM" native_enum = True - def __init__(self, *enums, **kw): + def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None: """Construct an ENUM. E.g.:: - Column('myenum', ENUM("foo", "bar", "baz")) + Column("myenum", ENUM("foo", "bar", "baz")) :param enums: The range of valid values for this ENUM. Values in enums are not quoted, they will be escaped and surrounded by single @@ -62,21 +79,27 @@ def __init__(self, *enums, **kw): """ kw.pop("strict", None) - self._enum_init(enums, kw) + self._enum_init(enums, kw) # type: ignore[arg-type] _StringType.__init__(self, length=self.length, **kw) @classmethod - def adapt_emulated_to_native(cls, impl, **kw): + def adapt_emulated_to_native( + cls, + impl: Union[TypeEngine[Any], TypeEngineMixin], + **kw: Any, + ) -> ENUM: """Produce a MySQL native :class:`.mysql.ENUM` from plain :class:`.Enum`. """ + if TYPE_CHECKING: + assert isinstance(impl, ENUM) kw.setdefault("validate_strings", impl.validate_strings) kw.setdefault("values_callable", impl.values_callable) kw.setdefault("omit_aliases", impl._omit_aliases) return cls(**kw) - def _object_value_for_elem(self, elem): + def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]: # mysql sends back a blank string for any value that # was persisted that was not in the enums; that is, it does no # validation on the incoming data, it "truncates" it to be @@ -86,24 +109,27 @@ def _object_value_for_elem(self, elem): else: return super()._object_value_for_elem(elem) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[ENUM, _StringType, sqltypes.Enum] ) +# TODO: SET is a string as far as configuration but does not act like +# a string at the python level. We either need to make a py-type agnostic +# version of String as a base to be used for this, make this some kind of +# TypeDecorator, or just vendor it out as its own type. class SET(_StringType): """MySQL SET type.""" __visit_name__ = "SET" - def __init__(self, *values, **kw): + def __init__(self, *values: str, **kw: Any): """Construct a SET. E.g.:: - Column('myset', SET("foo", "bar", "baz")) - + Column("myset", SET("foo", "bar", "baz")) The list of potential values is required in the case that this set will be used to generate DDL for a table, or if the @@ -151,17 +177,19 @@ def __init__(self, *values, **kw): "setting retrieve_as_bitwise=True" ) if self.retrieve_as_bitwise: - self._bitmap = { + self._inversed_bitmap: Dict[str, int] = { value: 2**idx for idx, value in enumerate(self.values) } - self._bitmap.update( - (2**idx, value) for idx, value in enumerate(self.values) - ) + self._bitmap: Dict[int, str] = { + 2**idx: value for idx, value in enumerate(self.values) + } length = max([len(v) for v in values] + [0]) kw.setdefault("length", length) super().__init__(**kw) - def column_expression(self, colexpr): + def column_expression( + self, colexpr: ColumnElement[Any] + ) -> ColumnElement[Any]: if self.retrieve_as_bitwise: return sql.type_coerce( sql.type_coerce(colexpr, sqltypes.Integer) + 0, self @@ -169,10 +197,12 @@ def column_expression(self, colexpr): else: return colexpr - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: Any + ) -> Optional[_ResultProcessorType[Any]]: if self.retrieve_as_bitwise: - def process(value): + def process(value: Union[str, int, None]) -> Optional[Set[str]]: if value is not None: value = int(value) @@ -183,11 +213,14 @@ def process(value): else: super_convert = super().result_processor(dialect, coltype) - def process(value): + def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501 if isinstance(value, str): # MySQLdb returns a string, let's parse if super_convert: value = super_convert(value) + assert value is not None + if TYPE_CHECKING: + assert isinstance(value, str) return set(re.findall(r"[^,]+", value)) else: # mysql-connector-python does a naive @@ -198,43 +231,48 @@ def process(value): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> _BindProcessorType[Union[str, int]]: super_convert = super().bind_processor(dialect) if self.retrieve_as_bitwise: - def process(value): + def process( + value: Union[str, int, set[str], None], + ) -> Union[str, int, None]: if value is None: return None elif isinstance(value, (int, str)): if super_convert: - return super_convert(value) + return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501 else: return value else: int_value = 0 for v in value: - int_value |= self._bitmap[v] + int_value |= self._inversed_bitmap[v] return int_value else: - def process(value): + def process( + value: Union[str, int, set[str], None], + ) -> Union[str, int, None]: # accept strings and int (actually bitflag) values directly if value is not None and not isinstance(value, (int, str)): value = ",".join(value) - if super_convert: - return super_convert(value) + return super_convert(value) # type: ignore else: return value return process - def adapt(self, impltype, **kw): + def adapt(self, cls: type, **kw: Any) -> Any: kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise - return util.constructor_copy(self, impltype, *self.values, **kw) + return util.constructor_copy(self, cls, *self.values, **kw) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[SET, _StringType], diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py index c5bd0be02b0..9d19d52de5e 100644 --- a/lib/sqlalchemy/dialects/mysql/expression.py +++ b/lib/sqlalchemy/dialects/mysql/expression.py @@ -1,10 +1,13 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/expression.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any from ... import exc from ... import util @@ -17,7 +20,7 @@ from ...util.typing import Self -class match(Generative, elements.BinaryExpression): +class match(Generative, elements.BinaryExpression[Any]): """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. E.g.:: @@ -37,7 +40,9 @@ class match(Generative, elements.BinaryExpression): .order_by(desc(match_expr)) ) - Would produce SQL resembling:: + Would produce SQL resembling: + + .. sourcecode:: sql SELECT id, firstname, lastname FROM user @@ -70,8 +75,9 @@ class match(Generative, elements.BinaryExpression): __visit_name__ = "mysql_match" inherit_cache = True + modifiers: util.immutabledict[str, Any] - def __init__(self, *cols, **kw): + def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any): if not cols: raise exc.ArgumentError("columns are required") diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 66fcb714d54..e654a61941d 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -1,13 +1,21 @@ -# mysql/json.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/json.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING from ... import types as sqltypes +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + class JSON(sqltypes.JSON): """MySQL JSON type. @@ -34,13 +42,13 @@ class JSON(sqltypes.JSON): class _FormatTypeMixin: - def _format_value(self, value): + def _format_value(self, value: Any) -> str: raise NotImplementedError() - def bind_processor(self, dialect): - super_proc = self.string_bind_processor(dialect) + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> Any: value = self._format_value(value) if super_proc: value = super_proc(value) @@ -48,29 +56,31 @@ def process(value): return process - def literal_processor(self, dialect): - super_proc = self.string_literal_processor(dialect) + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> str: value = self._format_value(value) if super_proc: value = super_proc(value) - return value + return value # type: ignore[no-any-return] return process class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: if isinstance(value, int): - value = "$[%s]" % value + formatted_value = "$[%s]" % value else: - value = '$."%s"' % value - return value + formatted_value = '$."%s"' % value + return formatted_value class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: return "$%s" % ( "".join( [ diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py index a6ee5dfac93..508820e67ce 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadb.py +++ b/lib/sqlalchemy/dialects/mysql/mariadb.py @@ -1,32 +1,73 @@ -# mysql/mariadb.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/mariadb.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors + +from __future__ import annotations + +from typing import Any +from typing import Callable + from .base import MariaDBIdentifierPreparer from .base import MySQLDialect +from .base import MySQLIdentifierPreparer +from .base import MySQLTypeCompiler +from ...sql import sqltypes + + +class INET4(sqltypes.TypeEngine[str]): + """INET4 column type for MariaDB + + .. versionadded:: 2.0.37 + """ + + __visit_name__ = "INET4" + + +class INET6(sqltypes.TypeEngine[str]): + """INET6 column type for MariaDB + + .. versionadded:: 2.0.37 + """ + + __visit_name__ = "INET6" + + +class MariaDBTypeCompiler(MySQLTypeCompiler): + def visit_INET4(self, type_: INET4, **kwargs: Any) -> str: + return "INET4" + + def visit_INET6(self, type_: INET6, **kwargs: Any) -> str: + return "INET6" class MariaDBDialect(MySQLDialect): is_mariadb = True supports_statement_cache = True name = "mariadb" - preparer = MariaDBIdentifierPreparer + preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer + type_compiler_cls = MariaDBTypeCompiler -def loader(driver): - driver_mod = __import__( +def loader(driver: str) -> Callable[[], type[MariaDBDialect]]: + dialect_mod = __import__( "sqlalchemy.dialects.mysql.%s" % driver ).dialects.mysql - driver_cls = getattr(driver_mod, driver).dialect - - return type( - "MariaDBDialect_%s" % driver, - ( - MariaDBDialect, - driver_cls, - ), - {"supports_statement_cache": True}, - ) + + driver_mod = getattr(dialect_mod, driver) + if hasattr(driver_mod, "mariadb_dialect"): + driver_cls = driver_mod.mariadb_dialect + return driver_cls # type: ignore[no-any-return] + else: + driver_cls = driver_mod.dialect + + return type( + "MariaDBDialect_%s" % driver, + ( + MariaDBDialect, + driver_cls, + ), + {"supports_statement_cache": True}, + ) diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 9730c9b4da3..c6bb58a8d93 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -1,11 +1,9 @@ -# mysql/mariadbconnector.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/mariadbconnector.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - """ @@ -29,7 +27,15 @@ .. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python """ # noqa +from __future__ import annotations + import re +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from uuid import UUID as _python_UUID from .base import MySQLCompiler @@ -39,6 +45,19 @@ from ... import util from ...sql import sqltypes +if TYPE_CHECKING: + from ...engine.base import Connection + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import PoolProxiedConnection + from ...engine.url import URL + from ...sql.compiler import SQLCompiler + from ...sql.type_api import _ResultProcessorType + mariadb_cpy_minimum_version = (1, 0, 1) @@ -47,10 +66,12 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]): # work around JIRA issue # https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed, # this type can be removed. - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[Any]]: if self.as_uuid: - def process(value): + def process(value: Any) -> Any: if value is not None: if hasattr(value, "decode"): value = value.decode("ascii") @@ -60,7 +81,7 @@ def process(value): return process else: - def process(value): + def process(value: Any) -> Any: if value is not None: if hasattr(value, "decode"): value = value.decode("ascii") @@ -71,30 +92,27 @@ def process(value): class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): - _lastrowid = None + _lastrowid: Optional[int] = None - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor(buffered=False) - def create_default_cursor(self): + def create_default_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor(buffered=True) - def post_exec(self): + def post_exec(self) -> None: super().post_exec() self._rowcount = self.cursor.rowcount + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) if self.isinsert and self.compiled.postfetch_lastrowid: self._lastrowid = self.cursor.lastrowid - @property - def rowcount(self): - if self._rowcount is not None: - return self._rowcount - else: - return self.cursor.rowcount - - def get_lastrowid(self): + def get_lastrowid(self) -> int: + if TYPE_CHECKING: + assert self._lastrowid is not None return self._lastrowid @@ -133,7 +151,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) @util.memoized_property - def _dbapi_version(self): + def _dbapi_version(self) -> Tuple[int, ...]: if self.dbapi and hasattr(self.dbapi, "__version__"): return tuple( [ @@ -146,7 +164,7 @@ def _dbapi_version(self): else: return (99, 99, 99) - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.paramstyle = "qmark" if self.dbapi is not None: @@ -158,20 +176,26 @@ def __init__(self, **kwargs): ) @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: return __import__("mariadb") - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: if super().is_disconnect(e, connection, cursor): return True - elif isinstance(e, self.dbapi.Error): + elif isinstance(e, self.loaded_dbapi.Error): str_e = str(e).lower() return "not connected" in str_e or "isn't valid" in str_e else: return False - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: opts = url.translate_connect_args() + opts.update(url.query) int_params = [ "connect_timeout", @@ -186,6 +210,7 @@ def create_connect_args(self, url): "ssl_verify_cert", "ssl", "pool_reset_connection", + "compress", ] for key in int_params: @@ -205,19 +230,21 @@ def create_connect_args(self, url): except (AttributeError, ImportError): self.supports_sane_rowcount = False opts["client_flag"] = client_flag - return [[], opts] + return [], opts - def _extract_error_code(self, exception): + def _extract_error_code(self, exception: DBAPIModule.Error) -> int: try: - rc = exception.errno + rc: int = exception.errno except: rc = -1 return rc - def _detect_charset(self, connection): + def _detect_charset(self, connection: Connection) -> str: return "utf8mb4" - def get_isolation_level_values(self, dbapi_connection): + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> Sequence[IsolationLevel]: return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -226,21 +253,23 @@ def get_isolation_level_values(self, dbapi_connection): "AUTOCOMMIT", ) - def set_isolation_level(self, connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: if level == "AUTOCOMMIT": - connection.autocommit = True + dbapi_connection.autocommit = True else: - connection.autocommit = False - super().set_isolation_level(connection, level) + dbapi_connection.autocommit = False + super().set_isolation_level(dbapi_connection, level) - def do_begin_twophase(self, connection, xid): + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: connection.execute( sql.text("XA BEGIN :xid").bindparams( sql.bindparam("xid", xid, literal_execute=True) ) ) - def do_prepare_twophase(self, connection, xid): + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: connection.execute( sql.text("XA END :xid").bindparams( sql.bindparam("xid", xid, literal_execute=True) @@ -253,8 +282,12 @@ def do_prepare_twophase(self, connection, xid): ) def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False - ): + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: if not is_prepared: connection.execute( sql.text("XA END :xid").bindparams( @@ -268,8 +301,12 @@ def do_rollback_twophase( ) def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False - ): + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute( diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index fc90c65d2ad..a830cb5afef 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -1,10 +1,9 @@ -# mysql/mysqlconnector.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/mysqlconnector.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" @@ -14,26 +13,85 @@ :connectstring: mysql+mysqlconnector://:@[:]/ :url: https://pypi.org/project/mysql-connector-python/ -.. note:: +Driver Status +------------- + +MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the +degree which the driver is functional. There are still ongoing issues +with features such as server side cursors which remain disabled until +upstream issues are repaired. + +.. warning:: The MySQL Connector/Python driver published by Oracle is subject + to frequent, major regressions of essential functionality such as being able + to correctly persist simple binary strings which indicate it is not well + tested. The SQLAlchemy project is not able to maintain this dialect fully as + regressions in the driver prevent it from being included in continuous + integration. + +.. versionchanged:: 2.0.39 + + The MySQL Connector/Python dialect has been updated to support the + latest version of this DBAPI. Previously, MySQL Connector/Python + was not fully supported. However, support remains limited due to ongoing + regressions introduced in this driver. + +Connecting to MariaDB with MySQL Connector/Python +-------------------------------------------------- + +MySQL Connector/Python may attempt to pass an incompatible collation to the +database when connecting to MariaDB. Experimentation has shown that using +``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible +charset/collation will allow connectivity. - The MySQL Connector/Python DBAPI has had many issues since its release, - some of which may remain unresolved, and the mysqlconnector dialect is - **not tested as part of SQLAlchemy's continuous integration**. - The recommended MySQL dialects are mysqlclient and PyMySQL. """ # noqa +from __future__ import annotations import re - -from .base import BIT +from typing import Any +from typing import cast +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from .base import MariaDBIdentifierPreparer from .base import MySQLCompiler from .base import MySQLDialect +from .base import MySQLExecutionContext from .base import MySQLIdentifierPreparer +from .mariadb import MariaDBDialect +from .types import BIT from ... import util +if TYPE_CHECKING: + + from ...engine.base import Connection + from ...engine.cursor import CursorResult + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import PoolProxiedConnection + from ...engine.row import Row + from ...engine.url import URL + from ...sql.elements import BinaryExpression + + +class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): + def create_server_side_cursor(self) -> DBAPICursor: + return self._dbapi_connection.cursor(buffered=False) + + def create_default_cursor(self) -> DBAPICursor: + return self._dbapi_connection.cursor(buffered=True) + class MySQLCompiler_mysqlconnector(MySQLCompiler): - def visit_mod_binary(self, binary, operator, **kw): + def visit_mod_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: return ( self.process(binary.left, **kw) + " % " @@ -41,22 +99,37 @@ def visit_mod_binary(self, binary, operator, **kw): ) -class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): +class IdentifierPreparerCommon_mysqlconnector: @property - def _double_percents(self): + def _double_percents(self) -> bool: return False @_double_percents.setter - def _double_percents(self, value): + def _double_percents(self, value: Any) -> None: pass - def _escape_identifier(self, value): - value = value.replace(self.escape_quote, self.escape_to_quote) + def _escape_identifier(self, value: str) -> str: + value = value.replace( + self.escape_quote, # type:ignore[attr-defined] + self.escape_to_quote, # type:ignore[attr-defined] + ) return value +class MySQLIdentifierPreparer_mysqlconnector( + IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer +): + pass + + +class MariaDBIdentifierPreparer_mysqlconnector( + IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer +): + pass + + class _myconnpyBIT(BIT): - def result_processor(self, dialect, coltype): + def result_processor(self, dialect: Any, coltype: Any) -> None: """MySQL-connector already converts mysql bits, so.""" return None @@ -71,24 +144,31 @@ class MySQLDialect_mysqlconnector(MySQLDialect): supports_native_decimal = True + supports_native_bit = True + + # not until https://bugs.mysql.com/bug.php?id=117548 + supports_server_side_cursors = False + default_paramstyle = "format" statement_compiler = MySQLCompiler_mysqlconnector - preparer = MySQLIdentifierPreparer_mysqlconnector + execution_ctx_cls = MySQLExecutionContext_mysqlconnector + + preparer: type[MySQLIdentifierPreparer] = ( + MySQLIdentifierPreparer_mysqlconnector + ) colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) @classmethod - def import_dbapi(cls): - from mysql import connector + def import_dbapi(cls) -> DBAPIModule: + return cast("DBAPIModule", __import__("mysql.connector").connector) - return connector - - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: dbapi_connection.ping(False) return True - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: opts = url.translate_connect_args(username="user") opts.update(url.query) @@ -96,6 +176,7 @@ def create_connect_args(self, url): util.coerce_kw_type(opts, "allow_local_infile", bool) util.coerce_kw_type(opts, "autocommit", bool) util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "client_flag", int) util.coerce_kw_type(opts, "compress", bool) util.coerce_kw_type(opts, "connection_timeout", int) util.coerce_kw_type(opts, "connect_timeout", int) @@ -110,15 +191,21 @@ def create_connect_args(self, url): util.coerce_kw_type(opts, "use_pure", bool) util.coerce_kw_type(opts, "use_unicode", bool) - # unfortunately, MySQL/connector python refuses to release a - # cursor without reading fully, so non-buffered isn't an option - opts.setdefault("buffered", True) + # note that "buffered" is set to False by default in MySQL/connector + # python. If you set it to True, then there is no way to get a server + # side cursor because the logic is written to disallow that. + + # leaving this at True until + # https://bugs.mysql.com/bug.php?id=117548 can be fixed + opts["buffered"] = True # FOUND_ROWS must be set in ClientFlag to enable # supports_sane_rowcount. if self.dbapi is not None: try: - from mysql.connector.constants import ClientFlag + from mysql.connector import constants # type: ignore + + ClientFlag = constants.ClientFlag client_flags = opts.get( "client_flags", ClientFlag.get_default() @@ -127,24 +214,35 @@ def create_connect_args(self, url): opts["client_flags"] = client_flags except Exception: pass - return [[], opts] + + return [], opts @util.memoized_property - def _mysqlconnector_version_info(self): + def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]: if self.dbapi and hasattr(self.dbapi, "__version__"): m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + return None - def _detect_charset(self, connection): - return connection.connection.charset + def _detect_charset(self, connection: Connection) -> str: + return connection.connection.charset # type: ignore - def _extract_error_code(self, exception): - return exception.errno + def _extract_error_code(self, exception: BaseException) -> int: + return exception.errno # type: ignore - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: Exception, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: errnos = (2006, 2013, 2014, 2045, 2055, 2048) - exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) + exceptions = ( + self.loaded_dbapi.OperationalError, # + self.loaded_dbapi.InterfaceError, + self.loaded_dbapi.ProgrammingError, + ) if isinstance(e, exceptions): return ( e.errno in errnos @@ -154,26 +252,48 @@ def is_disconnect(self, e, connection, cursor): else: return False - def _compat_fetchall(self, rp, charset=None): + def _compat_fetchall( + self, + rp: CursorResult[Tuple[Any, ...]], + charset: Optional[str] = None, + ) -> Sequence[Row[Tuple[Any, ...]]]: return rp.fetchall() - def _compat_fetchone(self, rp, charset=None): + def _compat_fetchone( + self, + rp: CursorResult[Tuple[Any, ...]], + charset: Optional[str] = None, + ) -> Optional[Row[Tuple[Any, ...]]]: return rp.fetchone() - _isolation_lookup = { - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "AUTOCOMMIT", - } + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> Sequence[IsolationLevel]: + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ) - def _set_isolation_level(self, connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: if level == "AUTOCOMMIT": - connection.autocommit = True + dbapi_connection.autocommit = True else: - connection.autocommit = False - super()._set_isolation_level(connection, level) + dbapi_connection.autocommit = False + super().set_isolation_level(dbapi_connection, level) + + +class MariaDBDialect_mysqlconnector( + MariaDBDialect, MySQLDialect_mysqlconnector +): + supports_statement_cache = True + _allows_uuid_binds = False + preparer = MariaDBIdentifierPreparer_mysqlconnector dialect = MySQLDialect_mysqlconnector +mariadb_dialect = MariaDBDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index d1cf835c54e..de4ae61c047 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -1,11 +1,9 @@ -# mysql/mysqldb.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/mysqldb.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - """ @@ -48,9 +46,9 @@ "ssl": { "ca": "/home/gord/client-ssl/ca.pem", "cert": "/home/gord/client-ssl/client-cert.pem", - "key": "/home/gord/client-ssl/client-key.pem" + "key": "/home/gord/client-ssl/client-key.pem", } - } + }, ) For convenience, the following keys may also be specified inline within the URL @@ -74,7 +72,9 @@ ----------------------------------- Google Cloud SQL now recommends use of the MySQLdb dialect. Connect -using a URL like the following:: +using a URL like the following: + +.. sourcecode:: text mysql+mysqldb://root@/?unix_socket=/cloudsql/: @@ -84,25 +84,39 @@ The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. """ +from __future__ import annotations import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING from .base import MySQLCompiler from .base import MySQLDialect from .base import MySQLExecutionContext from .base import MySQLIdentifierPreparer -from .base import TEXT -from ... import sql from ... import util +from ...util.typing import Literal + +if TYPE_CHECKING: + + from ...engine.base import Connection + from ...engine.interfaces import _DBAPIMultiExecuteParams + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import ExecutionContext + from ...engine.interfaces import IsolationLevel + from ...engine.url import URL class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - @property - def rowcount(self): - if hasattr(self, "_rowcount"): - return self._rowcount - else: - return self.cursor.rowcount + pass class MySQLCompiler_mysqldb(MySQLCompiler): @@ -122,8 +136,9 @@ class MySQLDialect_mysqldb(MySQLDialect): execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer + server_version_info: Tuple[int, ...] - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._mysql_dbapi_version = ( self._parse_dbapi_version(self.dbapi.__version__) @@ -131,7 +146,7 @@ def __init__(self, **kwargs): else (0, 0, 0) ) - def _parse_dbapi_version(self, version): + def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]: m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) @@ -139,7 +154,7 @@ def _parse_dbapi_version(self, version): return (0, 0, 0) @util.langhelpers.memoized_property - def supports_server_side_cursors(self): + def supports_server_side_cursors(self) -> bool: try: cursors = __import__("MySQLdb.cursors").cursors self._sscursor = cursors.SSCursor @@ -148,13 +163,13 @@ def supports_server_side_cursors(self): return False @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: return __import__("MySQLdb") - def on_connect(self): + def on_connect(self) -> Callable[[DBAPIConnection], None]: super_ = super().on_connect() - def on_connect(conn): + def on_connect(conn: DBAPIConnection) -> None: if super_ is not None: super_(conn) @@ -167,43 +182,24 @@ def on_connect(conn): return on_connect - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]: dbapi_connection.ping() return True - def do_executemany(self, cursor, statement, parameters, context=None): + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: rowcount = cursor.executemany(statement, parameters) if context is not None: - context._rowcount = rowcount - - def _check_unicode_returns(self, connection): - # work around issue fixed in - # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 - # specific issue w/ the utf8mb4_bin collation and unicode returns - - collation = connection.exec_driver_sql( - "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation"), - ) - ).scalar() - has_utf8mb4_bin = self.server_version_info > (5,) and collation - if has_utf8mb4_bin: - additional_tests = [ - sql.collate( - sql.cast( - sql.literal_column("'test collated returns'"), - TEXT(charset="utf8mb4"), - ), - "utf8mb4_bin", - ) - ] - else: - additional_tests = [] - return super()._check_unicode_returns(connection, additional_tests) + cast(MySQLExecutionContext, context)._rowcount = rowcount - def create_connect_args(self, url, _translate_args=None): + def create_connect_args( + self, url: URL, _translate_args: Optional[Dict[str, Any]] = None + ) -> ConnectArgsType: if _translate_args is None: _translate_args = dict( database="db", username="user", password="passwd" @@ -217,7 +213,7 @@ def create_connect_args(self, url, _translate_args=None): util.coerce_kw_type(opts, "read_timeout", int) util.coerce_kw_type(opts, "write_timeout", int) util.coerce_kw_type(opts, "client_flag", int) - util.coerce_kw_type(opts, "local_infile", int) + util.coerce_kw_type(opts, "local_infile", bool) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. @@ -252,9 +248,9 @@ def create_connect_args(self, url, _translate_args=None): if client_flag_found_rows is not None: client_flag |= client_flag_found_rows opts["client_flag"] = client_flag - return [[], opts] + return [], opts - def _found_rows_client_flag(self): + def _found_rows_client_flag(self) -> Optional[int]: if self.dbapi is not None: try: CLIENT_FLAGS = __import__( @@ -263,20 +259,23 @@ def _found_rows_client_flag(self): except (AttributeError, ImportError): return None else: - return CLIENT_FLAGS.FOUND_ROWS + return CLIENT_FLAGS.FOUND_ROWS # type: ignore else: return None - def _extract_error_code(self, exception): - return exception.args[0] + def _extract_error_code(self, exception: DBAPIModule.Error) -> int: + return exception.args[0] # type: ignore[no-any-return] - def _detect_charset(self, connection): + def _detect_charset(self, connection: Connection) -> str: """Sniff out the character set in use for connection results.""" try: # note: the SQL here would be # "SHOW VARIABLES LIKE 'character_set%%'" - cset_name = connection.connection.character_set_name + + cset_name: Callable[[], str] = ( + connection.connection.character_set_name + ) except AttributeError: util.warn( "No 'character_set_name' can be detected with " @@ -288,7 +287,9 @@ def _detect_charset(self, connection): else: return cset_name() - def get_isolation_level_values(self, dbapi_connection): + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> Tuple[IsolationLevel, ...]: return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -297,7 +298,9 @@ def get_isolation_level_values(self, dbapi_connection): "AUTOCOMMIT", ) - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: if level == "AUTOCOMMIT": dbapi_connection.autocommit(True) else: diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py index b7faf771214..fe97672ad85 100644 --- a/lib/sqlalchemy/dialects/mysql/provision.py +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -1,5 +1,10 @@ +# dialects/mysql/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors - from ... import exc from ...testing.provision import configure_follower from ...testing.provision import create_db @@ -34,6 +39,13 @@ def generate_driver_url(url, driver, query_str): drivername="%s+%s" % (backend, driver) ).update_query_string(query_str) + if driver == "mariadbconnector": + new_url = new_url.difference_update_query(["charset"]) + elif driver == "mysqlconnector": + new_url = new_url.update_query_pairs( + [("collation", "utf8mb4_general_ci")] + ) + try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 6567202a45e..48b7994a82a 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -1,11 +1,9 @@ -# mysql/pymysql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/pymysql.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - r""" @@ -41,7 +39,6 @@ "&ssl_check_hostname=false" ) - MySQL-Python Compatibility -------------------------- @@ -50,9 +47,26 @@ to the pymysql driver as well. """ # noqa +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .mysqldb import MySQLDialect_mysqldb from ...util import langhelpers +from ...util.typing import Literal + +if TYPE_CHECKING: + + from ...engine.interfaces import ConnectArgsType + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import DBAPICursor + from ...engine.interfaces import DBAPIModule + from ...engine.interfaces import PoolProxiedConnection + from ...engine.url import URL class MySQLDialect_pymysql(MySQLDialect_mysqldb): @@ -62,7 +76,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): description_encoding = None @langhelpers.memoized_property - def supports_server_side_cursors(self): + def supports_server_side_cursors(self) -> bool: try: cursors = __import__("pymysql.cursors").cursors self._sscursor = cursors.SSCursor @@ -71,11 +85,11 @@ def supports_server_side_cursors(self): return False @classmethod - def import_dbapi(cls): + def import_dbapi(cls) -> DBAPIModule: return __import__("pymysql") @langhelpers.memoized_property - def _send_false_to_ping(self): + def _send_false_to_ping(self) -> bool: """determine if pymysql has deprecated, changed the default of, or removed the 'reconnect' argument of connection.ping(). @@ -86,7 +100,9 @@ def _send_false_to_ping(self): """ # noqa: E501 try: - Connection = __import__("pymysql.connections").Connection + Connection = __import__( + "pymysql.connections" + ).connections.Connection except (ImportError, AttributeError): return True else: @@ -100,7 +116,7 @@ def _send_false_to_ping(self): not insp.defaults or insp.defaults[0] is not False ) - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]: if self._send_false_to_ping: dbapi_connection.ping(False) else: @@ -108,17 +124,24 @@ def do_ping(self, dbapi_connection): return True - def create_connect_args(self, url, _translate_args=None): + def create_connect_args( + self, url: URL, _translate_args: Optional[Dict[str, Any]] = None + ) -> ConnectArgsType: if _translate_args is None: _translate_args = dict(username="user") return super().create_connect_args( url, _translate_args=_translate_args ) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: if super().is_disconnect(e, connection, cursor): return True - elif isinstance(e, self.dbapi.Error): + elif isinstance(e, self.loaded_dbapi.Error): str_e = str(e).lower() return ( "already closed" in str_e or "connection was killed" in str_e @@ -126,7 +149,7 @@ def is_disconnect(self, e, connection, cursor): else: return False - def _extract_error_code(self, exception): + def _extract_error_code(self, exception: BaseException) -> Any: if isinstance(exception.args[0], Exception): exception = exception.args[0] return exception.args[0] diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index e4b11778afc..86f1b3c89ad 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -1,15 +1,13 @@ -# mysql/pyodbc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/pyodbc.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r""" - .. dialect:: mysql+pyodbc :name: PyODBC :dbapi: pyodbc @@ -30,21 +28,30 @@ Pass through exact pyodbc connection string:: import urllib + connection_string = ( - 'DRIVER=MySQL ODBC 8.0 ANSI Driver;' - 'SERVER=localhost;' - 'PORT=3307;' - 'DATABASE=mydb;' - 'UID=root;' - 'PWD=(whatever);' - 'charset=utf8mb4;' + "DRIVER=MySQL ODBC 8.0 ANSI Driver;" + "SERVER=localhost;" + "PORT=3307;" + "DATABASE=mydb;" + "UID=root;" + "PWD=(whatever);" + "charset=utf8mb4;" ) params = urllib.parse.quote_plus(connection_string) connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params """ # noqa +from __future__ import annotations +import datetime import re +from typing import Any +from typing import Callable +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from .base import MySQLDialect from .base import MySQLExecutionContext @@ -54,23 +61,31 @@ from ...connectors.pyodbc import PyODBCConnector from ...sql.sqltypes import Time +if TYPE_CHECKING: + from ...engine import Connection + from ...engine.interfaces import DBAPIConnection + from ...engine.interfaces import Dialect + from ...sql.type_api import _ResultProcessorType + class _pyodbcTIME(TIME): - def result_processor(self, dialect, coltype): - def process(value): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[datetime.time]: + def process(value: Any) -> Union[datetime.time, None]: # pyodbc returns a datetime.time object; no need to convert - return value + return value # type: ignore[no-any-return] return process class MySQLExecutionContext_pyodbc(MySQLExecutionContext): - def get_lastrowid(self): + def get_lastrowid(self) -> int: cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") - lastrowid = cursor.fetchone()[0] + lastrowid = cursor.fetchone()[0] # type: ignore[index] cursor.close() - return lastrowid + return lastrowid # type: ignore[no-any-return] class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): @@ -81,7 +96,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): pyodbc_driver_name = "MySQL" - def _detect_charset(self, connection): + def _detect_charset(self, connection: Connection) -> str: """Sniff out the character set in use for connection results.""" # Prefer 'character_set_results' for the current connection over the @@ -106,21 +121,25 @@ def _detect_charset(self, connection): ) return "latin1" - def _get_server_version_info(self, connection): + def _get_server_version_info( + self, connection: Connection + ) -> Tuple[int, ...]: return MySQLDialect._get_server_version_info(self, connection) - def _extract_error_code(self, exception): + def _extract_error_code(self, exception: BaseException) -> Optional[int]: m = re.compile(r"\((\d+)\)").search(str(exception.args)) - c = m.group(1) + if m is None: + return None + c: Optional[str] = m.group(1) if c: return int(c) else: return None - def on_connect(self): + def on_connect(self) -> Callable[[DBAPIConnection], None]: super_ = super().on_connect() - def on_connect(conn): + def on_connect(conn: DBAPIConnection) -> None: if super_ is not None: super_(conn) diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index c4909fe319e..71bd8c45494 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -1,46 +1,65 @@ -# mysql/reflection.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/reflection.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - +from __future__ import annotations import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from .enumerated import ENUM from .enumerated import SET from .types import DATETIME from .types import TIME from .types import TIMESTAMP -from ... import log from ... import types as sqltypes from ... import util +from ...util.typing import Literal + +if TYPE_CHECKING: + from .base import MySQLDialect + from .base import MySQLIdentifierPreparer + from ...engine.interfaces import ReflectedColumn class ReflectedState: """Stores raw information about a SHOW CREATE TABLE statement.""" - def __init__(self): - self.columns = [] - self.table_options = {} - self.table_name = None - self.keys = [] - self.fk_constraints = [] - self.ck_constraints = [] + charset: Optional[str] + + def __init__(self) -> None: + self.columns: List[ReflectedColumn] = [] + self.table_options: Dict[str, str] = {} + self.table_name: Optional[str] = None + self.keys: List[Dict[str, Any]] = [] + self.fk_constraints: List[Dict[str, Any]] = [] + self.ck_constraints: List[Dict[str, Any]] = [] -@log.class_logger class MySQLTableDefinitionParser: """Parses the results of a SHOW CREATE TABLE statement.""" - def __init__(self, dialect, preparer): + def __init__( + self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer + ): self.dialect = dialect self.preparer = preparer self._prep_regexes() - def parse(self, show_create, charset): + def parse( + self, show_create: str, charset: Optional[str] + ) -> ReflectedState: state = ReflectedState() state.charset = charset for line in re.split(r"\r?\n", show_create): @@ -65,11 +84,11 @@ def parse(self, show_create, charset): if type_ is None: util.warn("Unknown schema content: %r" % line) elif type_ == "key": - state.keys.append(spec) + state.keys.append(spec) # type: ignore[arg-type] elif type_ == "fk_constraint": - state.fk_constraints.append(spec) + state.fk_constraints.append(spec) # type: ignore[arg-type] elif type_ == "ck_constraint": - state.ck_constraints.append(spec) + state.ck_constraints.append(spec) # type: ignore[arg-type] else: pass return state @@ -77,7 +96,13 @@ def parse(self, show_create, charset): def _check_view(self, sql: str) -> bool: return bool(self._re_is_view.match(sql)) - def _parse_constraints(self, line): + def _parse_constraints(self, line: str) -> Union[ + Tuple[None, str], + Tuple[Literal["partition"], str], + Tuple[ + Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str] + ], + ]: """Parse a KEY or CONSTRAINT line. :param line: A line of SHOW CREATE TABLE output @@ -127,7 +152,7 @@ def _parse_constraints(self, line): # No match. return (None, line) - def _parse_table_name(self, line, state): + def _parse_table_name(self, line: str, state: ReflectedState) -> None: """Extract the table name. :param line: The first line of SHOW CREATE TABLE @@ -138,7 +163,7 @@ def _parse_table_name(self, line, state): if m: state.table_name = cleanup(m.group("name")) - def _parse_table_options(self, line, state): + def _parse_table_options(self, line: str, state: ReflectedState) -> None: """Build a dictionary of all reflected table-level options. :param line: The final line of SHOW CREATE TABLE output. @@ -164,7 +189,9 @@ def _parse_table_options(self, line, state): for opt, val in options.items(): state.table_options["%s_%s" % (self.dialect.name, opt)] = val - def _parse_partition_options(self, line, state): + def _parse_partition_options( + self, line: str, state: ReflectedState + ) -> None: options = {} new_line = line[:] @@ -220,7 +247,7 @@ def _parse_partition_options(self, line, state): else: state.table_options["%s_%s" % (self.dialect.name, opt)] = val - def _parse_column(self, line, state): + def _parse_column(self, line: str, state: ReflectedState) -> None: """Extract column details. Falls back to a 'minimal support' variant if full parse fails. @@ -283,13 +310,16 @@ def _parse_column(self, line, state): type_instance = col_type(*type_args, **type_kw) - col_kw = {} + col_kw: Dict[str, Any] = {} # NOT NULL col_kw["nullable"] = True # this can be "NULL" in the case of TIMESTAMP if spec.get("notnull", False) == "NOT NULL": col_kw["nullable"] = False + # For generated columns, the nullability is marked in a different place + if spec.get("notnull_generated", False) == "NOT NULL": + col_kw["nullable"] = False # AUTO_INCREMENT if spec.get("autoincr", False): @@ -321,9 +351,13 @@ def _parse_column(self, line, state): name=name, type=type_instance, default=default, comment=comment ) col_d.update(col_kw) - state.columns.append(col_d) + state.columns.append(col_d) # type: ignore[arg-type] - def _describe_to_create(self, table_name, columns): + def _describe_to_create( + self, + table_name: str, + columns: Sequence[Tuple[str, str, str, str, str, str]], + ) -> str: """Re-format DESCRIBE output as a SHOW CREATE TABLE string. DESCRIBE is a much simpler reflection and is sufficient for @@ -376,7 +410,9 @@ def _describe_to_create(self, table_name, columns): ] ) - def _parse_keyexprs(self, identifiers): + def _parse_keyexprs( + self, identifiers: str + ) -> List[Tuple[str, Optional[int], str]]: """Unpack '"col"(2),"col" ASC'-ish strings into components.""" return [ @@ -386,11 +422,12 @@ def _parse_keyexprs(self, identifiers): ) ] - def _prep_regexes(self): + def _prep_regexes(self) -> None: """Pre-compile regular expressions.""" - self._re_columns = [] - self._pr_options = [] + self._pr_options: List[ + Tuple[re.Pattern[Any], Optional[Callable[[str], str]]] + ] = [] _final = self.preparer.final_quote @@ -448,11 +485,13 @@ def _prep_regexes(self): r"(?: +COLLATE +(?P[\w_]+))?" r"(?: +(?P(?:NOT )?NULL))?" r"(?: +DEFAULT +(?P" - r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+" r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" r"))?" r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" - r".*\))? ?(?PVIRTUAL|STORED)?)?" + r".*\))? ?(?PVIRTUAL|STORED)?" + r"(?: +(?P(?:NOT )?NULL))?" + r")?" r"(?: +(?PAUTO_INCREMENT))?" r"(?: +COMMENT +'(?P(?:''|[^'])*)')?" r"(?: +COLUMN_FORMAT +(?P\w+))?" @@ -500,7 +539,7 @@ def _prep_regexes(self): # # unique constraints come back as KEYs kw = quotes.copy() - kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT" self._re_fk_constraint = _re_compile( r" " r"CONSTRAINT +" @@ -577,21 +616,21 @@ def _prep_regexes(self): _optional_equals = r"(?:\s*(?:=\s*)|\s+)" - def _add_option_string(self, directive): + def _add_option_string(self, directive: str) -> None: regex = r"(?P%s)%s" r"'(?P(?:[^']|'')*?)'(?!')" % ( re.escape(directive), self._optional_equals, ) self._pr_options.append(_pr_compile(regex, cleanup_text)) - def _add_option_word(self, directive): + def _add_option_word(self, directive: str) -> None: regex = r"(?P%s)%s" r"(?P\w+)" % ( re.escape(directive), self._optional_equals, ) self._pr_options.append(_pr_compile(regex)) - def _add_partition_option_word(self, directive): + def _add_partition_option_word(self, directive: str) -> None: if directive == "PARTITION BY" or directive == "SUBPARTITION BY": regex = r"(?%s)%s" r"(?P\w+.*)" % ( re.escape(directive), @@ -606,7 +645,7 @@ def _add_partition_option_word(self, directive): regex = r"(?%s)(?!\S)" % (re.escape(directive),) self._pr_options.append(_pr_compile(regex)) - def _add_option_regex(self, directive, regex): + def _add_option_regex(self, directive: str, regex: str) -> None: regex = r"(?P%s)%s" r"(?P%s)" % ( re.escape(directive), self._optional_equals, @@ -624,21 +663,35 @@ def _add_option_regex(self, directive, regex): ) -def _pr_compile(regex, cleanup=None): +@overload +def _pr_compile( + regex: str, cleanup: Callable[[str], str] +) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ... + + +@overload +def _pr_compile( + regex: str, cleanup: None = None +) -> Tuple[re.Pattern[Any], None]: ... + + +def _pr_compile( + regex: str, cleanup: Optional[Callable[[str], str]] = None +) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]: """Prepare a 2-tuple of compiled regex and callable.""" return (_re_compile(regex), cleanup) -def _re_compile(regex): +def _re_compile(regex: str) -> re.Pattern[Any]: """Compile a string to regex, I and UNICODE.""" return re.compile(regex, re.I | re.UNICODE) -def _strip_values(values): +def _strip_values(values: Sequence[str]) -> List[str]: "Strip reflected values quotes" - strip_values = [] + strip_values: List[str] = [] for a in values: if a[0:1] == '"' or a[0:1] == "'": # strip enclosing quotes and unquote interior @@ -650,7 +703,9 @@ def _strip_values(values): def cleanup_text(raw_text: str) -> str: if "\\" in raw_text: raw_text = re.sub( - _control_char_regexp, lambda s: _control_char_map[s[0]], raw_text + _control_char_regexp, + lambda s: _control_char_map[s[0]], # type: ignore[index] + raw_text, ) return raw_text.replace("''", "'") diff --git a/lib/sqlalchemy/dialects/mysql/reserved_words.py b/lib/sqlalchemy/dialects/mysql/reserved_words.py index 9f3436e6379..ff526394a69 100644 --- a/lib/sqlalchemy/dialects/mysql/reserved_words.py +++ b/lib/sqlalchemy/dialects/mysql/reserved_words.py @@ -1,5 +1,5 @@ -# mysql/reserved_words.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/reserved_words.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,7 +11,6 @@ # https://mariadb.com/kb/en/reserved-words/ # includes: Reserved Words, Oracle Mode (separate set unioned) # excludes: Exceptions, Function Names -# mypy: ignore-errors RESERVED_WORDS_MARIADB = { "accessible", @@ -282,6 +281,7 @@ } ) +# https://dev.mysql.com/doc/refman/8.3/en/keywords.html # https://dev.mysql.com/doc/refman/8.0/en/keywords.html # https://dev.mysql.com/doc/refman/5.7/en/keywords.html # https://dev.mysql.com/doc/refman/5.6/en/keywords.html @@ -403,6 +403,7 @@ "int4", "int8", "integer", + "intersect", "interval", "into", "io_after_gtids", @@ -468,6 +469,7 @@ "outfile", "over", "parse_gcol_expr", + "parallel", "partition", "percent_rank", "persist", @@ -476,6 +478,7 @@ "primary", "procedure", "purge", + "qualify", "range", "rank", "read", diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index aa1de1b6992..455b0b6629e 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -1,18 +1,29 @@ -# mysql/types.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/mysql/types.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - +from __future__ import annotations import datetime +import decimal +from typing import Any +from typing import Iterable +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from ... import exc from ... import util from ...sql import sqltypes +if TYPE_CHECKING: + from .base import MySQLDialect + from ...engine.interfaces import Dialect + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _ResultProcessorType + class _NumericType: """Base for MySQL numeric types. @@ -22,19 +33,27 @@ class _NumericType: """ - def __init__(self, unsigned=False, zerofill=False, **kw): + def __init__( + self, unsigned: bool = False, zerofill: bool = False, **kw: Any + ): self.unsigned = unsigned self.zerofill = zerofill super().__init__(**kw) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[_NumericType, sqltypes.Numeric] ) -class _FloatType(_NumericType, sqltypes.Float): - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): +class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = True, + **kw: Any, + ): if isinstance(self, (REAL, DOUBLE)) and ( (precision is None and scale is not None) or (precision is not None and scale is None) @@ -46,18 +65,18 @@ def __init__(self, precision=None, scale=None, asdecimal=True, **kw): super().__init__(precision=precision, asdecimal=asdecimal, **kw) self.scale = scale - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] ) class _IntegerType(_NumericType, sqltypes.Integer): - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): self.display_width = display_width super().__init__(**kw) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] ) @@ -68,13 +87,13 @@ class _StringType(sqltypes.String): def __init__( self, - charset=None, - collation=None, - ascii=False, # noqa - binary=False, - unicode=False, - national=False, - **kw, + charset: Optional[str] = None, + collation: Optional[str] = None, + ascii: bool = False, # noqa + binary: bool = False, + unicode: bool = False, + national: bool = False, + **kw: Any, ): self.charset = charset @@ -87,25 +106,33 @@ def __init__( self.national = national super().__init__(**kw) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr( self, to_inspect=[_StringType, sqltypes.String] ) -class _MatchType(sqltypes.Float, sqltypes.MatchType): - def __init__(self, **kw): +class _MatchType( + sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType +): + def __init__(self, **kw: Any): # TODO: float arguments? - sqltypes.Float.__init__(self) + sqltypes.Float.__init__(self) # type: ignore[arg-type] sqltypes.MatchType.__init__(self) -class NUMERIC(_NumericType, sqltypes.NUMERIC): +class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]): """MySQL NUMERIC type.""" __visit_name__ = "NUMERIC" - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = True, + **kw: Any, + ): """Construct a NUMERIC. :param precision: Total digits in this number. If scale and precision @@ -126,12 +153,18 @@ def __init__(self, precision=None, scale=None, asdecimal=True, **kw): ) -class DECIMAL(_NumericType, sqltypes.DECIMAL): +class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]): """MySQL DECIMAL type.""" __visit_name__ = "DECIMAL" - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = True, + **kw: Any, + ): """Construct a DECIMAL. :param precision: Total digits in this number. If scale and precision @@ -152,12 +185,18 @@ def __init__(self, precision=None, scale=None, asdecimal=True, **kw): ) -class DOUBLE(_FloatType, sqltypes.DOUBLE): +class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]): """MySQL DOUBLE type.""" __visit_name__ = "DOUBLE" - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = True, + **kw: Any, + ): """Construct a DOUBLE. .. note:: @@ -186,12 +225,18 @@ def __init__(self, precision=None, scale=None, asdecimal=True, **kw): ) -class REAL(_FloatType, sqltypes.REAL): +class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]): """MySQL REAL type.""" __visit_name__ = "REAL" - def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = True, + **kw: Any, + ): """Construct a REAL. .. note:: @@ -220,12 +265,18 @@ def __init__(self, precision=None, scale=None, asdecimal=True, **kw): ) -class FLOAT(_FloatType, sqltypes.FLOAT): +class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]): """MySQL FLOAT type.""" __visit_name__ = "FLOAT" - def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + def __init__( + self, + precision: Optional[int] = None, + scale: Optional[int] = None, + asdecimal: bool = False, + **kw: Any, + ): """Construct a FLOAT. :param precision: Total digits in this number. If scale and precision @@ -245,7 +296,9 @@ def __init__(self, precision=None, scale=None, asdecimal=False, **kw): precision=precision, scale=scale, asdecimal=asdecimal, **kw ) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]: return None @@ -254,7 +307,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): __visit_name__ = "INTEGER" - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): """Construct an INTEGER. :param display_width: Optional, maximum display width for this number. @@ -275,7 +328,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): __visit_name__ = "BIGINT" - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): """Construct a BIGINTEGER. :param display_width: Optional, maximum display width for this number. @@ -296,7 +349,7 @@ class MEDIUMINT(_IntegerType): __visit_name__ = "MEDIUMINT" - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): """Construct a MEDIUMINTEGER :param display_width: Optional, maximum display width for this number. @@ -317,7 +370,7 @@ class TINYINT(_IntegerType): __visit_name__ = "TINYINT" - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): """Construct a TINYINT. :param display_width: Optional, maximum display width for this number. @@ -338,7 +391,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT): __visit_name__ = "SMALLINT" - def __init__(self, display_width=None, **kw): + def __init__(self, display_width: Optional[int] = None, **kw: Any): """Construct a SMALLINTEGER. :param display_width: Optional, maximum display width for this number. @@ -354,7 +407,7 @@ def __init__(self, display_width=None, **kw): super().__init__(display_width=display_width, **kw) -class BIT(sqltypes.TypeEngine): +class BIT(sqltypes.TypeEngine[Any]): """MySQL BIT type. This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater @@ -365,7 +418,7 @@ class BIT(sqltypes.TypeEngine): __visit_name__ = "BIT" - def __init__(self, length=None): + def __init__(self, length: Optional[int] = None): """Construct a BIT. :param length: Optional, number of bits. @@ -373,20 +426,19 @@ def __init__(self, length=None): """ self.length = length - def result_processor(self, dialect, coltype): - """Convert a MySQL's 64 bit, variable length binary string to a long. - - TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector - already do this, so this logic should be moved to those dialects. + def result_processor( + self, dialect: MySQLDialect, coltype: object # type: ignore[override] + ) -> Optional[_ResultProcessorType[Any]]: + """Convert a MySQL's 64 bit, variable length binary string to a + long.""" - """ + if dialect.supports_native_bit: + return None - def process(value): + def process(value: Optional[Iterable[int]]) -> Optional[int]: if value is not None: v = 0 for i in value: - if not isinstance(i, int): - i = ord(i) # convert byte to int on Python 2 v = v << 8 | i return v return value @@ -399,7 +451,7 @@ class TIME(sqltypes.TIME): __visit_name__ = "TIME" - def __init__(self, timezone=False, fsp=None): + def __init__(self, timezone: bool = False, fsp: Optional[int] = None): """Construct a MySQL TIME type. :param timezone: not used by the MySQL dialect. @@ -418,10 +470,12 @@ def __init__(self, timezone=False, fsp=None): super().__init__(timezone=timezone) self.fsp = fsp - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[datetime.time]: time = datetime.time - def process(value): + def process(value: Any) -> Optional[datetime.time]: # convert from a timedelta value if value is not None: microseconds = value.microseconds @@ -444,7 +498,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): __visit_name__ = "TIMESTAMP" - def __init__(self, timezone=False, fsp=None): + def __init__(self, timezone: bool = False, fsp: Optional[int] = None): """Construct a MySQL TIMESTAMP type. :param timezone: not used by the MySQL dialect. @@ -469,7 +523,7 @@ class DATETIME(sqltypes.DATETIME): __visit_name__ = "DATETIME" - def __init__(self, timezone=False, fsp=None): + def __init__(self, timezone: bool = False, fsp: Optional[int] = None): """Construct a MySQL DATETIME type. :param timezone: not used by the MySQL dialect. @@ -489,26 +543,26 @@ def __init__(self, timezone=False, fsp=None): self.fsp = fsp -class YEAR(sqltypes.TypeEngine): +class YEAR(sqltypes.TypeEngine[Any]): """MySQL YEAR type, for single byte storage of years 1901-2155.""" __visit_name__ = "YEAR" - def __init__(self, display_width=None): + def __init__(self, display_width: Optional[int] = None): self.display_width = display_width class TEXT(_StringType, sqltypes.TEXT): - """MySQL TEXT type, for text up to 2^16 characters.""" + """MySQL TEXT type, for character storage encoded up to 2^16 bytes.""" __visit_name__ = "TEXT" - def __init__(self, length=None, **kw): + def __init__(self, length: Optional[int] = None, **kw: Any): """Construct a TEXT. :param length: Optional, if provided the server may optimize storage by substituting the smallest TEXT type sufficient to store - ``length`` characters. + ``length`` bytes of characters. :param charset: Optional, a column-level character set for this string value. Takes precedence to 'ascii' or 'unicode' short-hand. @@ -535,11 +589,11 @@ def __init__(self, length=None, **kw): class TINYTEXT(_StringType): - """MySQL TINYTEXT type, for text up to 2^8 characters.""" + """MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes.""" __visit_name__ = "TINYTEXT" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): """Construct a TINYTEXT. :param charset: Optional, a column-level character set for this string @@ -567,11 +621,12 @@ def __init__(self, **kwargs): class MEDIUMTEXT(_StringType): - """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + """MySQL MEDIUMTEXT type, for character storage encoded up + to 2^24 bytes.""" __visit_name__ = "MEDIUMTEXT" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): """Construct a MEDIUMTEXT. :param charset: Optional, a column-level character set for this string @@ -599,11 +654,11 @@ def __init__(self, **kwargs): class LONGTEXT(_StringType): - """MySQL LONGTEXT type, for text up to 2^32 characters.""" + """MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes.""" __visit_name__ = "LONGTEXT" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): """Construct a LONGTEXT. :param charset: Optional, a column-level character set for this string @@ -635,7 +690,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): __visit_name__ = "VARCHAR" - def __init__(self, length=None, **kwargs): + def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None: """Construct a VARCHAR. :param charset: Optional, a column-level character set for this string @@ -667,7 +722,7 @@ class CHAR(_StringType, sqltypes.CHAR): __visit_name__ = "CHAR" - def __init__(self, length=None, **kwargs): + def __init__(self, length: Optional[int] = None, **kwargs: Any): """Construct a CHAR. :param length: Maximum data length, in characters. @@ -683,7 +738,7 @@ def __init__(self, length=None, **kwargs): super().__init__(length=length, **kwargs) @classmethod - def _adapt_string_for_cast(self, type_): + def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR: # copy the given string type into a CHAR # for the purposes of rendering a CAST expression type_ = sqltypes.to_instance(type_) @@ -712,7 +767,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): __visit_name__ = "NVARCHAR" - def __init__(self, length=None, **kwargs): + def __init__(self, length: Optional[int] = None, **kwargs: Any): """Construct an NVARCHAR. :param length: Maximum data length, in characters. @@ -738,7 +793,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): __visit_name__ = "NCHAR" - def __init__(self, length=None, **kwargs): + def __init__(self, length: Optional[int] = None, **kwargs: Any): """Construct an NCHAR. :param length: Maximum data length, in characters. diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index 46a5d0a2051..2265de033c9 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -1,11 +1,11 @@ -# oracle/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors - +from types import ModuleType from . import base # noqa from . import cx_oracle # noqa @@ -32,7 +32,16 @@ from .base import TIMESTAMP from .base import VARCHAR from .base import VARCHAR2 +from .base import VECTOR +from .base import VectorIndexConfig +from .base import VectorIndexType +from .vector import VectorDistanceType +from .vector import VectorStorageFormat +# Alias oracledb also as oracledb_async +oracledb_async = type( + "oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async} +) base.dialect = dialect = cx_oracle.dialect @@ -60,4 +69,9 @@ "NVARCHAR2", "ROWID", "REAL", + "VECTOR", + "VectorDistanceType", + "VectorIndexType", + "VectorIndexConfig", + "VectorStorageFormat", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index d993ef26927..1d882def8d6 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1,5 +1,5 @@ -# oracle/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,8 +9,7 @@ r""" .. dialect:: oracle - :name: Oracle - :full_support: 18c + :name: Oracle Database :normal_support: 11+ :best_effort: 9+ @@ -18,21 +17,24 @@ Auto Increment Behavior ----------------------- -SQLAlchemy Table objects which include integer primary keys are usually -assumed to have "autoincrementing" behavior, meaning they can generate their -own primary key values upon INSERT. For use within Oracle, two options are -available, which are the use of IDENTITY columns (Oracle 12 and above only) -or the association of a SEQUENCE with the column. +SQLAlchemy Table objects which include integer primary keys are usually assumed +to have "autoincrementing" behavior, meaning they can generate their own +primary key values upon INSERT. For use within Oracle Database, two options are +available, which are the use of IDENTITY columns (Oracle Database 12 and above +only) or the association of a SEQUENCE with the column. -Specifying GENERATED AS IDENTITY (Oracle 12 and above) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Specifying GENERATED AS IDENTITY (Oracle Database 12 and above) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Starting from version 12 Oracle can make use of identity columns using -the :class:`_sql.Identity` to specify the autoincrementing behavior:: +Starting from version 12, Oracle Database can make use of identity columns +using the :class:`_sql.Identity` to specify the autoincrementing behavior:: - t = Table('mytable', metadata, - Column('id', Integer, Identity(start=3), primary_key=True), - Column(...), ... + t = Table( + "mytable", + metadata, + Column("id", Integer, Identity(start=3), primary_key=True), + Column(...), + ..., ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -47,34 +49,38 @@ The :class:`_schema.Identity` object support many options to control the "autoincrementing" behavior of the column, like the starting value, the -incrementing value, etc. -In addition to the standard options, Oracle supports setting -:paramref:`_schema.Identity.always` to ``None`` to use the default -generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports +incrementing value, etc. In addition to the standard options, Oracle Database +supports setting :paramref:`_schema.Identity.always` to ``None`` to use the +default generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL in conjunction with a 'BY DEFAULT' identity column. -Using a SEQUENCE (all Oracle versions) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Older version of Oracle had no "autoincrement" -feature, SQLAlchemy relies upon sequences to produce these values. With the -older Oracle versions, *a sequence must always be explicitly specified to -enable autoincrement*. This is divergent with the majority of documentation -examples which assume the usage of an autoincrement-capable database. To -specify sequences, use the sqlalchemy.schema.Sequence object which is passed -to a Column construct:: - - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), - Column(...), ... +Using a SEQUENCE (all Oracle Database versions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Older version of Oracle Database had no "autoincrement" feature: SQLAlchemy +relies upon sequences to produce these values. With the older Oracle Database +versions, *a sequence must always be explicitly specified to enable +autoincrement*. This is divergent with the majority of documentation examples +which assume the usage of an autoincrement-capable database. To specify +sequences, use the sqlalchemy.schema.Sequence object which is passed to a +Column construct:: + + t = Table( + "mytable", + metadata, + Column("id", Integer, Sequence("id_seq", start=1), primary_key=True), + Column(...), + ..., ) This step is also required when using table reflection, i.e. autoload_with=engine:: - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), - autoload_with=engine + t = Table( + "mytable", + metadata, + Column("id", Integer, Sequence("id_seq", start=1), primary_key=True), + autoload_with=engine, ) .. versionchanged:: 1.4 Added :class:`_schema.Identity` construct @@ -86,21 +92,18 @@ Transaction Isolation Level / Autocommit ---------------------------------------- -The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of -isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle -dialect. +Oracle Database supports "READ COMMITTED" and "SERIALIZABLE" modes of +isolation. The AUTOCOMMIT isolation level is also supported by the +python-oracledb and cx_Oracle dialects. To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options( - isolation_level="AUTOCOMMIT" - ) + connection = connection.execution_options(isolation_level="AUTOCOMMIT") -For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the -level at the session level using ``ALTER SESSION``, which is reverted back -to its default setting when the connection is returned to the connection -pool. +For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle Database dialects sets +the level at the session level using ``ALTER SESSION``, which is reverted back +to its default setting when the connection is returned to the connection pool. Valid values for ``isolation_level`` include: @@ -110,28 +113,28 @@ .. note:: The implementation for the :meth:`_engine.Connection.get_isolation_level` method as implemented by the - Oracle dialect necessarily forces the start of a transaction using the - Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally - readable. + Oracle Database dialects necessarily force the start of a transaction using the + Oracle Database DBMS_TRANSACTION.LOCAL_TRANSACTION_ID function; otherwise no + level is normally readable. Additionally, the :meth:`_engine.Connection.get_isolation_level` method will raise an exception if the ``v$transaction`` view is not available due to - permissions or other reasons, which is a common occurrence in Oracle + permissions or other reasons, which is a common occurrence in Oracle Database installations. - The cx_Oracle dialect attempts to call the + The python-oracledb and cx_Oracle dialects attempt to call the :meth:`_engine.Connection.get_isolation_level` method when the dialect makes its first connection to the database in order to acquire the "default"isolation level. This default level is necessary so that the level can be reset on a connection after it has been temporarily modified using - :meth:`_engine.Connection.execution_options` method. In the common event + :meth:`_engine.Connection.execution_options` method. In the common event that the :meth:`_engine.Connection.get_isolation_level` method raises an exception due to ``v$transaction`` not being readable as well as any other database-related failure, the level is assumed to be "READ COMMITTED". No warning is emitted for this initial first-connect condition as it is expected to be a common restriction on Oracle databases. -.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect +.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_Oracle dialect as well as the notion of a default isolation level .. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live @@ -149,56 +152,182 @@ Identifier Casing ----------------- -In Oracle, the data dictionary represents all case insensitive identifier -names using UPPERCASE text. SQLAlchemy on the other hand considers an -all-lower case identifier name to be case insensitive. The Oracle dialect -converts all case insensitive identifiers to and from those two formats during -schema level communication, such as reflection of tables and indexes. Using -an UPPERCASE name on the SQLAlchemy side indicates a case sensitive -identifier, and SQLAlchemy will quote the name - this will cause mismatches -against data dictionary data received from Oracle, so unless identifier names -have been truly created as case sensitive (i.e. using quoted names), all -lowercase names should be used on the SQLAlchemy side. +In Oracle Database, the data dictionary represents all case insensitive +identifier names using UPPERCASE text. This is in contradiction to the +expectations of SQLAlchemy, which assume a case insensitive name is represented +as lowercase text. + +As an example of case insensitive identifier names, consider the following table: + +.. sourcecode:: sql + + CREATE TABLE MyTable (Identifier INTEGER PRIMARY KEY) + +If you were to ask Oracle Database for information about this table, the +table name would be reported as ``MYTABLE`` and the column name would +be reported as ``IDENTIFIER``. Compare to most other databases such as +PostgreSQL and MySQL which would report these names as ``mytable`` and +``identifier``. The names are **not quoted, therefore are case insensitive**. +The special casing of ``MyTable`` and ``Identifier`` would only be maintained +if they were quoted in the table definition: + +.. sourcecode:: sql + + CREATE TABLE "MyTable" ("Identifier" INTEGER PRIMARY KEY) + +When constructing a SQLAlchemy :class:`.Table` object, **an all lowercase name +is considered to be case insensitive**. So the following table assumes +case insensitive names:: + + Table("mytable", metadata, Column("identifier", Integer, primary_key=True)) + +Whereas when mixed case or UPPERCASE names are used, case sensitivity is +assumed:: + + Table("MyTable", metadata, Column("Identifier", Integer, primary_key=True)) + +A similar situation occurs at the database driver level when emitting a +textual SQL SELECT statement and looking at column names in the DBAPI +``cursor.description`` attribute. A database like PostgreSQL will normalize +case insensitive names to be lowercase:: + + >>> pg_engine = create_engine("postgresql://scott:tiger@localhost/test") + >>> pg_connection = pg_engine.connect() + >>> result = pg_connection.exec_driver_sql("SELECT 1 AS SomeName") + >>> result.cursor.description + (Column(name='somename', type_code=23),) + +Whereas Oracle normalizes them to UPPERCASE:: + + >>> oracle_engine = create_engine("oracle+oracledb://scott:tiger@oracle18c/xe") + >>> oracle_connection = oracle_engine.connect() + >>> result = oracle_connection.exec_driver_sql( + ... "SELECT 1 AS SomeName FROM DUAL" + ... ) + >>> result.cursor.description + [('SOMENAME', , 127, None, 0, -127, True)] + +In order to achieve cross-database parity for the two cases of a. table +reflection and b. textual-only SQL statement round trips, SQLAlchemy performs a step +called **name normalization** when using the Oracle dialect. This process may +also apply to other third party dialects that have similar UPPERCASE handling +of case insensitive names. + +When using name normalization, SQLAlchemy attempts to detect if a name is +case insensitive by checking if all characters are UPPERCASE letters only; +if so, then it assumes this is a case insensitive name and is delivered as +a lowercase name. + +For table reflection, a tablename that is seen represented as all UPPERCASE +in Oracle Database's catalog tables will be assumed to have a case insensitive +name. This is what allows the ``Table`` definition to use lower case names +and be equally compatible from a reflection point of view on Oracle Database +and all other databases such as PostgreSQL and MySQL:: + + # matches a table created with CREATE TABLE mytable + Table("mytable", metadata, autoload_with=some_engine) + +Above, the all lowercase name ``"mytable"`` is case insensitive; it will match +a table reported by PostgreSQL as ``"mytable"`` and a table reported by +Oracle as ``"MYTABLE"``. If name normalization were not present, it would +not be possible for the above :class:`.Table` definition to be introspectable +in a cross-database way, since we are dealing with a case insensitive name +that is not reported by each database in the same way. + +Case sensitivity can be forced on in this case, such as if we wanted to represent +the quoted tablename ``"MYTABLE"`` with that exact casing, most simply by using +that casing directly, which will be seen as a case sensitive name:: + + # matches a table created with CREATE TABLE "MYTABLE" + Table("MYTABLE", metadata, autoload_with=some_engine) + +For the unusual case of a quoted all-lowercase name, the :class:`.quoted_name` +construct may be used:: + + from sqlalchemy import quoted_name + + # matches a table created with CREATE TABLE "mytable" + Table( + quoted_name("mytable", quote=True), metadata, autoload_with=some_engine + ) + +Name normalization also takes place when handling result sets from **purely +textual SQL strings**, that have no other :class:`.Table` or :class:`.Column` +metadata associated with them. This includes SQL strings executed using +:meth:`.Connection.exec_driver_sql` and SQL strings executed using the +:func:`.text` construct which do not include :class:`.Column` metadata. + +Returning to the Oracle Database SELECT statement, we see that even though +``cursor.description`` reports the column name as ``SOMENAME``, SQLAlchemy +name normalizes this to ``somename``:: + + >>> oracle_engine = create_engine("oracle+oracledb://scott:tiger@oracle18c/xe") + >>> oracle_connection = oracle_engine.connect() + >>> result = oracle_connection.exec_driver_sql( + ... "SELECT 1 AS SomeName FROM DUAL" + ... ) + >>> result.cursor.description + [('SOMENAME', , 127, None, 0, -127, True)] + >>> result.keys() + RMKeyView(['somename']) + +The single scenario where the above behavior produces inaccurate results +is when using an all-uppercase, quoted name. SQLAlchemy has no way to determine +that a particular name in ``cursor.description`` was quoted, and is therefore +case sensitive, or was not quoted, and should be name normalized:: + + >>> result = oracle_connection.exec_driver_sql( + ... 'SELECT 1 AS "SOMENAME" FROM DUAL' + ... ) + >>> result.cursor.description + [('SOMENAME', , 127, None, 0, -127, True)] + >>> result.keys() + RMKeyView(['somename']) + +For this case, a new feature will be available in SQLAlchemy 2.1 to disable +the name normalization behavior in specific cases. + .. _oracle_max_identifier_lengths: -Max Identifier Lengths ----------------------- +Maximum Identifier Lengths +-------------------------- -Oracle has changed the default max identifier length as of Oracle Server -version 12.2. Prior to this version, the length was 30, and for 12.2 and -greater it is now 128. This change impacts SQLAlchemy in the area of -generated SQL label names as well as the generation of constraint names, -particularly in the case where the constraint naming convention feature -described at :ref:`constraint_naming_conventions` is being used. - -To assist with this change and others, Oracle includes the concept of a -"compatibility" version, which is a version number that is independent of the -actual server version in order to assist with migration of Oracle databases, -and may be configured within the Oracle server itself. This compatibility -version is retrieved using the query ``SELECT value FROM v$parameter WHERE -name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with -determining the default max identifier length, will attempt to use this query -upon first connect in order to determine the effective compatibility version of -the server, which determines what the maximum allowed identifier length is for -the server. If the table is not available, the server version information is -used instead. - -As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect -is 128 characters. Upon first connect, the compatibility version is detected -and if it is less than Oracle version 12.2, the max identifier length is -changed to be 30 characters. In all cases, setting the +SQLAlchemy is sensitive to the maximum identifier length supported by Oracle +Database. This affects generated SQL label names as well as the generation of +constraint names, particularly in the case where the constraint naming +convention feature described at :ref:`constraint_naming_conventions` is being +used. + +Oracle Database 12.2 increased the default maximum identifier length from 30 to +128. As of SQLAlchemy 1.4, the default maximum identifier length for the Oracle +dialects is 128 characters. Upon first connection, the maximum length actually +supported by the database is obtained. In all cases, setting the :paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this change and the value given will be used as is:: engine = create_engine( - "oracle+cx_oracle://scott:tiger@oracle122", - max_identifier_length=30) + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1", + max_identifier_length=30, + ) + +If :paramref:`_sa.create_engine.max_identifier_length` is not set, the oracledb +dialect internally uses the ``max_identifier_length`` attribute available on +driver connections since python-oracledb version 2.5. When using an older +driver version, or using the cx_Oracle dialect, SQLAlchemy will instead attempt +to use the query ``SELECT value FROM v$parameter WHERE name = 'compatible'`` +upon first connect in order to determine the effective compatibility version of +the database. The "compatibility" version is a version number that is +independent of the actual database version. It is used to assist database +migration. It is configured by an Oracle Database initialization parameter. The +compatibility version then determines the maximum allowed identifier length for +the database. If the V$ view is not available, the database version information +is used instead. The maximum identifier length comes into play both when generating anonymized SQL labels in SELECT statements, but more crucially when generating constraint names from a naming convention. It is this area that has created the need for -SQLAlchemy to change this default conservatively. For example, the following +SQLAlchemy to change this default conservatively. For example, the following naming convention produces two very different constraint names based on the identifier length:: @@ -230,68 +359,71 @@ oracle_dialect = oracle.dialect(max_identifier_length=30) print(CreateIndex(ix).compile(dialect=oracle_dialect)) -With an identifier length of 30, the above CREATE INDEX looks like:: +With an identifier length of 30, the above CREATE INDEX looks like: + +.. sourcecode:: sql CREATE INDEX ix_some_column_name_1s_70cd ON t (some_column_name_1, some_column_name_2, some_column_name_3) -However with length=128, it becomes:: +However with length of 128, it becomes:: + +.. sourcecode:: sql CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t (some_column_name_1, some_column_name_2, some_column_name_3) -Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle -server version 12.2 or greater are therefore subject to the scenario of a +Applications which have run versions of SQLAlchemy prior to 1.4 on Oracle +Database version 12.2 or greater are therefore subject to the scenario of a database migration that wishes to "DROP CONSTRAINT" on a name that was previously generated with the shorter length. This migration will fail when the identifier length is changed without the name of the index or constraint first being adjusted. Such applications are strongly advised to make use of -:paramref:`_sa.create_engine.max_identifier_length` -in order to maintain control -of the generation of truncated names, and to fully review and test all database -migrations in a staging environment when changing this value to ensure that the -impact of this change has been mitigated. +:paramref:`_sa.create_engine.max_identifier_length` in order to maintain +control of the generation of truncated names, and to fully review and test all +database migrations in a staging environment when changing this value to ensure +that the impact of this change has been mitigated. -.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128 - characters, which is adjusted down to 30 upon first connect if an older - version of Oracle server (compatibility version < 12.2) is detected. +.. versionchanged:: 1.4 the default max_identifier_length for Oracle Database + is 128 characters, which is adjusted down to 30 upon first connect if the + Oracle Database, or its compatibility setting, are lower than version 12.2. LIMIT/OFFSET/FETCH Support -------------------------- -Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make -use of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming -Oracle 12c or above, and assuming the SELECT statement is not embedded within -a compound statement like UNION. This syntax is also available directly by using -the :meth:`_sql.Select.fetch` method. - -.. versionchanged:: 2.0 the Oracle dialect now uses - ``FETCH FIRST N ROW / OFFSET N ROWS`` for all - :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` usage including - within the ORM and legacy :class:`_orm.Query`. To force the legacy - behavior using window functions, specify the ``enable_offset_fetch=False`` - dialect parameter to :func:`_sa.create_engine`. - -The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle version -by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, which -will force the use of "legacy" mode that makes use of window functions. +Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make use +of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming Oracle Database 12c or +above, and assuming the SELECT statement is not embedded within a compound +statement like UNION. This syntax is also available directly by using the +:meth:`_sql.Select.fetch` method. + +.. versionchanged:: 2.0 the Oracle Database dialects now use ``FETCH FIRST N + ROW / OFFSET N ROWS`` for all :meth:`_sql.Select.limit` and + :meth:`_sql.Select.offset` usage including within the ORM and legacy + :class:`_orm.Query`. To force the legacy behavior using window functions, + specify the ``enable_offset_fetch=False`` dialect parameter to + :func:`_sa.create_engine`. + +The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle Database +version by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, +which will force the use of "legacy" mode that makes use of window functions. This mode is also selected automatically when using a version of Oracle -prior to 12c. +Database prior to 12c. -When using legacy mode, or when a :class:`.Select` statement -with limit/offset is embedded in a compound statement, an emulated approach for -LIMIT / OFFSET based on window functions is used, which involves creation of a -subquery using ``ROW_NUMBER`` that is prone to performance issues as well as -SQL construction issues for complex statements. However, this approach is -supported by all Oracle versions. See notes below. +When using legacy mode, or when a :class:`.Select` statement with limit/offset +is embedded in a compound statement, an emulated approach for LIMIT / OFFSET +based on window functions is used, which involves creation of a subquery using +``ROW_NUMBER`` that is prone to performance issues as well as SQL construction +issues for complex statements. However, this approach is supported by all +Oracle Database versions. See notes below. Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`, or with the ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods on an -Oracle version prior to 12c, the following notes apply: +Oracle Database version prior to 12c, the following notes apply: * SQLAlchemy currently makes use of ROWNUM to achieve LIMIT/OFFSET; the exact methodology is taken from @@ -302,10 +434,11 @@ to :func:`_sa.create_engine`. .. versionchanged:: 1.4 - The Oracle dialect renders limit/offset integer values using a "post - compile" scheme which renders the integer directly before passing the - statement to the cursor for execution. The ``use_binds_for_limits`` flag - no longer has an effect. + + The Oracle Database dialect renders limit/offset integer values using a + "post compile" scheme which renders the integer directly before passing + the statement to the cursor for execution. The ``use_binds_for_limits`` + flag no longer has an effect. .. seealso:: @@ -316,37 +449,36 @@ RETURNING Support ----------------- -The Oracle database supports RETURNING fully for INSERT, UPDATE and DELETE -statements that are invoked with a single collection of bound parameters -(that is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally +Oracle Database supports RETURNING fully for INSERT, UPDATE and DELETE +statements that are invoked with a single collection of bound parameters (that +is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally support RETURNING with :term:`executemany` statements). Multiple rows may be returned as well. -.. versionchanged:: 2.0 the Oracle backend has full support for RETURNING - on parity with other backends. - +.. versionchanged:: 2.0 the Oracle Database backend has full support for + RETURNING on parity with other backends. ON UPDATE CASCADE ----------------- -Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based -solution is available at -https://asktom.oracle.com/tkyte/update_cascade/index.html . +Oracle Database doesn't have native ON UPDATE CASCADE functionality. A trigger +based solution is available at +https://web.archive.org/web/20090317041251/https://asktom.oracle.com/tkyte/update_cascade/index.html When using the SQLAlchemy ORM, the ORM has limited ability to manually issue cascading updates - specify ForeignKey objects using the "deferrable=True, initially='deferred'" keyword arguments, and specify "passive_updates=False" on each relationship(). -Oracle 8 Compatibility ----------------------- +Oracle Database 8 Compatibility +------------------------------- -.. warning:: The status of Oracle 8 compatibility is not known for SQLAlchemy - 2.0. +.. warning:: The status of Oracle Database 8 compatibility is not known for + SQLAlchemy 2.0. -When Oracle 8 is detected, the dialect internally configures itself to the -following behaviors: +When Oracle Database 8 is detected, the dialect internally configures itself to +the following behaviors: * the use_ansi flag is set to False. This has the effect of converting all JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN @@ -368,14 +500,15 @@ accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as a keyword argument to the :class:`_schema.Table` construct:: - some_table = Table('some_table', autoload_with=some_engine, - oracle_resolve_synonyms=True) + some_table = Table( + "some_table", autoload_with=some_engine, oracle_resolve_synonyms=True + ) -When this flag is set, the given name (such as ``some_table`` above) will -be searched not just in the ``ALL_TABLES`` view, but also within the +When this flag is set, the given name (such as ``some_table`` above) will be +searched not just in the ``ALL_TABLES`` view, but also within the ``ALL_SYNONYMS`` view to see if this name is actually a synonym to another -name. If the synonym is located and refers to a DBLINK, the oracle dialect -knows how to locate the table's information using DBLINK syntax(e.g. +name. If the synonym is located and refers to a DBLINK, the Oracle Database +dialects know how to locate the table's information using DBLINK syntax(e.g. ``@dblink``). ``oracle_resolve_synonyms`` is accepted wherever reflection arguments are @@ -389,8 +522,8 @@ Constraint Reflection --------------------- -The Oracle dialect can return information about foreign key, unique, and -CHECK constraints, as well as indexes on tables. +The Oracle Database dialects can return information about foreign key, unique, +and CHECK constraints, as well as indexes on tables. Raw information regarding these constraints can be acquired using :meth:`_reflection.Inspector.get_foreign_keys`, @@ -398,7 +531,7 @@ :meth:`_reflection.Inspector.get_check_constraints`, and :meth:`_reflection.Inspector.get_indexes`. -.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and +.. versionchanged:: 1.2 The Oracle Database dialect can now reflect UNIQUE and CHECK constraints. When using reflection at the :class:`_schema.Table` level, the @@ -408,29 +541,29 @@ Note the following caveats: * When using the :meth:`_reflection.Inspector.get_check_constraints` method, - Oracle - builds a special "IS NOT NULL" constraint for columns that specify - "NOT NULL". This constraint is **not** returned by default; to include - the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: + Oracle Database builds a special "IS NOT NULL" constraint for columns that + specify "NOT NULL". This constraint is **not** returned by default; to + include the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: from sqlalchemy import create_engine, inspect - engine = create_engine("oracle+cx_oracle://s:t@dsn") + engine = create_engine( + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" + ) inspector = inspect(engine) all_check_constraints = inspector.get_check_constraints( - "some_table", include_all=True) + "some_table", include_all=True + ) -* in most cases, when reflecting a :class:`_schema.Table`, - a UNIQUE constraint will - **not** be available as a :class:`.UniqueConstraint` object, as Oracle - mirrors unique constraints with a UNIQUE index in most cases (the exception - seems to be when two or more unique constraints represent the same columns); - the :class:`_schema.Table` will instead represent these using - :class:`.Index` - with the ``unique=True`` flag set. +* in most cases, when reflecting a :class:`_schema.Table`, a UNIQUE constraint + will **not** be available as a :class:`.UniqueConstraint` object, as Oracle + Database mirrors unique constraints with a UNIQUE index in most cases (the + exception seems to be when two or more unique constraints represent the same + columns); the :class:`_schema.Table` will instead represent these using + :class:`.Index` with the ``unique=True`` flag set. -* Oracle creates an implicit index for the primary key of a table; this index - is **excluded** from all index results. +* Oracle Database creates an implicit index for the primary key of a table; + this index is **excluded** from all index results. * the list of columns reflected for an index will not include column names that start with SYS_NC. @@ -450,50 +583,112 @@ # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM e = create_engine( - "oracle+cx_oracle://scott:tiger@xe", - exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"]) + "oracle+oracledb://scott:tiger@localhost:1521/?service_name=freepdb1", + exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"], + ) + +.. _oracle_float_support: + +FLOAT / DOUBLE Support and Behaviors +------------------------------------ + +The SQLAlchemy :class:`.Float` and :class:`.Double` datatypes are generic +datatypes that resolve to the "least surprising" datatype for a given backend. +For Oracle Database, this means they resolve to the ``FLOAT`` and ``DOUBLE`` +types:: + + >>> from sqlalchemy import cast, literal, Float + >>> from sqlalchemy.dialects import oracle + >>> float_datatype = Float() + >>> print(cast(literal(5.0), float_datatype).compile(dialect=oracle.dialect())) + CAST(:param_1 AS FLOAT) + +Oracle's ``FLOAT`` / ``DOUBLE`` datatypes are aliases for ``NUMBER``. Oracle +Database stores ``NUMBER`` values with full precision, not floating point +precision, which means that ``FLOAT`` / ``DOUBLE`` do not actually behave like +native FP values. Oracle Database instead offers special datatypes +``BINARY_FLOAT`` and ``BINARY_DOUBLE`` to deliver real 4- and 8- byte FP +values. + +SQLAlchemy supports these datatypes directly using :class:`.BINARY_FLOAT` and +:class:`.BINARY_DOUBLE`. To use the :class:`.Float` or :class:`.Double` +datatypes in a database agnostic way, while allowing Oracle backends to utilize +one of these types, use the :meth:`.TypeEngine.with_variant` method to set up a +variant:: + + >>> from sqlalchemy import cast, literal, Float + >>> from sqlalchemy.dialects import oracle + >>> float_datatype = Float().with_variant(oracle.BINARY_FLOAT(), "oracle") + >>> print(cast(literal(5.0), float_datatype).compile(dialect=oracle.dialect())) + CAST(:param_1 AS BINARY_FLOAT) + +E.g. to use this datatype in a :class:`.Table` definition:: + + my_table = Table( + "my_table", + metadata, + Column( + "fp_data", Float().with_variant(oracle.BINARY_FLOAT(), "oracle") + ), + ) DateTime Compatibility ---------------------- -Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, -which can actually store a date and time value. For this reason, the Oracle -dialect provides a type :class:`_oracle.DATE` which is a subclass of -:class:`.DateTime`. This type has no special behavior, and is only -present as a "marker" for this type; additionally, when a database column -is reflected and the type is reported as ``DATE``, the time-supporting +Oracle Database has no datatype known as ``DATETIME``, it instead has only +``DATE``, which can actually store a date and time value. For this reason, the +Oracle Database dialects provide a type :class:`_oracle.DATE` which is a +subclass of :class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column is +reflected and the type is reported as ``DATE``, the time-supporting :class:`_oracle.DATE` type is used. .. _oracle_table_options: -Oracle Table Options -------------------------- +Oracle Database Table Options +----------------------------- -The CREATE TABLE phrase supports the following options with Oracle -in conjunction with the :class:`_schema.Table` construct: +The CREATE TABLE phrase supports the following options with Oracle Database +dialects in conjunction with the :class:`_schema.Table` construct: * ``ON COMMIT``:: Table( - "some_table", metadata, ..., - prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + "some_table", + metadata, + ..., + prefixes=["GLOBAL TEMPORARY"], + oracle_on_commit="PRESERVE ROWS", + ) + +* + ``COMPRESS``:: -* ``COMPRESS``:: + Table( + "mytable", metadata, Column("data", String(32)), oracle_compress=True + ) - Table('mytable', metadata, Column('data', String(32)), - oracle_compress=True) + Table("mytable", metadata, Column("data", String(32)), oracle_compress=6) - Table('mytable', metadata, Column('data', String(32)), - oracle_compress=6) + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. - The ``oracle_compress`` parameter accepts either an integer compression - level, or ``True`` to use the default compression level. +* + ``TABLESPACE``:: + + Table("mytable", metadata, ..., oracle_tablespace="EXAMPLE_TABLESPACE") + + The ``oracle_tablespace`` parameter specifies the tablespace in which the + table is to be created. This is useful when you want to create a table in a + tablespace other than the default tablespace of the user. + + .. versionadded:: 2.0.37 .. _oracle_index_options: -Oracle Specific Index Options ------------------------------ +Oracle Database Specific Index Options +-------------------------------------- Bitmap Indexes ~~~~~~~~~~~~~~ @@ -501,7 +696,7 @@ You can specify the ``oracle_bitmap`` parameter to create a bitmap index instead of a B-tree index:: - Index('my_index', my_table.c.data, oracle_bitmap=True) + Index("my_index", my_table.c.data, oracle_bitmap=True) Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not check for such limitations, only the database will. @@ -509,24 +704,195 @@ Index compression ~~~~~~~~~~~~~~~~~ -Oracle has a more efficient storage mode for indexes containing lots of -repeated values. Use the ``oracle_compress`` parameter to turn on key +Oracle Database has a more efficient storage mode for indexes containing lots +of repeated values. Use the ``oracle_compress`` parameter to turn on key compression:: - Index('my_index', my_table.c.data, oracle_compress=True) + Index("my_index", my_table.c.data, oracle_compress=True) - Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, - oracle_compress=1) + Index( + "my_index", + my_table.c.data1, + my_table.c.data2, + unique=True, + oracle_compress=1, + ) The ``oracle_compress`` parameter accepts either an integer specifying the number of prefix columns to compress, or ``True`` to use the default (all columns for non-unique indexes, all but the last column for unique indexes). +.. _oracle_vector_datatype: + +VECTOR Datatype +--------------- + +Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence +and machine learning search operations. The VECTOR datatype is a homogeneous array +of 8-bit signed integers, 8-bit unsigned integers (binary), 32-bit floating-point numbers, +or 64-bit floating-point numbers. + +.. seealso:: + + `Using VECTOR Data + `_ - in the documentation + for the :ref:`oracledb` driver. + +.. versionadded:: 2.0.41 + +CREATE TABLE support for VECTOR +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +With the :class:`.VECTOR` datatype, you can specify the dimension for the data +and the storage format. Valid values for storage format are enum values from +:class:`.VectorStorageFormat`. To create a table that includes a +:class:`.VECTOR` column:: + + from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat + + t = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + Column(...), + ..., + ) + +Vectors can also be defined with an arbitrary number of dimensions and formats. +This allows you to specify vectors of different dimensions with the various +storage formats mentioned above. + +**Examples** + +* In this case, the storage format is flexible, allowing any vector type data to be inserted, + such as INT8 or BINARY etc:: + + vector_col: Mapped[array.array] = mapped_column(VECTOR(dim=3)) + +* The dimension is flexible in this case, meaning that any dimension vector can be used:: + + vector_col: Mapped[array.array] = mapped_column( + VECTOR(storage_format=VectorStorageType.INT8) + ) + +* Both the dimensions and the storage format are flexible:: + + vector_col: Mapped[array.array] = mapped_column(VECTOR) + +Python Datatypes for VECTOR +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +VECTOR data can be inserted using Python list or Python ``array.array()`` objects. +Python arrays of type FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer) +are used as bind values when inserting VECTOR columns:: + + from sqlalchemy import insert, select + + with engine.begin() as conn: + conn.execute( + insert(t1), + {"id": 1, "embedding": [1, 2, 3]}, + ) + +VECTOR Indexes +~~~~~~~~~~~~~~ + +The VECTOR feature supports an Oracle-specific parameter ``oracle_vector`` +on the :class:`.Index` construct, which allows the construction of VECTOR +indexes. + +To utilize VECTOR indexing, set the ``oracle_vector`` parameter to True to use +the default values provided by Oracle. HNSW is the default indexing method:: + + from sqlalchemy import Index + + Index( + "vector_index", + t1.c.embedding, + oracle_vector=True, + ) + +The full range of parameters for vector indexes are available by using the +:class:`.VectorIndexConfig` dataclass in place of a boolean; this dataclass +allows full configuration of the index:: + + Index( + "hnsw_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.HNSW, + distance=VectorDistanceType.COSINE, + accuracy=90, + hnsw_neighbors=5, + hnsw_efconstruction=20, + parallel=10, + ), + ) + + Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + +For complete explanation of these parameters, see the Oracle documentation linked +below. + +.. seealso:: + + `CREATE VECTOR INDEX `_ - in the Oracle documentation + + + +Similarity Searching +~~~~~~~~~~~~~~~~~~~~ + +When using the :class:`_oracle.VECTOR` datatype with a :class:`.Column` or similar +ORM mapped construct, additional comparison functions are available, including: + +* ``l2_distance`` +* ``cosine_distance`` +* ``inner_product`` + +Example Usage:: + + result_vector = connection.scalars( + select(t1).order_by(t1.embedding.l2_distance([2, 3, 4])).limit(3) + ) + + for user in vector: + print(user.id, user.embedding) + +FETCH APPROXIMATE support +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Approximate vector search can only be performed when all syntax and semantic +rules are satisfied, the corresponding vector index is available, and the +query optimizer determines to perform it. If any of these conditions are +unmet, then an approximate search is not performed. In this case the query +returns exact results. + +To enable approximate searching during similarity searches on VECTORS, the +``oracle_fetch_approximate`` parameter may be used with the :meth:`.Select.fetch` +clause to add ``FETCH APPROX`` to the SELECT statement:: + + select(users_table).fetch(5, oracle_fetch_approximate=True) + """ # noqa from __future__ import annotations from collections import defaultdict +from dataclasses import fields from functools import lru_cache from functools import wraps import re @@ -549,6 +915,9 @@ from .types import ROWID # noqa from .types import TIMESTAMP from .types import VARCHAR2 # noqa +from .vector import VECTOR +from .vector import VectorIndexConfig +from .vector import VectorIndexType from ... import Computed from ... import exc from ... import schema as sa_schema @@ -567,6 +936,7 @@ from ...sql import null from ...sql import or_ from ...sql import select +from ...sql import selectable as sa_selectable from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors @@ -594,7 +964,7 @@ ) NO_ARG_FNS = set( - "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() + "UID CURRENT_DATE SYSDATE USER CURRENT_TIME CURRENT_TIMESTAMP".split() ) @@ -628,6 +998,7 @@ "BINARY_DOUBLE": BINARY_DOUBLE, "BINARY_FLOAT": BINARY_FLOAT, "ROWID": ROWID, + "VECTOR": VECTOR, } @@ -708,16 +1079,16 @@ def _generate_numeric( # https://www.oracletutorial.com/oracle-basics/oracle-float/ estimated_binary_precision = int(precision / 0.30103) raise exc.ArgumentError( - "Oracle FLOAT types use 'binary precision', which does " - "not convert cleanly from decimal 'precision'. Please " - "specify " - f"this type with a separate Oracle variant, such as " - f"{type_.__class__.__name__}(precision={precision})." + "Oracle Database FLOAT types use 'binary precision', " + "which does not convert cleanly from decimal " + "'precision'. Please specify " + "this type with a separate Oracle Database variant, such " + f"as {type_.__class__.__name__}(precision={precision})." f"with_variant(oracle.FLOAT" f"(binary_precision=" f"{estimated_binary_precision}), 'oracle'), so that the " - "Oracle specific 'binary_precision' may be specified " - "accurately." + "Oracle Database specific 'binary_precision' may be " + "specified accurately." ) else: precision = binary_precision @@ -785,6 +1156,16 @@ def visit_RAW(self, type_, **kw): def visit_ROWID(self, type_, **kw): return "ROWID" + def visit_VECTOR(self, type_, **kw): + if type_.dim is None and type_.storage_format is None: + return "VECTOR(*,*)" + elif type_.storage_format is None: + return f"VECTOR({type_.dim},*)" + elif type_.dim is None: + return f"VECTOR(*,{type_.storage_format.value})" + else: + return f"VECTOR({type_.dim},{type_.storage_format.value})" + class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select @@ -839,7 +1220,7 @@ def function_argspec(self, fn, **kw): def visit_function(self, func, **kw): text = super().visit_function(func, **kw) - if kw.get("asfrom", False): + if kw.get("asfrom", False) and func.name.lower() != "table": text = "TABLE (%s)" % text return text @@ -946,13 +1327,13 @@ def returning_clause( and not self.dialect._supports_update_returning_computed_cols ): util.warn( - "Computed columns don't work with Oracle UPDATE " + "Computed columns don't work with Oracle Database UPDATE " "statements that use RETURNING; the value of the column " "*before* the UPDATE takes place is returned. It is " - "advised to not use RETURNING with an Oracle computed " - "column. Consider setting implicit_returning to False on " - "the Table object in order to avoid implicit RETURNING " - "clauses from being generated for this Table." + "advised to not use RETURNING with an Oracle Database " + "computed column. Consider setting implicit_returning " + "to False on the Table object in order to avoid implicit " + "RETURNING clauses from being generated for this Table." ) if column.type._has_column_expression: col_expr = column.type.column_expression(column) @@ -976,7 +1357,7 @@ def returning_clause( raise exc.InvalidRequestError( "Using explicit outparam() objects with " "UpdateBase.returning() in the same Core DML statement " - "is not supported in the Oracle dialect." + "is not supported in the Oracle Database dialects." ) self._oracle_returning = True @@ -997,7 +1378,7 @@ def returning_clause( return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) def _row_limit_clause(self, select, **kw): - """ORacle 12c supports OFFSET/FETCH operators + """Oracle Database 12c supports OFFSET/FETCH operators Use it instead subquery with row_number """ @@ -1023,6 +1404,29 @@ def _get_limit_or_fetch(self, select): else: return select._fetch_clause + def fetch_clause( + self, + select, + fetch_clause=None, + require_offset=False, + use_literal_execute_for_simple_int=False, + **kw, + ): + text = super().fetch_clause( + select, + fetch_clause=fetch_clause, + require_offset=require_offset, + use_literal_execute_for_simple_int=( + use_literal_execute_for_simple_int + ), + **kw, + ) + + if select.dialect_options["oracle"]["fetch_approximate"]: + text = re.sub("FETCH FIRST", "FETCH APPROX FIRST", text) + + return text + def translate_select_structure(self, select_stmt, **kwargs): select = select_stmt @@ -1244,8 +1648,75 @@ def visit_regexp_replace_op_binary(self, binary, operator, **kw): def visit_aggregate_strings_func(self, fn, **kw): return "LISTAGG%s" % self.function_argspec(fn, **kw) + def _visit_bitwise(self, binary, fn_name, custom_right=None, **kw): + left = self.process(binary.left, **kw) + right = self.process( + custom_right if custom_right is not None else binary.right, **kw + ) + return f"{fn_name}({left}, {right})" + + def visit_bitwise_xor_op_binary(self, binary, operator, **kw): + return self._visit_bitwise(binary, "BITXOR", **kw) + + def visit_bitwise_or_op_binary(self, binary, operator, **kw): + return self._visit_bitwise(binary, "BITOR", **kw) + + def visit_bitwise_and_op_binary(self, binary, operator, **kw): + return self._visit_bitwise(binary, "BITAND", **kw) + + def visit_bitwise_rshift_op_binary(self, binary, operator, **kw): + raise exc.CompileError("Cannot compile bitwise_rshift in oracle") + + def visit_bitwise_lshift_op_binary(self, binary, operator, **kw): + raise exc.CompileError("Cannot compile bitwise_lshift in oracle") + + def visit_bitwise_not_op_unary_operator(self, element, operator, **kw): + raise exc.CompileError("Cannot compile bitwise_not in oracle") + class OracleDDLCompiler(compiler.DDLCompiler): + + def _build_vector_index_config( + self, vector_index_config: VectorIndexConfig + ) -> str: + parts = [] + sql_param_name = { + "hnsw_neighbors": "neighbors", + "hnsw_efconstruction": "efconstruction", + "ivf_neighbor_partitions": "neighbor partitions", + "ivf_sample_per_partition": "sample_per_partition", + "ivf_min_vectors_per_partition": "min_vectors_per_partition", + } + if vector_index_config.index_type == VectorIndexType.HNSW: + parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") + elif vector_index_config.index_type == VectorIndexType.IVF: + parts.append("ORGANIZATION NEIGHBOR PARTITIONS") + if vector_index_config.distance is not None: + parts.append(f"DISTANCE {vector_index_config.distance.value}") + + if vector_index_config.accuracy is not None: + parts.append( + f"WITH TARGET ACCURACY {vector_index_config.accuracy}" + ) + + parameters_str = [f"type {vector_index_config.index_type.name}"] + prefix = vector_index_config.index_type.name.lower() + "_" + + for field in fields(vector_index_config): + if field.name.startswith(prefix): + key = sql_param_name.get(field.name) + value = getattr(vector_index_config, field.name) + if value is not None: + parameters_str.append(f"{key} {value}") + + parameters_str = ", ".join(parameters_str) + parts.append(f"PARAMETERS ({parameters_str})") + + if vector_index_config.parallel is not None: + parts.append(f"PARALLEL {vector_index_config.parallel}") + + return " ".join(parts) + def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -1253,10 +1724,10 @@ def define_constraint_cascades(self, constraint): # oracle has no ON UPDATE CASCADE - # its only available via triggers - # https://asktom.oracle.com/tkyte/update_cascade/index.html + # https://web.archive.org/web/20090317041251/https://asktom.oracle.com/tkyte/update_cascade/index.html if constraint.onupdate is not None: util.warn( - "Oracle does not contain native UPDATE CASCADE " + "Oracle Database does not contain native UPDATE CASCADE " "functionality - onupdates will not be rendered for foreign " "keys. Consider using deferrable=True, initially='deferred' " "or triggers." @@ -1278,6 +1749,9 @@ def visit_create_index(self, create, **kw): text += "UNIQUE " if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " + vector_options = index.dialect_options["oracle"]["vector"] + if vector_options: + text += "VECTOR " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), @@ -1295,6 +1769,11 @@ def visit_create_index(self, create, **kw): text += " COMPRESS %d" % ( index.dialect_options["oracle"]["compress"] ) + if vector_options: + if vector_options is True: + vector_options = VectorIndexConfig() + + text += " " + self._build_vector_index_config(vector_options) return text def post_create_table(self, table): @@ -1310,7 +1789,10 @@ def post_create_table(self, table): table_opts.append("\n COMPRESS") else: table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) - + if opts["tablespace"]: + table_opts.append( + "\n TABLESPACE %s" % self.preparer.quote(opts["tablespace"]) + ) return "".join(table_opts) def get_identity_options(self, identity_options): @@ -1328,8 +1810,9 @@ def visit_computed_column(self, generated, **kw): ) if generated.persisted is True: raise exc.CompileError( - "Oracle computed columns do not support 'stored' persistence; " - "set the 'persisted' flag to None or False for Oracle support." + "Oracle Database computed columns do not support 'stored' " + "persistence; set the 'persisted' flag to None or False for " + "Oracle Database support." ) elif generated.persisted is False: text += " VIRTUAL" @@ -1434,16 +1917,30 @@ class OracleDialect(default.DefaultDialect): construct_arguments = [ ( sa_schema.Table, - {"resolve_synonyms": False, "on_commit": None, "compress": False}, + { + "resolve_synonyms": False, + "on_commit": None, + "compress": False, + "tablespace": None, + }, + ), + ( + sa_schema.Index, + { + "bitmap": False, + "compress": False, + "vector": False, + }, ), - (sa_schema.Index, {"bitmap": False, "compress": False}), + (sa_selectable.Select, {"fetch_approximate": False}), + (sa_selectable.CompoundSelect, {"fetch_approximate": False}), ] @util.deprecated_params( use_binds_for_limits=( "1.4", - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated. The dialect now renders LIMIT /OFFSET integers " + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated. The dialect now renders LIMIT / OFFSET integers " "inline in all cases using a post-compilation hook, so that the " "value is still represented by a 'bound parameter' on the Core " "Expression side.", @@ -1464,9 +1961,9 @@ def __init__( self.use_ansi = use_ansi self.optimize_limits = optimize_limits self.exclude_tablespaces = exclude_tablespaces - self.enable_offset_fetch = ( - self._supports_offset_fetch - ) = enable_offset_fetch + self.enable_offset_fetch = self._supports_offset_fetch = ( + enable_offset_fetch + ) def initialize(self, connection): super().initialize(connection) @@ -2036,8 +2533,17 @@ def _table_options_query( ): query = select( dictionary.all_tables.c.table_name, - dictionary.all_tables.c.compression, - dictionary.all_tables.c.compress_for, + ( + dictionary.all_tables.c.compression + if self._supports_table_compression + else sql.null().label("compression") + ), + ( + dictionary.all_tables.c.compress_for + if self._supports_table_compress_for + else sql.null().label("compress_for") + ), + dictionary.all_tables.c.tablespace_name, ).where(dictionary.all_tables.c.owner == owner) if has_filter_names: query = query.where( @@ -2129,11 +2635,12 @@ def get_multi_table_options( connection, query, dblink, returns_long=False, params=params ) - for table, compression, compress_for in result: + for table, compression, compress_for, tablespace in result: + data = default() if compression == "ENABLED": - data = {"oracle_compress": compress_for} - else: - data = default() + data["oracle_compress"] = compress_for + if tablespace: + data["oracle_tablespace"] = tablespace options[(schema, self.normalize_name(table))] = data if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: # add the views (no temporary views) @@ -2523,10 +3030,12 @@ def get_multi_table_comment( return ( ( (schema, self.normalize_name(table)), - {"text": comment} - if comment is not None - and not comment.startswith(ignore_mat_view) - else default(), + ( + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default() + ), ) for table, comment in result ) @@ -3068,9 +3577,11 @@ def get_multi_unique_constraints( table_uc[constraint_name] = uc = { "name": constraint_name, "column_names": [], - "duplicates_index": constraint_name - if constraint_name_orig in index_names - else None, + "duplicates_index": ( + constraint_name + if constraint_name_orig in index_names + else None + ), } else: uc = table_uc[constraint_name] @@ -3082,9 +3593,11 @@ def get_multi_unique_constraints( return ( ( key, - list(unique_cons[key].values()) - if key in unique_cons - else default(), + ( + list(unique_cons[key].values()) + if key in unique_cons + else default() + ), ) for key in ( (schema, self.normalize_name(obj_name)) @@ -3207,9 +3720,11 @@ def get_multi_check_constraints( return ( ( key, - check_constraints[key] - if key in check_constraints - else default(), + ( + check_constraints[key] + if key in check_constraints + else default() + ), ) for key in ( (schema, self.normalize_name(obj_name)) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index c595b56c562..0514ebbcd41 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -1,4 +1,5 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/cx_oracle.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -6,13 +7,18 @@ # mypy: ignore-errors -r""" -.. dialect:: oracle+cx_oracle +r""".. dialect:: oracle+cx_oracle :name: cx-Oracle :dbapi: cx_oracle :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] :url: https://oracle.github.io/python-cx_Oracle/ +Description +----------- + +cx_Oracle was the original driver for Oracle Database. It was superseded by +python-oracledb which should be used instead. + DSN vs. Hostname connections ----------------------------- @@ -22,27 +28,41 @@ Hostname Connections with Easy Connect Syntax ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Given a hostname, port and service name of the target Oracle Database, for -example from Oracle's `Easy Connect syntax -`_, -then connect in SQLAlchemy using the ``service_name`` query string parameter:: +Given a hostname, port and service name of the target database, for example +from Oracle Database's Easy Connect syntax then connect in SQLAlchemy using the +``service_name`` query string parameter:: - engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8") + engine = create_engine( + "oracle+cx_oracle://scott:tiger@hostname:port?service_name=myservice&encoding=UTF-8&nencoding=UTF-8" + ) -The `full Easy Connect syntax -`_ -is not supported. Instead, use a ``tnsnames.ora`` file and connect using a -DSN. +Note that the default driver value for encoding and nencoding was changed to +“UTF-8” in cx_Oracle 8.0 so these parameters can be omitted when using that +version, or later. -Connections with tnsnames.ora or Oracle Cloud -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +To use a full Easy Connect string, pass it as the ``dsn`` key value in a +:paramref:`_sa.create_engine.connect_args` dictionary:: -Alternatively, if no port, database name, or ``service_name`` is provided, the -dialect will use an Oracle DSN "connection string". This takes the "hostname" -portion of the URL as the data source name. For example, if the -``tnsnames.ora`` file contains a `Net Service Name -`_ -of ``myalias`` as below:: + import cx_Oracle + + e = create_engine( + "oracle+cx_oracle://@", + connect_args={ + "user": "scott", + "password": "tiger", + "dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60", + }, + ) + +Connections with tnsnames.ora or to Oracle Autonomous Database +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, if no port, database name, or service name is provided, the +dialect will use an Oracle Database DSN "connection string". This takes the +"hostname" portion of the URL as the data source name. For example, if the +``tnsnames.ora`` file contains a TNS Alias of ``myalias`` as below: + +.. sourcecode:: text myalias = (DESCRIPTION = @@ -57,19 +77,22 @@ hostname portion of the URL, without specifying a port, database name or ``service_name``:: - engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8") + engine = create_engine("oracle+cx_oracle://scott:tiger@myalias") -Users of Oracle Cloud should use this syntax and also configure the cloud +Users of Oracle Autonomous Database should use this syntax. If the database is +configured for mutural TLS ("mTLS"), then you must also configure the cloud wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases -`_. +`_. SID Connections ^^^^^^^^^^^^^^^ -To use Oracle's obsolete SID connection syntax, the SID can be passed in a -"database name" portion of the URL as below:: +To use Oracle Database's obsolete System Identifier connection syntax, the SID +can be passed in a "database name" portion of the URL:: - engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8") + engine = create_engine( + "oracle+cx_oracle://scott:tiger@hostname:port/dbname" + ) Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as follows:: @@ -78,17 +101,23 @@ >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname") '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))' +Note that although the SQLAlchemy syntax ``hostname:port/dbname`` looks like +Oracle's Easy Connect syntax it is different. It uses a SID in place of the +service name required by Easy Connect. The Easy Connect syntax does not +support SIDs. + Passing cx_Oracle connect arguments ----------------------------------- -Additional connection arguments can usually be passed via the URL -query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted -and converted to the correct symbol:: +Additional connection arguments can usually be passed via the URL query string; +particular symbols like ``SYSDBA`` are intercepted and converted to the correct +symbol:: e = create_engine( - "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true") + "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true" + ) -.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names +.. versionchanged:: 1.3 the cx_Oracle dialect now accepts all argument names within the URL string itself, to be passed to the cx_Oracle DBAPI. As was the case earlier but not correctly documented, the :paramref:`_sa.create_engine.connect_args` parameter also accepts all @@ -99,19 +128,20 @@ Any cx_Oracle parameter value and/or constant may be passed, such as:: import cx_Oracle + e = create_engine( "oracle+cx_oracle://user:pass@dsn", connect_args={ "encoding": "UTF-8", "nencoding": "UTF-8", "mode": cx_Oracle.SYSDBA, - "events": True - } + "events": True, + }, ) -Note that the default value for ``encoding`` and ``nencoding`` was changed to -"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that -version, or later. +Note that the default driver value for ``encoding`` and ``nencoding`` was +changed to "UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when +using that version, or later. Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver -------------------------------------------------------------------------- @@ -121,14 +151,19 @@ , such as:: e = create_engine( - "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False) + "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False + ) The parameters accepted by the cx_oracle dialect are as follows: -* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted - to 50. This setting is significant with cx_Oracle as the contents of LOB - objects are only readable within a "live" row (e.g. within a batch of - 50 rows). +* ``arraysize`` - set the cx_oracle.arraysize value on cursors; defaults + to ``None``, indicating that the driver default should be used (typically + the value is 100). This setting controls how many rows are buffered when + fetching rows, and can have a significant effect on performance when + modified. + + .. versionchanged:: 2.0.26 - changed the default value from 50 to None, + to use the default value of the driver itself. * ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. @@ -141,10 +176,16 @@ Using cx_Oracle SessionPool --------------------------- -The cx_Oracle library provides its own connection pool implementation that may -be used in place of SQLAlchemy's pooling functionality. This can be achieved -by using the :paramref:`_sa.create_engine.creator` parameter to provide a -function that returns a new connection, along with setting +The cx_Oracle driver provides its own connection pool implementation that may +be used in place of SQLAlchemy's pooling functionality. The driver pool +supports Oracle Database features such dead connection detection, connection +draining for planned database downtime, support for Oracle Application +Continuity and Transparent Application Continuity, and gives support for +Database Resident Connection Pooling (DRCP). + +Using the driver pool can be achieved by using the +:paramref:`_sa.create_engine.creator` parameter to provide a function that +returns a new connection, along with setting :paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable SQLAlchemy's pooling:: @@ -153,32 +194,41 @@ from sqlalchemy.pool import NullPool pool = cx_Oracle.SessionPool( - user="scott", password="tiger", dsn="orclpdb", - min=2, max=5, increment=1, threaded=True, - encoding="UTF-8", nencoding="UTF-8" + user="scott", + password="tiger", + dsn="orclpdb", + min=1, + max=4, + increment=1, + threaded=True, + encoding="UTF-8", + nencoding="UTF-8", ) - engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool) + engine = create_engine( + "oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool + ) The above engine may then be used normally where cx_Oracle's pool handles connection pooling:: with engine.connect() as conn: - print(conn.scalar("select 1 FROM dual")) - + print(conn.scalar("select 1 from dual")) As well as providing a scalable solution for multi-user applications, the cx_Oracle session pool supports some Oracle features such as DRCP and `Application Continuity `_. +Note that the pool creation parameters ``threaded``, ``encoding`` and +``nencoding`` were deprecated in later cx_Oracle releases. + Using Oracle Database Resident Connection Pooling (DRCP) -------------------------------------------------------- -When using Oracle's `DRCP -`_, -the best practice is to pass a connection class and "purity" when acquiring a -connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation +When using Oracle Database's DRCP, the best practice is to pass a connection +class and "purity" when acquiring a connection from the SessionPool. Refer to +the `cx_Oracle DRCP documentation `_. This can be achieved by wrapping ``pool.acquire()``:: @@ -188,21 +238,33 @@ from sqlalchemy.pool import NullPool pool = cx_Oracle.SessionPool( - user="scott", password="tiger", dsn="orclpdb", - min=2, max=5, increment=1, threaded=True, - encoding="UTF-8", nencoding="UTF-8" + user="scott", + password="tiger", + dsn="orclpdb", + min=2, + max=5, + increment=1, + threaded=True, + encoding="UTF-8", + nencoding="UTF-8", ) + def creator(): - return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF) + return pool.acquire( + cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF + ) - engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool) + + engine = create_engine( + "oracle+cx_oracle://", creator=creator, poolclass=NullPool + ) The above engine may then be used normally where cx_Oracle handles session pooling and Oracle Database additionally uses DRCP:: with engine.connect() as conn: - print(conn.scalar("select 1 FROM dual")) + print(conn.scalar("select 1 from dual")) .. _cx_oracle_unicode: @@ -210,24 +272,28 @@ def creator(): ------- As is the case for all DBAPIs under Python 3, all strings are inherently -Unicode strings. In all cases however, the driver requires an explicit +Unicode strings. In all cases however, the driver requires an explicit encoding configuration. Ensuring the Correct Client Encoding ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The long accepted standard for establishing client encoding for nearly all -Oracle related software is via the `NLS_LANG `_ -environment variable. cx_Oracle like most other Oracle drivers will use -this environment variable as the source of its encoding configuration. The -format of this variable is idiosyncratic; a typical value would be -``AMERICAN_AMERICA.AL32UTF8``. - -The cx_Oracle driver also supports a programmatic alternative which is to -pass the ``encoding`` and ``nencoding`` parameters directly to its -``.connect()`` function. These can be present in the URL as follows:: - - engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8") +Oracle Database related software is via the `NLS_LANG +`_ environment +variable. Older versions of cx_Oracle use this environment variable as the +source of its encoding configuration. The format of this variable is +Territory_Country.CharacterSet; a typical value would be +``AMERICAN_AMERICA.AL32UTF8``. cx_Oracle version 8 and later use the character +set "UTF-8" by default, and ignore the character set component of NLS_LANG. + +The cx_Oracle driver also supported a programmatic alternative which is to pass +the ``encoding`` and ``nencoding`` parameters directly to its ``.connect()`` +function. These can be present in the URL as follows:: + + engine = create_engine( + "oracle+cx_oracle://scott:tiger@tnsalias?encoding=UTF-8&nencoding=UTF-8" + ) For the meaning of the ``encoding`` and ``nencoding`` parameters, please consult @@ -242,25 +308,24 @@ def creator(): Unicode-specific Column datatypes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The Core expression language handles unicode data by use of the :class:`.Unicode` -and :class:`.UnicodeText` -datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by -default. When using these datatypes with Unicode data, it is expected that -the Oracle database is configured with a Unicode-aware character set, as well -as that the ``NLS_LANG`` environment variable is set appropriately, so that -the VARCHAR2 and CLOB datatypes can accommodate the data. +The Core expression language handles unicode data by use of the +:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond +to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using +these datatypes with Unicode data, it is expected that the database is +configured with a Unicode-aware character set, as well as that the ``NLS_LANG`` +environment variable is set appropriately (this applies to older versions of +cx_Oracle), so that the VARCHAR2 and CLOB datatypes can accommodate the data. -In the case that the Oracle database is not configured with a Unicode character +In the case that Oracle Database is not configured with a Unicode character set, the two options are to use the :class:`_types.NCHAR` and :class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag -``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, -which will cause the -SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause +the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / :class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. -.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` - datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes - unless the ``use_nchar_for_unicode=True`` is passed to the dialect +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database + datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect when :func:`_sa.create_engine` is called. @@ -269,7 +334,7 @@ def creator(): Encoding Errors ^^^^^^^^^^^^^^^ -For the unusual case that data in the Oracle database is present with a broken +For the unusual case that data in Oracle Database is present with a broken encoding, the dialect accepts a parameter ``encoding_errors`` which will be passed to Unicode decoding functions in order to affect how decoding errors are handled. The value is ultimately consumed by the Python `decode @@ -287,13 +352,13 @@ def creator(): ------------------------------------------------------------------------------- The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the -DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the +DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the datatypes that are bound to a SQL statement for Python values being passed as parameters. While virtually no other DBAPI assigns any use to the ``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its -interactions with the Oracle client interface, and in some scenarios it is not -possible for SQLAlchemy to know exactly how data should be bound, as some -settings can cause profoundly different performance characteristics, while +interactions with the Oracle Database client interface, and in some scenarios +it is not possible for SQLAlchemy to know exactly how data should be bound, as +some settings can cause profoundly different performance characteristics, while altering the type coercion behavior at the same time. Users of the cx_Oracle dialect are **strongly encouraged** to read through @@ -322,13 +387,16 @@ def creator(): engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + @event.listens_for(engine, "do_setinputsizes") def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): for bindparam, dbapitype in inputsizes.items(): - log.info( - "Bound parameter name: %s SQLAlchemy type: %r " - "DBAPI object: %s", - bindparam.key, bindparam.type, dbapitype) + log.info( + "Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s", + bindparam.key, + bindparam.type, + dbapitype, + ) Example 2 - remove all bindings to CLOB ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -342,43 +410,42 @@ def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + @event.listens_for(engine, "do_setinputsizes") def _remove_clob(inputsizes, cursor, statement, parameters, context): for bindparam, dbapitype in list(inputsizes.items()): if dbapitype is CLOB: del inputsizes[bindparam] -.. _cx_oracle_returning: - -RETURNING Support ------------------ - -The cx_Oracle dialect implements RETURNING using OUT parameters. -The dialect supports RETURNING fully. - .. _cx_oracle_lob: LOB Datatypes -------------- LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and -BLOB. Modern versions of cx_Oracle and oracledb are optimized for these -datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of -these newer type handlers by default. +BLOB. Modern versions of cx_Oracle is optimized for these datatypes to be +delivered as a single buffer. As such, SQLAlchemy makes use of these newer type +handlers by default. To disable the use of newer type handlers and deliver LOB objects as classic buffered objects with a ``read()`` method, the parameter ``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`, which takes place only engine-wide. +.. _cx_oracle_returning: + +RETURNING Support +----------------- + +The cx_Oracle dialect implements RETURNING using OUT parameters. +The dialect supports RETURNING fully. + Two Phase Transactions Not Supported -------------------------------------- +------------------------------------ -Two phase transactions are **not supported** under cx_Oracle due to poor -driver support. As of cx_Oracle 6.0b1, the interface for -two phase transactions has been changed to be more of a direct pass-through -to the underlying OCI layer with less automation. The additional logic -to support this system is not implemented in SQLAlchemy. +Two phase transactions are **not supported** under cx_Oracle due to poor driver +support. The newer :ref:`oracledb` dialect however **does** support two phase +transactions. .. _cx_oracle_numeric: @@ -389,20 +456,21 @@ def _remove_clob(inputsizes, cursor, statement, parameters, context): ``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in use, the :paramref:`.Numeric.asdecimal` flag determines if values should be -coerced to ``Decimal`` upon return, or returned as float objects. To make -matters more complicated under Oracle, Oracle's ``NUMBER`` type can also -represent integer values if the "scale" is zero, so the Oracle-specific -:class:`_oracle.NUMBER` type takes this into account as well. +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle Database, the ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle +Database-specific :class:`_oracle.NUMBER` type takes this into account as well. The cx_Oracle dialect makes extensive use of connection- and cursor-level "outputtypehandler" callables in order to coerce numeric values as requested. These callables are specific to the specific flavor of :class:`.Numeric` in -use, as well as if no SQLAlchemy typing objects are present. There are -observed scenarios where Oracle may sends incomplete or ambiguous information -about the numeric types being returned, such as a query where the numeric types -are buried under multiple levels of subquery. The type handlers do their best -to make the right decision in all cases, deferring to the underlying cx_Oracle -DBAPI for all those cases where the driver can make the best decision. +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle Database may send incomplete or ambiguous +information about the numeric types being returned, such as a query where the +numeric types are buried under multiple levels of subquery. The type handlers +do their best to make the right decision in all cases, deferring to the +underlying cx_Oracle DBAPI for all those cases where the driver can make the +best decision. When no typing objects are present, as when executing plain SQL strings, a default "outputtypehandler" is present which will generally return numeric @@ -814,6 +882,8 @@ def _generate_out_parameter_vars(self): out_parameters[name] = self.cursor.var( dbtype, + # this is fine also in oracledb_async since + # the driver will await the read coroutine outconverter=lambda value: value.read(), arraysize=len_params, ) @@ -832,9 +902,9 @@ def _generate_out_parameter_vars(self): ) for param in self.parameters: - param[ - quoted_bind_names.get(name, name) - ] = out_parameters[name] + param[quoted_bind_names.get(name, name)] = ( + out_parameters[name] + ) def _generate_cursor_outputtype_handler(self): output_handlers = {} @@ -1030,7 +1100,7 @@ def __init__( self, auto_convert_lobs=True, coerce_to_decimal=True, - arraysize=50, + arraysize=None, encoding_errors=None, threaded=None, **kwargs, @@ -1283,8 +1353,13 @@ def output_type_handler( cx_Oracle.CLOB, cx_Oracle.NCLOB, ): + typ = ( + cx_Oracle.DB_TYPE_VARCHAR + if default_type is cx_Oracle.CLOB + else cx_Oracle.DB_TYPE_NVARCHAR + ) return cursor.var( - cx_Oracle.DB_TYPE_NVARCHAR, + typ, _CX_ORACLE_MAGIC_LOB_SIZE, cursor.arraysize, **dialect._cursor_var_unicode_kwargs, @@ -1415,13 +1490,6 @@ def is_disconnect(self, e, connection, cursor): return False def create_xid(self): - """create a two-phase transaction ID. - - this id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). its format is unspecified. - - """ - id_ = random.randint(0, 2**128) return (0x1234, "%032x" % id_, "%032x" % 9) diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py index fdf47ef31ed..f785a66ef71 100644 --- a/lib/sqlalchemy/dialects/oracle/dictionary.py +++ b/lib/sqlalchemy/dialects/oracle/dictionary.py @@ -1,4 +1,5 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/dictionary.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index 7defbc9f064..c09d2bae0df 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -1,68 +1,639 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/oracledb.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors -r""" -.. dialect:: oracle+oracledb +r""".. dialect:: oracle+oracledb :name: python-oracledb :dbapi: oracledb :connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] :url: https://oracle.github.io/python-oracledb/ -python-oracledb is released by Oracle to supersede the cx_Oracle driver. -It is fully compatible with cx_Oracle and features both a "thin" client -mode that requires no dependencies, as well as a "thick" mode that uses -the Oracle Client Interface in the same way as cx_Oracle. +Description +----------- -.. seealso:: +Python-oracledb is the Oracle Database driver for Python. It features a default +"thin" client mode that requires no dependencies, and an optional "thick" mode +that uses Oracle Client libraries. It supports SQLAlchemy features including +two phase transactions and Asyncio. + +Python-oracle is the renamed, updated cx_Oracle driver. Oracle is no longer +doing any releases in the cx_Oracle namespace. + +The SQLAlchemy ``oracledb`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will + automatically select the sync version:: + + from sqlalchemy import create_engine + + sync_engine = create_engine( + "oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1" + ) + +* calling :func:`_asyncio.create_async_engine` with ``oracle+oracledb://...`` + will automatically select the async version:: - :ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver - as well. + from sqlalchemy.ext.asyncio import create_async_engine + + asyncio_engine = create_async_engine( + "oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1" + ) + + The asyncio version of the dialect may also be specified explicitly using the + ``oracledb_async`` suffix:: + + from sqlalchemy.ext.asyncio import create_async_engine + + asyncio_engine = create_async_engine( + "oracle+oracledb_async://scott:tiger@localhost?service_name=FREEPDB1" + ) + +.. versionadded:: 2.0.25 added support for the async version of oracledb. Thick mode support ------------------ -By default the ``python-oracledb`` is started in thin mode, that does not -require oracle client libraries to be installed in the system. The -``python-oracledb`` driver also support a "thick" mode, that behaves -similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI) -is installed. +By default, the python-oracledb driver runs in a "thin" mode that does not +require Oracle Client libraries to be installed. The driver also supports a +"thick" mode that uses Oracle Client libraries to get functionality such as +Oracle Application Continuity. + +To enable thick mode, call `oracledb.init_oracle_client() +`_ +explicitly, or pass the parameter ``thick_mode=True`` to +:func:`_sa.create_engine`. To pass custom arguments to +``init_oracle_client()``, like the ``lib_dir`` path, a dict may be passed, for +example:: -To enable this mode, the user may call ``oracledb.init_oracle_client`` -manually, or by passing the parameter ``thick_mode=True`` to -:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``, -like the ``lib_dir`` path, a dict may be passed to this parameter, as in:: + engine = sa.create_engine( + "oracle+oracledb://...", + thick_mode={ + "lib_dir": "/path/to/oracle/client/lib", + "config_dir": "/path/to/network_config_file_directory", + "driver_name": "my-app : 1.0.0", + }, + ) - engine = sa.create_engine("oracle+oracledb://...", thick_mode={ - "lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app" - }) +Note that passing a ``lib_dir`` path should only be done on macOS or +Windows. On Linux it does not behave as you might expect. .. seealso:: - https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client + python-oracledb documentation `Enabling python-oracledb Thick mode + `_ + +Connecting to Oracle Database +----------------------------- + +python-oracledb provides several methods of indicating the target database. +The dialect translates from a series of different URL forms. + +Given the hostname, port and service name of the target database, you can +connect in SQLAlchemy using the ``service_name`` query string parameter:: + + engine = create_engine( + "oracle+oracledb://scott:tiger@hostname:port?service_name=myservice" + ) + +Connecting with Easy Connect strings +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can pass any valid python-oracledb connection string as the ``dsn`` key +value in a :paramref:`_sa.create_engine.connect_args` dictionary. See +python-oracledb documentation `Oracle Net Services Connection Strings +`_. + +For example to use an `Easy Connect string +`_ +with a timeout to prevent connection establishment from hanging if the network +transport to the database cannot be establishd in 30 seconds, and also setting +a keep-alive time of 60 seconds to stop idle network connections from being +terminated by a firewall:: + + e = create_engine( + "oracle+oracledb://@", + connect_args={ + "user": "scott", + "password": "tiger", + "dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60", + }, + ) + +The Easy Connect syntax has been enhanced during the life of Oracle Database. +Review the documentation for your database version. The current documentation +is at `Understanding the Easy Connect Naming Method +`_. + +The general syntax is similar to: + +.. sourcecode:: text + + [[protocol:]//]host[:port][/[service_name]][?parameter_name=value{¶meter_name=value}] + +Note that although the SQLAlchemy URL syntax ``hostname:port/dbname`` looks +like Oracle's Easy Connect syntax, it is different. SQLAlchemy's URL requires a +system identifier (SID) for the ``dbname`` component:: + + engine = create_engine("oracle+oracledb://scott:tiger@hostname:port/sid") + +Easy Connect syntax does not support SIDs. It uses services names, which are +the preferred choice for connecting to Oracle Database. + +Passing python-oracledb connect arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Other python-oracledb driver `connection options +`_ +can be passed in ``connect_args``. For example:: + + e = create_engine( + "oracle+oracledb://@", + connect_args={ + "user": "scott", + "password": "tiger", + "dsn": "hostname:port/myservice", + "events": True, + "mode": oracledb.AUTH_MODE_SYSDBA, + }, + ) + +Connecting with tnsnames.ora TNS aliases +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If no port, database name, or service name is provided, the dialect will use an +Oracle Database DSN "connection string". This takes the "hostname" portion of +the URL as the data source name. For example, if the ``tnsnames.ora`` file +contains a `TNS Alias +`_ +of ``myalias`` as below: + +.. sourcecode:: text + + myalias = + (DESCRIPTION = + (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521)) + (CONNECT_DATA = + (SERVER = DEDICATED) + (SERVICE_NAME = orclpdb1) + ) + ) + +The python-oracledb dialect connects to this database service when ``myalias`` is the +hostname portion of the URL, without specifying a port, database name or +``service_name``:: + + engine = create_engine("oracle+oracledb://scott:tiger@myalias") + +Connecting to Oracle Autonomous Database +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Users of Oracle Autonomous Database should use either use the TNS Alias URL +shown above, or pass the TNS Alias as the ``dsn`` key value in a +:paramref:`_sa.create_engine.connect_args` dictionary. + +If Oracle Autonomous Database is configured for mutual TLS ("mTLS") +connections, then additional configuration is required as shown in `Connecting +to Oracle Cloud Autonomous Databases +`_. In +summary, Thick mode users should configure file locations and set the wallet +path in ``sqlnet.ora`` appropriately:: + + e = create_engine( + "oracle+oracledb://@", + thick_mode={ + # directory containing tnsnames.ora and cwallet.so + "config_dir": "/opt/oracle/wallet_dir", + }, + connect_args={ + "user": "scott", + "password": "tiger", + "dsn": "mydb_high", + }, + ) + +Thin mode users of mTLS should pass the appropriate directories and PEM wallet +password when creating the engine, similar to:: + + e = create_engine( + "oracle+oracledb://@", + connect_args={ + "user": "scott", + "password": "tiger", + "dsn": "mydb_high", + "config_dir": "/opt/oracle/wallet_dir", # directory containing tnsnames.ora + "wallet_location": "/opt/oracle/wallet_dir", # directory containing ewallet.pem + "wallet_password": "top secret", # password for the PEM file + }, + ) + +Typically ``config_dir`` and ``wallet_location`` are the same directory, which +is where the Oracle Autonomous Database wallet zip file was extracted. Note +this directory should be protected. + +Connection Pooling +------------------ + +Applications with multiple concurrent users should use connection pooling. A +minimal sized connection pool is also beneficial for long-running, single-user +applications that do not frequently use a connection. + +The python-oracledb driver provides its own connection pool implementation that +may be used in place of SQLAlchemy's pooling functionality. The driver pool +gives support for high availability features such as dead connection detection, +connection draining for planned database downtime, support for Oracle +Application Continuity and Transparent Application Continuity, and gives +support for `Database Resident Connection Pooling (DRCP) +`_. + +To take advantage of python-oracledb's pool, use the +:paramref:`_sa.create_engine.creator` parameter to provide a function that +returns a new connection, along with setting +:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable +SQLAlchemy's pooling:: + + import oracledb + from sqlalchemy import create_engine + from sqlalchemy import text + from sqlalchemy.pool import NullPool + + # Uncomment to use the optional python-oracledb Thick mode. + # Review the python-oracledb doc for the appropriate parameters + # oracledb.init_oracle_client() + + pool = oracledb.create_pool( + user="scott", + password="tiger", + dsn="localhost:1521/freepdb1", + min=1, + max=4, + increment=1, + ) + engine = create_engine( + "oracle+oracledb://", creator=pool.acquire, poolclass=NullPool + ) + +The above engine may then be used normally. Internally, python-oracledb handles +connection pooling:: + + with engine.connect() as conn: + print(conn.scalar(text("select 1 from dual"))) + +Refer to the python-oracledb documentation for `oracledb.create_pool() +`_ +for the arguments that can be used when creating a connection pool. + +.. _drcp: + +Using Oracle Database Resident Connection Pooling (DRCP) +-------------------------------------------------------- + +When using Oracle Database's Database Resident Connection Pooling (DRCP), the +best practice is to specify a connection class and "purity". Refer to the +`python-oracledb documentation on DRCP +`_. +For example:: + + import oracledb + from sqlalchemy import create_engine + from sqlalchemy import text + from sqlalchemy.pool import NullPool + + # Uncomment to use the optional python-oracledb Thick mode. + # Review the python-oracledb doc for the appropriate parameters + # oracledb.init_oracle_client() + + pool = oracledb.create_pool( + user="scott", + password="tiger", + dsn="localhost:1521/freepdb1", + min=1, + max=4, + increment=1, + cclass="MYCLASS", + purity=oracledb.PURITY_SELF, + ) + engine = create_engine( + "oracle+oracledb://", creator=pool.acquire, poolclass=NullPool + ) + +The above engine may then be used normally where python-oracledb handles +application connection pooling and Oracle Database additionally uses DRCP:: + + with engine.connect() as conn: + print(conn.scalar(text("select 1 from dual"))) + +If you wish to use different connection classes or purities for different +connections, then wrap ``pool.acquire()``:: + + import oracledb + from sqlalchemy import create_engine + from sqlalchemy import text + from sqlalchemy.pool import NullPool + + # Uncomment to use python-oracledb Thick mode. + # Review the python-oracledb doc for the appropriate parameters + # oracledb.init_oracle_client() + + pool = oracledb.create_pool( + user="scott", + password="tiger", + dsn="localhost:1521/freepdb1", + min=1, + max=4, + increment=1, + cclass="MYCLASS", + purity=oracledb.PURITY_SELF, + ) + + + def creator(): + return pool.acquire(cclass="MYOTHERCLASS", purity=oracledb.PURITY_NEW) + + + engine = create_engine( + "oracle+oracledb://", creator=creator, poolclass=NullPool + ) + +Engine Options consumed by the SQLAlchemy oracledb dialect outside of the driver +-------------------------------------------------------------------------------- + +There are also options that are consumed by the SQLAlchemy oracledb dialect +itself. These options are always passed directly to :func:`_sa.create_engine`, +such as:: + + e = create_engine("oracle+oracledb://user:pass@tnsalias", arraysize=500) + +The parameters accepted by the oracledb dialect are as follows: + +* ``arraysize`` - set the driver cursor.arraysize value. It defaults to + ``None``, indicating that the driver default value of 100 should be used. + This setting controls how many rows are buffered when fetching rows, and can + have a significant effect on performance if increased for queries that return + large numbers of rows. + + .. versionchanged:: 2.0.26 - changed the default value from 50 to None, + to use the default value of the driver itself. + +* ``auto_convert_lobs`` - defaults to True; See :ref:`oracledb_lob`. + +* ``coerce_to_decimal`` - see :ref:`oracledb_numeric` for detail. + +* ``encoding_errors`` - see :ref:`oracledb_unicode_encoding_errors` for detail. + +.. _oracledb_unicode: + +Unicode +------- + +As is the case for all DBAPIs under Python 3, all strings are inherently +Unicode strings. + +Ensuring the Correct Client Encoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In python-oracledb, the encoding used for all character data is "UTF-8". + +Unicode-specific Column datatypes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Core expression language handles unicode data by use of the +:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond +to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using +these datatypes with Unicode data, it is expected that the database is +configured with a Unicode-aware character set so that the VARCHAR2 and CLOB +datatypes can accommodate the data. + +In the case that Oracle Database is not configured with a Unicode character +set, the two options are to use the :class:`_types.NCHAR` and +:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause +the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. + +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database + datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect + when :func:`_sa.create_engine` is called. + + +.. _oracledb_unicode_encoding_errors: + +Encoding Errors +^^^^^^^^^^^^^^^ + +For the unusual case that data in Oracle Database is present with a broken +encoding, the dialect accepts a parameter ``encoding_errors`` which will be +passed to Unicode decoding functions in order to affect how decoding errors are +handled. The value is ultimately consumed by the Python `decode +`_ function, and +is passed both via python-oracledb's ``encodingErrors`` parameter consumed by +``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the +python-oracledb dialect makes use of both under different circumstances. + +.. versionadded:: 1.3.11 + + +.. _oracledb_setinputsizes: + +Fine grained control over python-oracledb data binding with setinputsizes +------------------------------------------------------------------------- + +The python-oracle DBAPI has a deep and fundamental reliance upon the usage of +the DBAPI ``setinputsizes()`` call. The purpose of this call is to establish +the datatypes that are bound to a SQL statement for Python values being passed +as parameters. While virtually no other DBAPI assigns any use to the +``setinputsizes()`` call, the python-oracledb DBAPI relies upon it heavily in +its interactions with the Oracle Database, and in some scenarios it is not +possible for SQLAlchemy to know exactly how data should be bound, as some +settings can cause profoundly different performance characteristics, while +altering the type coercion behavior at the same time. + +Users of the oracledb dialect are **strongly encouraged** to read through +python-oracledb's list of built-in datatype symbols at `Database Types +`_ +Note that in some cases, significant performance degradation can occur when +using these types vs. not. + +On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can +be used both for runtime visibility (e.g. logging) of the setinputsizes step as +well as to fully control how ``setinputsizes()`` is used on a per-statement +basis. + +.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes` + + +Example 1 - logging all setinputsizes calls +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following example illustrates how to log the intermediary values from a +SQLAlchemy perspective before they are converted to the raw ``setinputsizes()`` +parameter dictionary. The keys of the dictionary are :class:`.BindParameter` +objects which have a ``.key`` and a ``.type`` attribute:: + + from sqlalchemy import create_engine, event + + engine = create_engine( + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" + ) + + + @event.listens_for(engine, "do_setinputsizes") + def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in inputsizes.items(): + log.info( + "Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s", + bindparam.key, + bindparam.type, + dbapitype, + ) + +Example 2 - remove all bindings to CLOB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For performance, fetching LOB datatypes from Oracle Database is set by default +for the ``Text`` type within SQLAlchemy. This setting can be modified as +follows:: + + + from sqlalchemy import create_engine, event + from oracledb import CLOB + + engine = create_engine( + "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" + ) + + + @event.listens_for(engine, "do_setinputsizes") + def _remove_clob(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in list(inputsizes.items()): + if dbapitype is CLOB: + del inputsizes[bindparam] + +.. _oracledb_lob: + +LOB Datatypes +-------------- + +LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and +BLOB. Oracle Database can efficiently return these datatypes as a single +buffer. SQLAlchemy makes use of type handlers to do this by default. + +To disable the use of the type handlers and deliver LOB objects as classic +buffered objects with a ``read()`` method, the parameter +``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`. + +.. _oracledb_returning: + +RETURNING Support +----------------- + +The oracledb dialect implements RETURNING using OUT parameters. The dialect +supports RETURNING fully. +Two Phase Transaction Support +----------------------------- -.. versionadded:: 2.0.0 added support for oracledb driver. +Two phase transactions are fully supported with python-oracledb. (Thin mode +requires python-oracledb 2.3). APIs for two phase transactions are provided at +the Core level via :meth:`_engine.Connection.begin_twophase` and +:paramref:`_orm.Session.twophase` for transparent ORM use. + +.. versionchanged:: 2.0.32 added support for two phase transactions + +.. _oracledb_numeric: + +Precision Numerics +------------------ + +SQLAlchemy's numeric types can handle receiving and returning values as Python +``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a +subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in +use, the :paramref:`.Numeric.asdecimal` flag determines if values should be +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle Database, the ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle +Database-specific :class:`_oracle.NUMBER` type takes this into account as well. + +The oracledb dialect makes extensive use of connection- and cursor-level +"outputtypehandler" callables in order to coerce numeric values as requested. +These callables are specific to the specific flavor of :class:`.Numeric` in +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle Database may send incomplete or ambiguous +information about the numeric types being returned, such as a query where the +numeric types are buried under multiple levels of subquery. The type handlers +do their best to make the right decision in all cases, deferring to the +underlying python-oracledb DBAPI for all those cases where the driver can make +the best decision. + +When no typing objects are present, as when executing plain SQL strings, a +default "outputtypehandler" is present which will generally return numeric +values which specify precision and scale as Python ``Decimal`` objects. To +disable this coercion to decimal for performance reasons, pass the flag +``coerce_to_decimal=False`` to :func:`_sa.create_engine`:: + + engine = create_engine( + "oracle+oracledb://scott:tiger@tnsalias", coerce_to_decimal=False + ) + +The ``coerce_to_decimal`` flag only impacts the results of plain string +SQL statements that are not otherwise associated with a :class:`.Numeric` +SQLAlchemy type (or a subclass of such). + +.. versionchanged:: 1.2 The numeric handling system for the oracle dialects has + been reworked to take advantage of newer driver features as well as better + integration of outputtypehandlers. + +.. versionadded:: 2.0.0 added support for the python-oracledb driver. """ # noqa +from __future__ import annotations + +import collections import re +from typing import Any +from typing import TYPE_CHECKING -from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle +from . import cx_oracle as _cx_oracle from ... import exc +from ... import pool +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection +from ...engine import default +from ...util import asbool +from ...util import await_fallback +from ...util import await_only + +if TYPE_CHECKING: + from oracledb import AsyncConnection + from oracledb import AsyncCursor + +class OracleExecutionContext_oracledb( + _cx_oracle.OracleExecutionContext_cx_oracle +): + pass -class OracleDialect_oracledb(_OracleDialect_cx_oracle): + +class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_oracledb + driver = "oracledb" + _min_version = (1,) def __init__( self, auto_convert_lobs=True, coerce_to_decimal=True, - arraysize=50, + arraysize=None, encoding_errors=None, thick_mode=None, **kwargs, @@ -91,6 +662,10 @@ def import_dbapi(cls): def is_thin_mode(cls, connection): return connection.connection.dbapi_connection.thin + @classmethod + def get_async_dialect_cls(cls, url): + return OracleDialectAsync_oracledb + def _load_version(self, dbapi_module): version = (0, 0, 0) if dbapi_module is not None: @@ -100,10 +675,273 @@ def _load_version(self, dbapi_module): int(x) for x in m.group(1, 2, 3) if x is not None ) self.oracledb_ver = version - if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0): + if ( + self.oracledb_ver > (0, 0, 0) + and self.oracledb_ver < self._min_version + ): raise exc.InvalidRequestError( - "oracledb version 1 and above are supported" + f"oracledb version {self._min_version} and above are supported" + ) + + def do_begin_twophase(self, connection, xid): + conn_xis = connection.connection.xid(*xid) + connection.connection.tpc_begin(conn_xis) + connection.connection.info["oracledb_xid"] = conn_xis + + def do_prepare_twophase(self, connection, xid): + should_commit = connection.connection.tpc_prepare() + connection.info["oracledb_should_commit"] = should_commit + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if recover: + conn_xid = connection.connection.xid(*xid) + else: + conn_xid = None + connection.connection.tpc_rollback(conn_xid) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + conn_xid = None + if not is_prepared: + should_commit = connection.connection.tpc_prepare() + elif recover: + conn_xid = connection.connection.xid(*xid) + should_commit = True + else: + should_commit = connection.info["oracledb_should_commit"] + if should_commit: + connection.connection.tpc_commit(conn_xid) + + def do_recover_twophase(self, connection): + return [ + # oracledb seems to return bytes + ( + fi, + gti.decode() if isinstance(gti, bytes) else gti, + bq.decode() if isinstance(bq, bytes) else bq, + ) + for fi, gti, bq in connection.connection.tpc_recover() + ] + + def _check_max_identifier_length(self, connection): + if self.oracledb_ver >= (2, 5): + max_len = connection.connection.max_identifier_length + if max_len is not None: + return max_len + return super()._check_max_identifier_length(connection) + + +class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor): + _cursor: AsyncCursor + __slots__ = () + + @property + def outputtypehandler(self): + return self._cursor.outputtypehandler + + @outputtypehandler.setter + def outputtypehandler(self, value): + self._cursor.outputtypehandler = value + + def var(self, *args, **kwargs): + return self._cursor.var(*args, **kwargs) + + def close(self): + self._rows.clear() + self._cursor.close() + + def setinputsizes(self, *args: Any, **kwargs: Any) -> Any: + return self._cursor.setinputsizes(*args, **kwargs) + + def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor: + try: + return cursor.__enter__() + except Exception as error: + self._adapt_connection._handle_exception(error) + + async def _execute_async(self, operation, parameters): + # override to not use mutex, oracledb already has a mutex + + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if self._cursor.description and not self.server_side: + self._rows = collections.deque(await self._cursor.fetchall()) + return result + + async def _executemany_async( + self, + operation, + seq_of_parameters, + ): + # override to not use mutex, oracledb already has a mutex + return await self._cursor.executemany(operation, seq_of_parameters) + + def __enter__(self): + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + +class AsyncAdapt_oracledb_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor +): + __slots__ = () + + def close(self) -> None: + if self._cursor is not None: + self._cursor.close() + self._cursor = None # type: ignore + + +class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): + _connection: AsyncConnection + __slots__ = () + + thin = True + + _cursor_cls = AsyncAdapt_oracledb_cursor + _ss_cursor_cls = None + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self._connection.autocommit = value + + @property + def outputtypehandler(self): + return self._connection.outputtypehandler + + @outputtypehandler.setter + def outputtypehandler(self, value): + self._connection.outputtypehandler = value + + @property + def version(self): + return self._connection.version + + @property + def stmtcachesize(self): + return self._connection.stmtcachesize + + @stmtcachesize.setter + def stmtcachesize(self, value): + self._connection.stmtcachesize = value + + @property + def max_identifier_length(self): + return self._connection.max_identifier_length + + def cursor(self): + return AsyncAdapt_oracledb_cursor(self) + + def ss_cursor(self): + return AsyncAdapt_oracledb_ss_cursor(self) + + def xid(self, *args: Any, **kwargs: Any) -> Any: + return self._connection.xid(*args, **kwargs) + + def tpc_begin(self, *args: Any, **kwargs: Any) -> Any: + return self.await_(self._connection.tpc_begin(*args, **kwargs)) + + def tpc_commit(self, *args: Any, **kwargs: Any) -> Any: + return self.await_(self._connection.tpc_commit(*args, **kwargs)) + + def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any: + return self.await_(self._connection.tpc_prepare(*args, **kwargs)) + + def tpc_recover(self, *args: Any, **kwargs: Any) -> Any: + return self.await_(self._connection.tpc_recover(*args, **kwargs)) + + def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any: + return self.await_(self._connection.tpc_rollback(*args, **kwargs)) + + +class AsyncAdaptFallback_oracledb_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection +): + __slots__ = () + + +class OracledbAdaptDBAPI: + def __init__(self, oracledb) -> None: + self.oracledb = oracledb + + for k, v in self.oracledb.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async) + + if asbool(async_fallback): + return AsyncAdaptFallback_oracledb_connection( + self, await_fallback(creator_fn(*arg, **kw)) + ) + + else: + return AsyncAdapt_oracledb_connection( + self, await_only(creator_fn(*arg, **kw)) ) +class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb): + # restore default create cursor + create_cursor = default.DefaultExecutionContext.create_cursor + + def create_default_cursor(self): + # copy of OracleExecutionContext_cx_oracle.create_cursor + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def create_server_side_cursor(self): + c = self._dbapi_connection.ss_cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + +class OracleDialectAsync_oracledb(OracleDialect_oracledb): + is_async = True + supports_server_side_cursors = True + supports_statement_cache = True + execution_ctx_cls = OracleExecutionContextAsync_oracledb + + _min_version = (2,) + + # thick_mode mode is not supported by asyncio, oracledb will raise + @classmethod + def import_dbapi(cls): + import oracledb + + return OracledbAdaptDBAPI(oracledb) + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def get_driver_connection(self, connection): + return connection._connection + + dialect = OracleDialect_oracledb +dialect_async = OracleDialectAsync_oracledb diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index c8599e8e225..3587de9d011 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -1,3 +1,9 @@ +# dialects/oracle/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import create_engine @@ -83,7 +89,7 @@ def _oracle_drop_db(cfg, eng, ident): # cx_Oracle seems to occasionally leak open connections when a large # suite it run, even if we confirm we have zero references to # connection objects. - # while there is a "kill session" command in Oracle, + # while there is a "kill session" command in Oracle Database, # it unfortunately does not release the connection sufficiently. _ora_drop_ignore(conn, ident) _ora_drop_ignore(conn, "%s_ts1" % ident) diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 4f82c43c699..06aeaace2f5 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -1,4 +1,5 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/oracle/types.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -63,17 +64,18 @@ def _type_affinity(self): class FLOAT(sqltypes.FLOAT): - """Oracle FLOAT. + """Oracle Database FLOAT. This is the same as :class:`_sqltypes.FLOAT` except that - an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + an Oracle Database -specific :paramref:`_oracle.FLOAT.binary_precision` parameter is accepted, and the :paramref:`_sqltypes.Float.precision` parameter is not accepted. - Oracle FLOAT types indicate precision in terms of "binary precision", which - defaults to 126. For a REAL type, the value is 63. This parameter does not - cleanly map to a specific number of decimal places but is roughly - equivalent to the desired number of decimal places divided by 0.3103. + Oracle Database FLOAT types indicate precision in terms of "binary + precision", which defaults to 126. For a REAL type, the value is 63. This + parameter does not cleanly map to a specific number of decimal places but + is roughly equivalent to the desired number of decimal places divided by + 0.3103. .. versionadded:: 2.0 @@ -90,10 +92,11 @@ def __init__( r""" Construct a FLOAT - :param binary_precision: Oracle binary precision value to be rendered - in DDL. This may be approximated to the number of decimal characters - using the formula "decimal precision = 0.30103 * binary precision". - The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + :param binary_precision: Oracle Database binary precision value to be + rendered in DDL. This may be approximated to the number of decimal + characters using the formula "decimal precision = 0.30103 * binary + precision". The default value used by Oracle Database for FLOAT / + DOUBLE PRECISION is 126. :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` @@ -108,10 +111,36 @@ def __init__( class BINARY_DOUBLE(sqltypes.Double): + """Implement the Oracle ``BINARY_DOUBLE`` datatype. + + This datatype differs from the Oracle ``DOUBLE`` datatype in that it + delivers a true 8-byte FP value. The datatype may be combined with a + generic :class:`.Double` datatype using :meth:`.TypeEngine.with_variant`. + + .. seealso:: + + :ref:`oracle_float_support` + + + """ + __visit_name__ = "BINARY_DOUBLE" class BINARY_FLOAT(sqltypes.Float): + """Implement the Oracle ``BINARY_FLOAT`` datatype. + + This datatype differs from the Oracle ``FLOAT`` datatype in that it + delivers a true 4-byte FP value. The datatype may be combined with a + generic :class:`.Float` datatype using :meth:`.TypeEngine.with_variant`. + + .. seealso:: + + :ref:`oracle_float_support` + + + """ + __visit_name__ = "BINARY_FLOAT" @@ -162,10 +191,10 @@ def process(value): class DATE(_OracleDateLiteralRender, sqltypes.DateTime): - """Provide the oracle DATE type. + """Provide the Oracle Database DATE type. This type has no special Python behavior, except that it subclasses - :class:`_types.DateTime`; this is to suit the fact that the Oracle + :class:`_types.DateTime`; this is to suit the fact that the Oracle Database ``DATE`` type supports a time value. """ @@ -245,8 +274,8 @@ def process(value: dt.timedelta) -> str: class TIMESTAMP(sqltypes.TIMESTAMP): - """Oracle implementation of ``TIMESTAMP``, which supports additional - Oracle-specific modes + """Oracle Database implementation of ``TIMESTAMP``, which supports + additional Oracle Database-specific modes .. versionadded:: 2.0 @@ -256,10 +285,11 @@ def __init__(self, timezone: bool = False, local_timezone: bool = False): """Construct a new :class:`_oracle.TIMESTAMP`. :param timezone: boolean. Indicates that the TIMESTAMP type should - use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype. + use Oracle Database's ``TIMESTAMP WITH TIME ZONE`` datatype. :param local_timezone: boolean. Indicates that the TIMESTAMP type - should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype. + should use Oracle Database's ``TIMESTAMP WITH LOCAL TIME ZONE`` + datatype. """ @@ -272,7 +302,7 @@ def __init__(self, timezone: bool = False, local_timezone: bool = False): class ROWID(sqltypes.TypeEngine): - """Oracle ROWID type. + """Oracle Database ROWID type. When used in a cast() or similar, generates ROWID. diff --git a/lib/sqlalchemy/dialects/oracle/vector.py b/lib/sqlalchemy/dialects/oracle/vector.py new file mode 100644 index 00000000000..dae89d3418d --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/vector.py @@ -0,0 +1,266 @@ +# dialects/oracle/vector.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import array +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import sqlalchemy.types as types +from sqlalchemy.types import Float + + +class VectorIndexType(Enum): + """Enum representing different types of VECTOR index structures. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + HNSW = "HNSW" + """ + The HNSW (Hierarchical Navigable Small World) index type. + """ + IVF = "IVF" + """ + The IVF (Inverted File Index) index type + """ + + +class VectorDistanceType(Enum): + """Enum representing different types of vector distance metrics. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + EUCLIDEAN = "EUCLIDEAN" + """Euclidean distance (L2 norm). + + Measures the straight-line distance between two vectors in space. + """ + DOT = "DOT" + """Dot product similarity. + + Measures the algebraic similarity between two vectors. + """ + COSINE = "COSINE" + """Cosine similarity. + + Measures the cosine of the angle between two vectors. + """ + MANHATTAN = "MANHATTAN" + """Manhattan distance (L1 norm). + + Calculates the sum of absolute differences across dimensions. + """ + + +class VectorStorageFormat(Enum): + """Enum representing the data format used to store vector components. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + INT8 = "INT8" + """ + 8-bit integer format. + """ + BINARY = "BINARY" + """ + Binary format. + """ + FLOAT32 = "FLOAT32" + """ + 32-bit floating-point format. + """ + FLOAT64 = "FLOAT64" + """ + 64-bit floating-point format. + """ + + +@dataclass +class VectorIndexConfig: + """Define the configuration for Oracle VECTOR Index. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + :param index_type: Enum value from :class:`.VectorIndexType` + Specifies the indexing method. For HNSW, this must be + :attr:`.VectorIndexType.HNSW`. + + :param distance: Enum value from :class:`.VectorDistanceType` + specifies the metric for calculating distance between VECTORS. + + :param accuracy: interger. Should be in the range 0 to 100 + Specifies the accuracy of the nearest neighbor search during + query execution. + + :param parallel: integer. Specifies degree of parallelism. + + :param hnsw_neighbors: interger. Should be in the range 0 to + 2048. Specifies the number of nearest neighbors considered + during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors` + is HNSW index specific. + + :param hnsw_efconstruction: integer. Should be in the range 0 + to 65535. Controls the trade-off between indexing speed and + recall quality during index construction. The attribute + :attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index + specific. + + :param ivf_neighbor_partitions: integer. Should be in the range + 0 to 10,000,000. Specifies the number of partitions used to + divide the dataset. The attribute + :attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index + specific. + + :param ivf_sample_per_partition: integer. Should be between 1 + and ``num_vectors / neighbor partitions``. Specifies the + number of samples used per partition. The attribute + :attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index + specific. + + :param ivf_min_vectors_per_partition: integer. From 0 (no trimming) + to the total number of vectors (results in 1 partition). Specifies + the minimum number of vectors per partition. The attribute + :attr:`.VectorIndexConfig.ivf_min_vectors_per_partition` + is IVF index specific. + + """ + + index_type: VectorIndexType = VectorIndexType.HNSW + distance: Optional[VectorDistanceType] = None + accuracy: Optional[int] = None + hnsw_neighbors: Optional[int] = None + hnsw_efconstruction: Optional[int] = None + ivf_neighbor_partitions: Optional[int] = None + ivf_sample_per_partition: Optional[int] = None + ivf_min_vectors_per_partition: Optional[int] = None + parallel: Optional[int] = None + + def __post_init__(self): + self.index_type = VectorIndexType(self.index_type) + for field in [ + "hnsw_neighbors", + "hnsw_efconstruction", + "ivf_neighbor_partitions", + "ivf_sample_per_partition", + "ivf_min_vectors_per_partition", + "parallel", + "accuracy", + ]: + value = getattr(self, field) + if value is not None and not isinstance(value, int): + raise TypeError( + f"{field} must be an integer if" + f"provided, got {type(value).__name__}" + ) + + +class VECTOR(types.TypeEngine): + """Oracle VECTOR datatype. + + For complete background on using this type, see + :ref:`oracle_vector_datatype`. + + .. versionadded:: 2.0.41 + + """ + + cache_ok = True + __visit_name__ = "VECTOR" + + _typecode_map = { + VectorStorageFormat.INT8: "b", # Signed int + VectorStorageFormat.BINARY: "B", # Unsigned int + VectorStorageFormat.FLOAT32: "f", # Float + VectorStorageFormat.FLOAT64: "d", # Double + } + + def __init__(self, dim=None, storage_format=None): + """Construct a VECTOR. + + :param dim: integer. The dimension of the VECTOR datatype. This + should be an integer value. + + :param storage_format: VectorStorageFormat. The VECTOR storage + type format. This may be Enum values form + :class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64. + + """ + if dim is not None and not isinstance(dim, int): + raise TypeError("dim must be an interger") + if storage_format is not None and not isinstance( + storage_format, VectorStorageFormat + ): + raise TypeError( + "storage_format must be an enum of type VectorStorageFormat" + ) + self.dim = dim + self.storage_format = storage_format + + def _cached_bind_processor(self, dialect): + """ + Convert a list to a array.array before binding it to the database. + """ + + def process(value): + if value is None or isinstance(value, array.array): + return value + + # Convert list to a array.array + elif isinstance(value, list): + typecode = self._array_typecode(self.storage_format) + value = array.array(typecode, value) + return value + + else: + raise TypeError("VECTOR accepts list or array.array()") + + return process + + def _cached_result_processor(self, dialect, coltype): + """ + Convert a array.array to list before binding it to the database. + """ + + def process(value): + if isinstance(value, array.array): + return list(value) + + return process + + def _array_typecode(self, typecode): + """ + Map storage format to array typecode. + """ + return self._typecode_map.get(typecode, "d") + + class comparator_factory(types.TypeEngine.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index c3ed7c1fc00..88935e20245 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -1,5 +1,5 @@ -# postgresql/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,6 +8,7 @@ from types import ModuleType +from . import array as arraylib # noqa # keep above base and other dialects from . import asyncpg # noqa from . import base from . import pg8000 # noqa @@ -56,12 +57,14 @@ from .named_types import NamedType from .ranges import AbstractMultiRange from .ranges import AbstractRange +from .ranges import AbstractSingleRange from .ranges import DATEMULTIRANGE from .ranges import DATERANGE from .ranges import INT4MULTIRANGE from .ranges import INT4RANGE from .ranges import INT8MULTIRANGE from .ranges import INT8RANGE +from .ranges import MultiRange from .ranges import NUMMULTIRANGE from .ranges import NUMRANGE from .ranges import Range @@ -86,6 +89,7 @@ from .types import TSQUERY from .types import TSVECTOR + # Alias psycopg also as psycopg_async psycopg_async = type( "psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async} diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index dfb25a56890..9b09868bd3a 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -1,4 +1,5 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/_psycopg_common.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -170,7 +171,6 @@ def _do_autocommit(self, connection, value): connection.autocommit = value def do_ping(self, dbapi_connection): - cursor = None before_autocommit = dbapi_connection.autocommit if not before_autocommit: diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 3496ed6b636..96f6dc21a2d 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -1,18 +1,21 @@ -# postgresql/array.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/array.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors from __future__ import annotations import re -from typing import Any +from typing import Any as typing_Any +from typing import Iterable from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .operators import CONTAINED_BY from .operators import CONTAINS @@ -21,32 +24,55 @@ from ... import util from ...sql import expression from ...sql import operators -from ...sql._typing import _TypeEngineArgument - - -_T = TypeVar("_T", bound=Any) - - -def Any(other, arrexpr, operator=operators.eq): +from ...sql.visitors import InternalTraversal + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql._typing import _ColumnExpressionArgument + from ...sql._typing import _TypeEngineArgument + from ...sql.elements import ColumnElement + from ...sql.elements import Grouping + from ...sql.expression import BindParameter + from ...sql.operators import OperatorType + from ...sql.selectable import _SelectIterable + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import _ResultProcessorType + from ...sql.type_api import TypeEngine + from ...sql.visitors import _TraverseInternalsType + from ...util.typing import Self + + +_T = TypeVar("_T", bound=typing_Any) + + +def Any( + other: typing_Any, + arrexpr: _ColumnExpressionArgument[_T], + operator: OperatorType = operators.eq, +) -> ColumnElement[bool]: """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. See that method for details. """ - return arrexpr.any(other, operator) + return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 -def All(other, arrexpr, operator=operators.eq): +def All( + other: typing_Any, + arrexpr: _ColumnExpressionArgument[_T], + operator: OperatorType = operators.eq, +) -> ColumnElement[bool]: """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. See that method for details. """ - return arrexpr.all(other, operator) + return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 class array(expression.ExpressionClauseList[_T]): - """A PostgreSQL ARRAY literal. This is used to produce ARRAY literals in SQL expressions, e.g.:: @@ -55,20 +81,43 @@ class array(expression.ExpressionClauseList[_T]): from sqlalchemy.dialects import postgresql from sqlalchemy import select, func - stmt = select(array([1,2]) + array([3,4,5])) + stmt = select(array([1, 2]) + array([3, 4, 5])) print(stmt.compile(dialect=postgresql.dialect())) - Produces the SQL:: + Produces the SQL: + + .. sourcecode:: sql SELECT ARRAY[%(param_1)s, %(param_2)s] || ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 An instance of :class:`.array` will always have the datatype - :class:`_types.ARRAY`. The "inner" type of the array is inferred from - the values present, unless the ``type_`` keyword argument is passed:: + :class:`_types.ARRAY`. The "inner" type of the array is inferred from the + values present, unless the :paramref:`_postgresql.array.type_` keyword + argument is passed:: + + array(["foo", "bar"], type_=CHAR) + + When constructing an empty array, the :paramref:`_postgresql.array.type_` + argument is particularly important as PostgreSQL server typically requires + a cast to be rendered for the inner type in order to render an empty array. + SQLAlchemy's compilation for the empty array will produce this cast so + that:: + + stmt = array([], type_=Integer) + print(stmt.compile(dialect=postgresql.dialect())) + + Produces: + + .. sourcecode:: sql - array(['foo', 'bar'], type_=CHAR) + ARRAY[]::INTEGER[] + + As required by PostgreSQL for empty arrays. + + .. versionadded:: 2.0.40 added support to render empty PostgreSQL array + literals with a required cast. Multidimensional arrays are produced by nesting :class:`.array` constructs. The dimensionality of the final :class:`_types.ARRAY` @@ -77,16 +126,21 @@ class array(expression.ExpressionClauseList[_T]): type:: stmt = select( - array([ - array([1, 2]), array([3, 4]), array([column('q'), column('x')]) - ]) + array( + [array([1, 2]), array([3, 4]), array([column("q"), column("x")])] + ) ) print(stmt.compile(dialect=postgresql.dialect())) - Produces:: + Produces: - SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s], - ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1 + .. sourcecode:: sql + + SELECT ARRAY[ + ARRAY[%(param_1)s, %(param_2)s], + ARRAY[%(param_3)s, %(param_4)s], + ARRAY[q, x] + ] AS anon_1 .. versionadded:: 1.3.6 added support for multidimensional array literals @@ -94,42 +148,63 @@ class array(expression.ExpressionClauseList[_T]): :class:`_postgresql.ARRAY` - """ + """ # noqa: E501 __visit_name__ = "array" stringify_dialect = "postgresql" - inherit_cache = True - def __init__(self, clauses, **kw): - type_arg = kw.pop("type_", None) - super().__init__(operators.comma_op, *clauses, **kw) + _traverse_internals: _TraverseInternalsType = [ + ("clauses", InternalTraversal.dp_clauseelement_tuple), + ("type", InternalTraversal.dp_type), + ] - self._type_tuple = [arg.type for arg in self.clauses] + def __init__( + self, + clauses: Iterable[_T], + *, + type_: Optional[_TypeEngineArgument[_T]] = None, + **kw: typing_Any, + ): + r"""Construct an ARRAY literal. + + :param clauses: iterable, such as a list, containing elements to be + rendered in the array + :param type\_: optional type. If omitted, the type is inferred + from the contents of the array. + + """ + super().__init__(operators.comma_op, *clauses, **kw) main_type = ( - type_arg - if type_arg is not None - else self._type_tuple[0] - if self._type_tuple - else sqltypes.NULLTYPE + type_ + if type_ is not None + else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE ) if isinstance(main_type, ARRAY): self.type = ARRAY( main_type.item_type, - dimensions=main_type.dimensions + 1 - if main_type.dimensions is not None - else 2, - ) + dimensions=( + main_type.dimensions + 1 + if main_type.dimensions is not None + else 2 + ), + ) # type: ignore[assignment] else: - self.type = ARRAY(main_type) + self.type = ARRAY(main_type) # type: ignore[assignment] @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) - def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + def _bind_param( + self, + operator: OperatorType, + obj: typing_Any, + type_: Optional[TypeEngine[_T]] = None, + _assume_scalar: bool = False, + ) -> BindParameter[_T]: if _assume_scalar or operator is operators.getitem: return expression.BindParameter( None, @@ -148,16 +223,18 @@ def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): ) for o in obj ] - ) + ) # type: ignore[return-value] - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: return self -class ARRAY(sqltypes.ARRAY): +class ARRAY(sqltypes.ARRAY[_T]): """PostgreSQL ARRAY type. The :class:`_postgresql.ARRAY` type is constructed in the same way @@ -167,9 +244,11 @@ class ARRAY(sqltypes.ARRAY): from sqlalchemy.dialects import postgresql - mytable = Table("mytable", metadata, - Column("data", postgresql.ARRAY(Integer, dimensions=2)) - ) + mytable = Table( + "mytable", + metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)), + ) The :class:`_postgresql.ARRAY` type provides all operations defined on the core :class:`_types.ARRAY` type, including support for "dimensions", @@ -184,8 +263,9 @@ class also mytable.c.data.contains([1, 2]) - The :class:`_postgresql.ARRAY` type may not be supported on all - PostgreSQL DBAPIs; it is currently known to work on psycopg2 only. + Indexed access is one-based by default, to match that of PostgreSQL; + for zero-based indexed access, set + :paramref:`_postgresql.ARRAY.zero_indexes`. Additionally, the :class:`_postgresql.ARRAY` type does not work directly in @@ -204,6 +284,7 @@ class also from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.ext.mutable import MutableList + class SomeOrmClass(Base): # ... @@ -225,45 +306,9 @@ class SomeOrmClass(Base): """ - class Comparator(sqltypes.ARRAY.Comparator): - - """Define comparison operations for :class:`_types.ARRAY`. - - Note that these operations are in addition to those provided - by the base :class:`.types.ARRAY.Comparator` class, including - :meth:`.types.ARRAY.Comparator.any` and - :meth:`.types.ARRAY.Comparator.all`. - - """ - - def contains(self, other, **kwargs): - """Boolean expression. Test if elements are a superset of the - elements of the argument array expression. - - kwargs may be ignored by this operator but are required for API - conformance. - """ - return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - - def contained_by(self, other): - """Boolean expression. Test if elements are a proper subset of the - elements of the argument array expression. - """ - return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean - ) - - def overlap(self, other): - """Boolean expression. Test if array has elements in common with - an argument array expression. - """ - return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) - - comparator_factory = Comparator - def __init__( self, - item_type: _TypeEngineArgument[Any], + item_type: _TypeEngineArgument[_T], as_tuple: bool = False, dimensions: Optional[int] = None, zero_indexes: bool = False, @@ -272,7 +317,7 @@ def __init__( E.g.:: - Column('myarray', ARRAY(Integer)) + Column("myarray", ARRAY(Integer)) Arguments are: @@ -312,35 +357,63 @@ def __init__( self.dimensions = dimensions self.zero_indexes = zero_indexes - @property - def hashable(self): - return self.as_tuple + class Comparator(sqltypes.ARRAY.Comparator[_T]): + """Define comparison operations for :class:`_types.ARRAY`. - @property - def python_type(self): - return list + Note that these operations are in addition to those provided + by the base :class:`.types.ARRAY.Comparator` class, including + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. - def compare_values(self, x, y): - return x == y + """ + + def contains( + self, other: typing_Any, **kwargs: typing_Any + ) -> ColumnElement[bool]: + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other: typing_Any) -> ColumnElement[bool]: + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def overlap(self, other: typing_Any) -> ColumnElement[bool]: + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator @util.memoized_property - def _against_native_enum(self): + def _against_native_enum(self) -> bool: return ( isinstance(self.item_type, sqltypes.Enum) and self.item_type.native_enum ) - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: item_proc = self.item_type.dialect_impl(dialect).literal_processor( dialect ) if item_proc is None: return None - def to_str(elements): + def to_str(elements: Iterable[typing_Any]) -> str: return f"ARRAY[{', '.join(elements)}]" - def process(value): + def process(value: Sequence[typing_Any]) -> str: inner = self._apply_item_processor( value, item_proc, self.dimensions, to_str ) @@ -348,12 +421,16 @@ def process(value): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]: item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect ) - def process(value): + def process( + value: Optional[Sequence[typing_Any]], + ) -> Optional[list[typing_Any]]: if value is None: return value else: @@ -363,12 +440,16 @@ def process(value): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[Sequence[typing_Any]]: item_proc = self.item_type.dialect_impl(dialect).result_processor( dialect, coltype ) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value else: @@ -383,11 +464,13 @@ def process(value): super_rp = process pattern = re.compile(r"^{(.*)}$") - def handle_raw_string(value): - inner = pattern.match(value).group(1) + def handle_raw_string(value: str) -> list[str]: + inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501 return _split_enum_values(inner) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value # isinstance(value, str) is required to handle @@ -402,7 +485,7 @@ def process(value): return process -def _split_enum_values(array_string): +def _split_enum_values(array_string: str) -> list[str]: if '"' not in array_string: # no escape char is present so it can just split on the comma return array_string.split(",") if array_string else [] diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index ca35bf96075..096892127ba 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -1,5 +1,5 @@ -# postgresql/asyncpg.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under @@ -23,18 +23,10 @@ :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") - -The dialect can also be run as a "synchronous" dialect within the -:func:`_sa.create_engine` function, which will pass "await" calls into -an ad-hoc event loop. This mode of operation is of **limited use** -and is for special testing scenarios only. The mode can be enabled by -adding the SQLAlchemy-specific flag ``async_fallback`` to the URL -in conjunction with :func:`_sa.create_engine`:: - - # for testing purposes only; do not use in production! - engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") + engine = create_async_engine( + "postgresql+asyncpg://user:pass@hostname/dbname" + ) .. versionadded:: 1.4 @@ -89,11 +81,15 @@ argument):: - engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") + engine = create_async_engine( + "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500" + ) To disable the prepared statement cache, use a value of zero:: - engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") + engine = create_async_engine( + "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0" + ) .. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. @@ -123,8 +119,8 @@ .. _asyncpg_prepared_statement_name: -Prepared Statement Name ------------------------ +Prepared Statement Name with PGBouncer +-------------------------------------- By default, asyncpg enumerates prepared statements in numeric order, which can lead to errors if a name has already been taken for another prepared @@ -139,10 +135,10 @@ from uuid import uuid4 engine = create_async_engine( - "postgresql+asyncpg://user:pass@hostname/dbname", + "postgresql+asyncpg://user:pass@somepgbouncer/dbname", poolclass=NullPool, connect_args={ - 'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__', + "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__", }, ) @@ -152,7 +148,7 @@ https://github.com/sqlalchemy/sqlalchemy/issues/6467 -.. warning:: To prevent a buildup of useless prepared statements in +.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in your application, it's important to use the :class:`.NullPool` pool class, and to configure PgBouncer to use `DISCARD `_ when returning connections. The DISCARD command is used to release resources held by the db connection, @@ -182,13 +178,11 @@ from __future__ import annotations -import collections +from collections import deque import decimal import json as _py_json import re import time -from typing import cast -from typing import TYPE_CHECKING from . import json from . import ranges @@ -218,9 +212,6 @@ from ...util.concurrency import await_fallback from ...util.concurrency import await_only -if TYPE_CHECKING: - from typing import Iterable - class AsyncpgARRAY(PGARRAY): render_bind_cast = True @@ -274,20 +265,20 @@ class AsyncpgInteger(sqltypes.Integer): render_bind_cast = True -class AsyncpgBigInteger(sqltypes.BigInteger): +class AsyncpgSmallInteger(sqltypes.SmallInteger): render_bind_cast = True -class AsyncpgJSON(json.JSON): +class AsyncpgBigInteger(sqltypes.BigInteger): render_bind_cast = True + +class AsyncpgJSON(json.JSON): def result_processor(self, dialect, coltype): return None class AsyncpgJSONB(json.JSONB): - render_bind_cast = True - def result_processor(self, dialect, coltype): return None @@ -372,7 +363,7 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True -class _AsyncpgRange(ranges.AbstractRangeImpl): +class _AsyncpgRange(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): asyncpg_Range = dialect.dbapi.asyncpg.Range @@ -426,10 +417,7 @@ def to_range(value): ) return value - return [ - to_range(element) - for element in cast("Iterable[ranges.Range]", value) - ] + return [to_range(element) for element in value] return to_range @@ -448,7 +436,7 @@ def to_range(rvalue): return rvalue if value is not None: - value = [to_range(elem) for elem in value] + value = ranges.MultiRange(to_range(elem) for elem in value) return value @@ -506,7 +494,7 @@ class AsyncAdapt_asyncpg_cursor: def __init__(self, adapt_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self._rows = [] + self._rows = deque() self._cursor = None self.description = None self.arraysize = 1 @@ -514,7 +502,7 @@ def __init__(self, adapt_connection): self._invalidate_schema_cache_asof = 0 def close(self): - self._rows[:] = [] + self._rows.clear() def _handle_exception(self, error): self._adapt_connection._handle_exception(error) @@ -554,11 +542,12 @@ async def _prepare_and_execute(self, operation, parameters): self._cursor = await prepared_stmt.cursor(*parameters) self.rowcount = -1 else: - self._rows = await prepared_stmt.fetch(*parameters) + self._rows = deque(await prepared_stmt.fetch(*parameters)) status = prepared_stmt.get_statusmsg() reg = re.match( - r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status + r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", + status or "", ) if reg: self.rowcount = int(reg.group(1)) @@ -602,11 +591,11 @@ def setinputsizes(self, *inputsizes): def __iter__(self): while self._rows: - yield self._rows.pop(0) + yield self._rows.popleft() def fetchone(self): if self._rows: - return self._rows.pop(0) + return self._rows.popleft() else: return None @@ -614,13 +603,12 @@ def fetchmany(self, size=None): if size is None: size = self.arraysize - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval + rr = self._rows + return [rr.popleft() for _ in range(min(size, len(rr)))] def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] + retval = list(self._rows) + self._rows.clear() return retval @@ -630,23 +618,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): def __init__(self, adapt_connection): super().__init__(adapt_connection) - self._rowbuffer = None + self._rowbuffer = deque() def close(self): self._cursor = None - self._rowbuffer = None + self._rowbuffer.clear() def _buffer_rows(self): + assert self._cursor is not None new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) - self._rowbuffer = collections.deque(new_rows) + self._rowbuffer.extend(new_rows) def __aiter__(self): return self async def __anext__(self): - if not self._rowbuffer: - self._buffer_rows() - while True: while self._rowbuffer: yield self._rowbuffer.popleft() @@ -669,21 +655,19 @@ def fetchmany(self, size=None): if not self._rowbuffer: self._buffer_rows() - buf = list(self._rowbuffer) - lb = len(buf) + assert self._cursor is not None + rb = self._rowbuffer + lb = len(rb) if size > lb: - buf.extend( + rb.extend( self._adapt_connection.await_(self._cursor.fetch(size - lb)) ) - result = buf[0:size] - self._rowbuffer = collections.deque(buf[size:]) - return result + return [rb.popleft() for _ in range(min(size, len(rb)))] def fetchall(self): - ret = list(self._rowbuffer) + list( - self._adapt_connection.await_(self._all()) - ) + ret = list(self._rowbuffer) + ret.extend(self._adapt_connection.await_(self._all())) self._rowbuffer.clear() return ret @@ -733,7 +717,7 @@ def __init__( ): self.dbapi = dbapi self._connection = connection - self.isolation_level = self._isolation_setting = "read_committed" + self.isolation_level = self._isolation_setting = None self.readonly = False self.deferrable = False self._transaction = None @@ -802,9 +786,9 @@ def _handle_exception(self, error): translated_error = exception_mapping[super_]( "%s: %s" % (type(error), error) ) - translated_error.pgcode = ( - translated_error.sqlstate - ) = getattr(error, "sqlstate", None) + translated_error.pgcode = translated_error.sqlstate = ( + getattr(error, "sqlstate", None) + ) raise translated_error from error else: raise error @@ -868,25 +852,45 @@ def cursor(self, server_side=False): else: return AsyncAdapt_asyncpg_cursor(self) + async def _rollback_and_discard(self): + try: + await self._transaction.rollback() + finally: + # if asyncpg .rollback() was actually called, then whether or + # not it raised or succeeded, the transation is done, discard it + self._transaction = None + self._started = False + + async def _commit_and_discard(self): + try: + await self._transaction.commit() + finally: + # if asyncpg .commit() was actually called, then whether or + # not it raised or succeeded, the transation is done, discard it + self._transaction = None + self._started = False + def rollback(self): if self._started: try: - self.await_(self._transaction.rollback()) - except Exception as error: - self._handle_exception(error) - finally: + self.await_(self._rollback_and_discard()) self._transaction = None self._started = False + except Exception as error: + # don't dereference asyncpg transaction if we didn't + # actually try to call rollback() on it + self._handle_exception(error) def commit(self): if self._started: try: - self.await_(self._transaction.commit()) - except Exception as error: - self._handle_exception(error) - finally: + self.await_(self._commit_and_discard()) self._transaction = None self._started = False + except Exception as error: + # don't dereference asyncpg transaction if we didn't + # actually try to call commit() on it + self._handle_exception(error) def close(self): self.rollback() @@ -894,7 +898,28 @@ def close(self): self.await_(self._connection.close()) def terminate(self): - self._connection.terminate() + if util.concurrency.in_greenlet(): + # in a greenlet; this is the connection was invalidated + # case. + try: + # try to gracefully close; see #10717 + # timeout added in asyncpg 0.14.0 December 2017 + self.await_(asyncio.shield(self._connection.close(timeout=2))) + except ( + asyncio.TimeoutError, + asyncio.CancelledError, + OSError, + self.dbapi.asyncpg.PostgresError, + ): + # in the case where we are recycling an old connection + # that may have already been disconnected, close() will + # fail with the above timeout. in this case, terminate + # the connection without any further waiting. + # see issue #8419 + self._connection.terminate() + else: + # not in a greenlet; this is the gc cleanup case + self._connection.terminate() self._started = False @staticmethod @@ -1031,6 +1056,7 @@ class PGDialect_asyncpg(PGDialect): INTERVAL: AsyncPgInterval, sqltypes.Boolean: AsyncpgBoolean, sqltypes.Integer: AsyncpgInteger, + sqltypes.SmallInteger: AsyncpgSmallInteger, sqltypes.BigInteger: AsyncpgBigInteger, sqltypes.Numeric: AsyncpgNumeric, sqltypes.Float: AsyncpgFloat, @@ -1045,7 +1071,7 @@ class PGDialect_asyncpg(PGDialect): OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, sqltypes.CHAR: AsyncpgCHAR, - ranges.AbstractRange: _AsyncpgRange, + ranges.AbstractSingleRange: _AsyncpgRange, ranges.AbstractMultiRange: _AsyncpgMultiRange, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index b9fd8c8baba..52f4721da9d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1,5 +1,5 @@ -# postgresql/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,6 @@ r""" .. dialect:: postgresql :name: PostgreSQL - :full_support: 12, 13, 14, 15 :normal_support: 9.6+ :best_effort: 9+ @@ -32,7 +31,7 @@ metadata, Column( "id", Integer, Sequence("some_id_seq", start=1), primary_key=True - ) + ), ) When SQLAlchemy issues a single INSERT statement, to fulfill the contract of @@ -64,9 +63,9 @@ "data", metadata, Column( - 'id', Integer, Identity(start=42, cycle=True), primary_key=True + "id", Integer, Identity(start=42, cycle=True), primary_key=True ), - Column('data', String) + Column("data", String), ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -93,23 +92,21 @@ from sqlalchemy.ext.compiler import compiles - @compiles(CreateColumn, 'postgresql') + @compiles(CreateColumn, "postgresql") def use_identity(element, compiler, **kw): text = compiler.visit_create_column(element, **kw) - text = text.replace( - "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY" - ) + text = text.replace("SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY") return text Using the above, a table such as:: t = Table( - 't', m, - Column('id', Integer, primary_key=True), - Column('data', String) + "t", m, Column("id", Integer, primary_key=True), Column("data", String) ) - Will generate on the backing database as:: + Will generate on the backing database as: + + .. sourcecode:: sql CREATE TABLE t ( id INT GENERATED BY DEFAULT AS IDENTITY, @@ -130,7 +127,9 @@ def use_identity(element, compiler, **kw): option:: with engine.connect() as conn: - result = conn.execution_options(stream_results=True).execute(text("select * from table")) + result = conn.execution_options(stream_results=True).execute( + text("select * from table") + ) Note that some kinds of SQL statements may not be supported with server side cursors; generally, only SQL statements that return rows should be @@ -169,17 +168,15 @@ def use_identity(element, compiler, **kw): engine = create_engine( "postgresql+pg8000://scott:tiger@localhost/test", - isolation_level = "REPEATABLE READ" + isolation_level="REPEATABLE READ", ) To set using per-connection execution options:: with engine.connect() as conn: - conn = conn.execution_options( - isolation_level="REPEATABLE READ" - ) + conn = conn.execution_options(isolation_level="REPEATABLE READ") with conn.begin(): - # ... work with transaction + ... # work with transaction There are also more options for isolation level configurations, such as "sub-engine" objects linked to a main :class:`_engine.Engine` which each apply @@ -222,10 +219,10 @@ def use_identity(element, compiler, **kw): conn = conn.execution_options( isolation_level="SERIALIZABLE", postgresql_readonly=True, - postgresql_deferrable=True + postgresql_deferrable=True, ) with conn.begin(): - # ... work with transaction + ... # work with transaction Note that some DBAPIs such as asyncpg only support "readonly" with SERIALIZABLE isolation. @@ -269,8 +266,7 @@ def use_identity(element, compiler, **kw): from sqlalchemy import event postgresql_engine = create_engine( - "postgresql+pyscopg2://scott:tiger@hostname/dbname", - + "postgresql+psycopg2://scott:tiger@hostname/dbname", # disable default reset-on-return scheme pool_reset_on_return=None, ) @@ -317,6 +313,7 @@ def _reset_postgresql(dbapi_connection, connection_record, reset_state): engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname") + @event.listens_for(engine, "connect", insert=True) def set_search_path(dbapi_connection, connection_record): existing_autocommit = dbapi_connection.autocommit @@ -335,9 +332,6 @@ def set_search_path(dbapi_connection, connection_record): :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation - - - .. _postgresql_schema_reflection: Remote-Schema Table Introspection and PostgreSQL search_path @@ -346,7 +340,9 @@ def set_search_path(dbapi_connection, connection_record): .. admonition:: Section Best Practices Summarized keep the ``search_path`` variable set to its default of ``public``, without - any other schema names. For other schema names, name these explicitly + any other schema names. Ensure the username used to connect **does not** + match remote schemas, or ensure the ``"$user"`` token is **removed** from + ``search_path``. For other schema names, name these explicitly within :class:`_schema.Table` definitions. Alternatively, the ``postgresql_ignore_search_path`` option will cause all reflected :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema` @@ -355,19 +351,78 @@ def set_search_path(dbapi_connection, connection_record): The PostgreSQL dialect can reflect tables from any schema, as outlined in :ref:`metadata_reflection_schemas`. +In all cases, the first thing SQLAlchemy does when reflecting tables is +to **determine the default schema for the current database connection**. +It does this using the PostgreSQL ``current_schema()`` +function, illustated below using a PostgreSQL client session (i.e. using +the ``psql`` tool): + +.. sourcecode:: sql + + test=> select current_schema(); + current_schema + ---------------- + public + (1 row) + +Above we see that on a plain install of PostgreSQL, the default schema name +is the name ``public``. + +However, if your database username **matches the name of a schema**, PostgreSQL's +default is to then **use that name as the default schema**. Below, we log in +using the username ``scott``. When we create a schema named ``scott``, **it +implicitly changes the default schema**: + +.. sourcecode:: sql + + test=> select current_schema(); + current_schema + ---------------- + public + (1 row) + + test=> create schema scott; + CREATE SCHEMA + test=> select current_schema(); + current_schema + ---------------- + scott + (1 row) + +The behavior of ``current_schema()`` is derived from the +`PostgreSQL search path +`_ +variable ``search_path``, which in modern PostgreSQL versions defaults to this: + +.. sourcecode:: sql + + test=> show search_path; + search_path + ----------------- + "$user", public + (1 row) + +Where above, the ``"$user"`` variable will inject the current username as the +default schema, if one exists. Otherwise, ``public`` is used. + +When a :class:`_schema.Table` object is reflected, if it is present in the +schema indicated by the ``current_schema()`` function, **the schema name assigned +to the ".schema" attribute of the Table is the Python "None" value**. Otherwise, the +".schema" attribute will be assigned the string name of that schema. + With regards to tables which these :class:`_schema.Table` objects refer to via foreign key constraint, a decision must be made as to how the ``.schema`` is represented in those remote tables, in the case where that -remote schema name is also a member of the current -`PostgreSQL search path -`_. +remote schema name is also a member of the current ``search_path``. By default, the PostgreSQL dialect mimics the behavior encouraged by PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function returns a sample definition for a particular foreign key constraint, omitting the referenced schema name from that definition when the name is also in the PostgreSQL schema search path. The interaction below -illustrates this behavior:: +illustrates this behavior: + +.. sourcecode:: sql test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); CREATE TABLE @@ -394,13 +449,17 @@ def set_search_path(dbapi_connection, connection_record): the function. On the other hand, if we set the search path back to the typical default -of ``public``:: +of ``public``: + +.. sourcecode:: sql test=> SET search_path TO public; SET The same query against ``pg_get_constraintdef()`` now returns the fully -schema-qualified name for us:: +schema-qualified name for us: + +.. sourcecode:: sql test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n @@ -422,16 +481,14 @@ def set_search_path(dbapi_connection, connection_record): >>> with engine.connect() as conn: ... conn.execute(text("SET search_path TO test_schema, public")) ... metadata_obj = MetaData() - ... referring = Table('referring', metadata_obj, - ... autoload_with=conn) - ... + ... referring = Table("referring", metadata_obj, autoload_with=conn) The above process would deliver to the :attr:`_schema.MetaData.tables` collection ``referred`` table named **without** the schema:: - >>> metadata_obj.tables['referred'].schema is None + >>> metadata_obj.tables["referred"].schema is None True To alter the behavior of reflection such that the referred schema is @@ -443,15 +500,17 @@ def set_search_path(dbapi_connection, connection_record): >>> with engine.connect() as conn: ... conn.execute(text("SET search_path TO test_schema, public")) ... metadata_obj = MetaData() - ... referring = Table('referring', metadata_obj, - ... autoload_with=conn, - ... postgresql_ignore_search_path=True) - ... + ... referring = Table( + ... "referring", + ... metadata_obj, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True, + ... ) We will now have ``test_schema.referred`` stored as schema-qualified:: - >>> metadata_obj.tables['test_schema.referred'].schema + >>> metadata_obj.tables["test_schema.referred"].schema 'test_schema' .. sidebar:: Best Practices for PostgreSQL Schema reflection @@ -466,13 +525,6 @@ def set_search_path(dbapi_connection, connection_record): described here are only for those users who can't, or prefer not to, stay within these guidelines. -Note that **in all cases**, the "default" schema is always reflected as -``None``. The "default" schema on PostgreSQL is that which is returned by the -PostgreSQL ``current_schema()`` function. On a typical PostgreSQL -installation, this is the name ``public``. So a table that refers to another -which is in the ``public`` (i.e. default) schema will always have the -``.schema`` attribute set to ``None``. - .. seealso:: :ref:`reflection_schema_qualified_interaction` - discussion of the issue @@ -492,18 +544,26 @@ def set_search_path(dbapi_connection, connection_record): use the :meth:`._UpdateBase.returning` method on a per-statement basis:: # INSERT..RETURNING - result = table.insert().returning(table.c.col1, table.c.col2).\ - values(name='foo') + result = ( + table.insert().returning(table.c.col1, table.c.col2).values(name="foo") + ) print(result.fetchall()) # UPDATE..RETURNING - result = table.update().returning(table.c.col1, table.c.col2).\ - where(table.c.name=='foo').values(name='bar') + result = ( + table.update() + .returning(table.c.col1, table.c.col2) + .where(table.c.name == "foo") + .values(name="bar") + ) print(result.fetchall()) # DELETE..RETURNING - result = table.delete().returning(table.c.col1, table.c.col2).\ - where(table.c.name=='foo') + result = ( + table.delete() + .returning(table.c.col1, table.c.col2) + .where(table.c.name == "foo") + ) print(result.fetchall()) .. _postgresql_insert_on_conflict: @@ -533,19 +593,16 @@ def set_search_path(dbapi_connection, connection_record): >>> from sqlalchemy.dialects.postgresql import insert >>> insert_stmt = insert(my_table).values( - ... id='some_existing_id', - ... data='inserted value') - >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( - ... index_elements=['id'] + ... id="some_existing_id", data="inserted value" ... ) + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"]) >>> print(do_nothing_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) ON CONFLICT (id) DO NOTHING {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint='pk_my_table', - ... set_=dict(data='updated value') + ... constraint="pk_my_table", set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -571,8 +628,7 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value') + ... index_elements=["id"], set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -580,8 +636,7 @@ def set_search_path(dbapi_connection, connection_record): {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=[my_table.c.id], - ... set_=dict(data='updated value') + ... index_elements=[my_table.c.id], set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -593,11 +648,11 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + >>> stmt = insert(my_table).values(user_email="a@b.com", data="inserted data") >>> stmt = stmt.on_conflict_do_update( ... index_elements=[my_table.c.user_email], - ... index_where=my_table.c.user_email.like('%@gmail.com'), - ... set_=dict(data=stmt.excluded.data) + ... index_where=my_table.c.user_email.like("%@gmail.com"), + ... set_=dict(data=stmt.excluded.data), ... ) >>> print(stmt) {printsql}INSERT INTO my_table (data, user_email) @@ -611,8 +666,7 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint='my_table_idx_1', - ... set_=dict(data='updated value') + ... constraint="my_table_idx_1", set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -620,8 +674,7 @@ def set_search_path(dbapi_connection, connection_record): {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint='my_table_pk', - ... set_=dict(data='updated value') + ... constraint="my_table_pk", set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -643,8 +696,7 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint=my_table.primary_key, - ... set_=dict(data='updated value') + ... constraint=my_table.primary_key, set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -662,10 +714,9 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value') + ... index_elements=["id"], set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -694,13 +745,11 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id='some_id', - ... data='inserted value', - ... author='jlh' + ... id="some_id", data="inserted value", author="jlh" ... ) >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value', author=stmt.excluded.author) + ... index_elements=["id"], + ... set_=dict(data="updated value", author=stmt.excluded.author), ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data, author) @@ -717,14 +766,12 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id='some_id', - ... data='inserted value', - ... author='jlh' + ... id="some_id", data="inserted value", author="jlh" ... ) >>> on_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value', author=stmt.excluded.author), - ... where=(my_table.c.status == 2) + ... index_elements=["id"], + ... set_=dict(data="updated value", author=stmt.excluded.author), + ... where=(my_table.c.status == 2), ... ) >>> print(on_update_stmt) {printsql}INSERT INTO my_table (id, data, author) @@ -742,8 +789,8 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') - >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) ON CONFLICT (id) DO NOTHING @@ -754,7 +801,7 @@ def set_search_path(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") >>> stmt = stmt.on_conflict_do_nothing() >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -785,7 +832,9 @@ def set_search_path(dbapi_connection, connection_record): select(sometable.c.text.match("search string")) -would emit to the database:: +would emit to the database: + +.. sourcecode:: sql SELECT text @@ plainto_tsquery('search string') FROM table @@ -801,11 +850,11 @@ def set_search_path(dbapi_connection, connection_record): from sqlalchemy import func - select( - sometable.c.text.bool_op("@@")(func.to_tsquery("search string")) - ) + select(sometable.c.text.bool_op("@@")(func.to_tsquery("search string"))) - Which would emit:: + Which would emit: + + .. sourcecode:: sql SELECT text @@ to_tsquery('search string') FROM table @@ -819,9 +868,7 @@ def set_search_path(dbapi_connection, connection_record): For example, the query:: - select( - func.to_tsquery('cat').bool_op("@>")(func.to_tsquery('cat & rat')) - ) + select(func.to_tsquery("cat").bool_op("@>")(func.to_tsquery("cat & rat"))) would generate: @@ -834,9 +881,12 @@ def set_search_path(dbapi_connection, connection_record): from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy import select, cast + select(cast("some text", TSVECTOR)) -produces a statement equivalent to:: +produces a statement equivalent to: + +.. sourcecode:: sql SELECT CAST('some text' AS TSVECTOR) AS anon_1 @@ -864,10 +914,12 @@ def set_search_path(dbapi_connection, connection_record): specified using the ``postgresql_regconfig`` parameter, such as:: select(mytable.c.id).where( - mytable.c.title.match('somestring', postgresql_regconfig='english') + mytable.c.title.match("somestring", postgresql_regconfig="english") ) -Which would emit:: +Which would emit: + +.. sourcecode:: sql SELECT mytable.id FROM mytable WHERE mytable.title @@ plainto_tsquery('english', 'somestring') @@ -881,7 +933,9 @@ def set_search_path(dbapi_connection, connection_record): ) ) -produces a statement equivalent to:: +produces a statement equivalent to: + +.. sourcecode:: sql SELECT mytable.id FROM mytable WHERE to_tsvector('english', mytable.title) @@ @@ -905,16 +959,16 @@ def set_search_path(dbapi_connection, connection_record): syntaxes. It uses SQLAlchemy's hints mechanism:: # SELECT ... FROM ONLY ... - result = table.select().with_hint(table, 'ONLY', 'postgresql') + result = table.select().with_hint(table, "ONLY", "postgresql") print(result.fetchall()) # UPDATE ONLY ... - table.update(values=dict(foo='bar')).with_hint('ONLY', - dialect_name='postgresql') + table.update(values=dict(foo="bar")).with_hint( + "ONLY", dialect_name="postgresql" + ) # DELETE FROM ONLY ... - table.delete().with_hint('ONLY', dialect_name='postgresql') - + table.delete().with_hint("ONLY", dialect_name="postgresql") .. _postgresql_indexes: @@ -924,18 +978,24 @@ def set_search_path(dbapi_connection, connection_record): Several extensions to the :class:`.Index` construct are available, specific to the PostgreSQL dialect. +.. _postgresql_covering_indexes: + Covering Indexes ^^^^^^^^^^^^^^^^ The ``postgresql_include`` option renders INCLUDE(colname) for the given string names:: - Index("my_index", table.c.x, postgresql_include=['y']) + Index("my_index", table.c.x, postgresql_include=["y"]) would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` Note that this feature requires PostgreSQL 11 or later. +.. seealso:: + + :ref:`postgresql_constraint_options` + .. versionadded:: 1.4 .. _postgresql_partial_indexes: @@ -947,7 +1007,7 @@ def set_search_path(dbapi_connection, connection_record): applied to a subset of rows. These can be specified on :class:`.Index` using the ``postgresql_where`` keyword argument:: - Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) + Index("my_index", my_table.c.id, postgresql_where=my_table.c.value > 10) .. _postgresql_operator_classes: @@ -961,11 +1021,11 @@ def set_search_path(dbapi_connection, connection_record): ``postgresql_ops`` keyword argument:: Index( - 'my_index', my_table.c.id, my_table.c.data, - postgresql_ops={ - 'data': 'text_pattern_ops', - 'id': 'int4_ops' - }) + "my_index", + my_table.c.id, + my_table.c.data, + postgresql_ops={"data": "text_pattern_ops", "id": "int4_ops"}, + ) Note that the keys in the ``postgresql_ops`` dictionaries are the "key" name of the :class:`_schema.Column`, i.e. the name used to access it from @@ -977,12 +1037,11 @@ def set_search_path(dbapi_connection, connection_record): that is identified in the dictionary by name, e.g.:: Index( - 'my_index', my_table.c.id, - func.lower(my_table.c.data).label('data_lower'), - postgresql_ops={ - 'data_lower': 'text_pattern_ops', - 'id': 'int4_ops' - }) + "my_index", + my_table.c.id, + func.lower(my_table.c.data).label("data_lower"), + postgresql_ops={"data_lower": "text_pattern_ops", "id": "int4_ops"}, + ) Operator classes are also supported by the :class:`_postgresql.ExcludeConstraint` construct using the @@ -1001,7 +1060,7 @@ def set_search_path(dbapi_connection, connection_record): https://www.postgresql.org/docs/current/static/indexes-types.html). These can be specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: - Index('my_index', my_table.c.data, postgresql_using='gin') + Index("my_index", my_table.c.data, postgresql_using="gin") The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX command, so it *must* be a valid index type for your @@ -1017,13 +1076,13 @@ def set_search_path(dbapi_connection, connection_record): parameters can be specified on :class:`.Index` using the ``postgresql_with`` keyword argument:: - Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) + Index("my_index", my_table.c.data, postgresql_with={"fillfactor": 50}) PostgreSQL allows to define the tablespace in which to create the index. The tablespace can be specified on :class:`.Index` using the ``postgresql_tablespace`` keyword argument:: - Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') + Index("my_index", my_table.c.data, postgresql_tablespace="my_tablespace") Note that the same option is available on :class:`_schema.Table` as well. @@ -1035,17 +1094,21 @@ def set_search_path(dbapi_connection, connection_record): The PostgreSQL index option CONCURRENTLY is supported by passing the flag ``postgresql_concurrently`` to the :class:`.Index` construct:: - tbl = Table('testtbl', m, Column('data', Integer)) + tbl = Table("testtbl", m, Column("data", Integer)) - idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + idx1 = Index("test_idx1", tbl.c.data, postgresql_concurrently=True) The above index construct will render DDL for CREATE INDEX, assuming -PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as: + +.. sourcecode:: sql CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for -a connection-less dialect, it will emit:: +a connection-less dialect, it will emit: + +.. sourcecode:: sql DROP INDEX CONCURRENTLY test_idx1 @@ -1055,14 +1118,11 @@ def set_search_path(dbapi_connection, connection_record): construct, the DBAPI's "autocommit" mode must be used:: metadata = MetaData() - table = Table( - "foo", metadata, - Column("id", String)) - index = Index( - "foo_idx", table.c.id, postgresql_concurrently=True) + table = Table("foo", metadata, Column("id", String)) + index = Index("foo_idx", table.c.id, postgresql_concurrently=True) with engine.connect() as conn: - with conn.execution_options(isolation_level='AUTOCOMMIT'): + with conn.execution_options(isolation_level="AUTOCOMMIT"): table.create(conn) .. seealso:: @@ -1112,36 +1172,49 @@ def set_search_path(dbapi_connection, connection_record): Several options for CREATE TABLE are supported directly by the PostgreSQL dialect in conjunction with the :class:`_schema.Table` construct: -* ``TABLESPACE``:: +* ``INHERITS``:: - Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") - The above option is also available on the :class:`.Index` construct. + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) * ``ON COMMIT``:: - Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') + Table("some_table", metadata, ..., postgresql_on_commit="PRESERVE ROWS") -* ``WITH OIDS``:: +* + ``PARTITION BY``:: - Table("some_table", metadata, ..., postgresql_with_oids=True) + Table( + "some_table", + metadata, + ..., + postgresql_partition_by="LIST (part_column)", + ) -* ``WITHOUT OIDS``:: + .. versionadded:: 1.2.6 - Table("some_table", metadata, ..., postgresql_with_oids=False) +* + ``TABLESPACE``:: -* ``INHERITS``:: + Table("some_table", metadata, ..., postgresql_tablespace="some_tablespace") - Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + The above option is also available on the :class:`.Index` construct. - Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) +* + ``USING``:: -* ``PARTITION BY``:: + Table("some_table", metadata, ..., postgresql_using="heap") - Table("some_table", metadata, ..., - postgresql_partition_by='LIST (part_column)') + .. versionadded:: 2.0.26 - .. versionadded:: 1.2.6 +* ``WITH OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=True) + +* ``WITHOUT OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=False) .. seealso:: @@ -1174,7 +1247,7 @@ def update(): "user", ["user_id"], ["id"], - postgresql_not_valid=True + postgresql_not_valid=True, ) The keyword is ultimately accepted directly by the @@ -1185,7 +1258,9 @@ def update(): CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True) - ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True) + ForeignKeyConstraint( + ["some_id"], ["some_table.some_id"], postgresql_not_valid=True + ) .. versionadded:: 1.4.32 @@ -1195,6 +1270,65 @@ def update(): `_ - in the PostgreSQL documentation. +* ``INCLUDE``: This option adds one or more columns as a "payload" to the + unique index created automatically by PostgreSQL for the constraint. + For example, the following table definition:: + + Table( + "mytable", + metadata, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + UniqueConstraint("id", postgresql_include=["value"]), + ) + + would produce the DDL statement + + .. sourcecode:: sql + + CREATE TABLE mytable ( + id INTEGER NOT NULL, + value INTEGER NOT NULL, + UNIQUE (id) INCLUDE (value) + ) + + Note that this feature requires PostgreSQL 11 or later. + + .. versionadded:: 2.0.41 + + .. seealso:: + + :ref:`postgresql_covering_indexes` + + .. seealso:: + + `PostgreSQL CREATE TABLE options + `_ - + in the PostgreSQL documentation. + +* Column list with foreign key ``ON DELETE SET`` actions: This applies to + :class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete` + parameter will accept on the PostgreSQL backend only a string list of column + names inside parenthesis, following the ``SET NULL`` or ``SET DEFAULT`` + phrases, which will limit the set of columns that are subject to the + action:: + + fktable = Table( + "fktable", + metadata, + Column("tid", Integer), + Column("id", Integer), + Column("fk_id_del_set_null", Integer), + ForeignKeyConstraint( + columns=["tid", "fk_id_del_set_null"], + refcolumns=[pktable.c.tid, pktable.c.id], + ondelete="SET NULL (fk_id_del_set_null)", + ), + ) + + .. versionadded:: 2.0.40 + + .. _postgresql_table_valued_overview: Table values, Table and Column valued functions, Row and Tuple objects @@ -1228,7 +1362,9 @@ def update(): .. sourcecode:: pycon+sql >>> from sqlalchemy import select, func - >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value")) + >>> stmt = select( + ... func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value") + ... ) >>> print(stmt) {printsql}SELECT anon_1.key, anon_1.value FROM json_each(:json_each_1) AS anon_1 @@ -1240,8 +1376,7 @@ def update(): >>> from sqlalchemy import select, func, literal_column >>> stmt = select( ... func.json_populate_record( - ... literal_column("null::myrowtype"), - ... '{"a":1,"b":2}' + ... literal_column("null::myrowtype"), '{"a":1,"b":2}' ... ).table_valued("a", "b", name="x") ... ) >>> print(stmt) @@ -1259,9 +1394,13 @@ def update(): >>> from sqlalchemy import select, func, column, Integer, Text >>> stmt = select( - ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued( - ... column("a", Integer), column("b", Text), column("d", Text), - ... ).render_derived(name="x", with_types=True) + ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}') + ... .table_valued( + ... column("a", Integer), + ... column("b", Text), + ... column("d", Text), + ... ) + ... .render_derived(name="x", with_types=True) ... ) >>> print(stmt) {printsql}SELECT x.a, x.b, x.d @@ -1278,9 +1417,9 @@ def update(): >>> from sqlalchemy import select, func >>> stmt = select( - ... func.generate_series(4, 1, -1). - ... table_valued("value", with_ordinality="ordinality"). - ... render_derived() + ... func.generate_series(4, 1, -1) + ... .table_valued("value", with_ordinality="ordinality") + ... .render_derived() ... ) >>> print(stmt) {printsql}SELECT anon_1.value, anon_1.ordinality @@ -1309,7 +1448,9 @@ def update(): .. sourcecode:: pycon+sql >>> from sqlalchemy import select, func - >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x")) + >>> stmt = select( + ... func.json_array_elements('["one", "two"]').column_valued("x") + ... ) >>> print(stmt) {printsql}SELECT x FROM json_array_elements(:json_array_elements_1) AS x @@ -1333,7 +1474,7 @@ def update(): >>> from sqlalchemy import table, column, ARRAY, Integer >>> from sqlalchemy import select, func - >>> t = table("t", column('value', ARRAY(Integer))) + >>> t = table("t", column("value", ARRAY(Integer))) >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value")) >>> print(stmt) {printsql}SELECT unnested_value @@ -1355,10 +1496,10 @@ def update(): >>> from sqlalchemy import table, column, func, tuple_ >>> t = table("t", column("id"), column("fk")) - >>> stmt = t.select().where( - ... tuple_(t.c.id, t.c.fk) > (1,2) - ... ).where( - ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7) + >>> stmt = ( + ... t.select() + ... .where(tuple_(t.c.id, t.c.fk) > (1, 2)) + ... .where(func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7)) ... ) >>> print(stmt) {printsql}SELECT t.id, t.fk @@ -1387,7 +1528,7 @@ def update(): .. sourcecode:: pycon+sql >>> from sqlalchemy import table, column, func, select - >>> a = table( "a", column("id"), column("x"), column("y")) + >>> a = table("a", column("id"), column("x"), column("y")) >>> stmt = select(func.row_to_json(a.table_valued())) >>> print(stmt) {printsql}SELECT row_to_json(a) AS row_to_json_1 @@ -1406,19 +1547,20 @@ def update(): import re from typing import Any from typing import cast +from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import Union -from . import array as _array -from . import hstore as _hstore +from . import arraylib as _array from . import json as _json from . import pg_catalog from . import ranges as _ranges from .ext import _regconfig_fn from .ext import aggregate_order_by +from .hstore import HSTORE from .named_types import CreateDomainType as CreateDomainType # noqa: F401 from .named_types import CreateEnumType as CreateEnumType # noqa: F401 from .named_types import DOMAIN as DOMAIN # noqa: F401 @@ -1596,6 +1738,7 @@ def update(): "verbose", } + colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -1608,7 +1751,7 @@ def update(): ischema_names = { "_array": _array.ARRAY, - "hstore": _hstore.HSTORE, + "hstore": HSTORE, "json": _json.JSON, "jsonb": _json.JSONB, "int4range": _ranges.INT4RANGE, @@ -1706,12 +1849,14 @@ def render_bind_cast(self, type_, dbapi_type, sqltext): # see #9511 dbapi_type = sqltypes.STRINGTYPE return f"""{sqltext}::{ - self.dialect.type_compiler_instance.process( - dbapi_type, identifier_preparer=self.preparer - ) - }""" + self.dialect.type_compiler_instance.process( + dbapi_type, identifier_preparer=self.preparer + ) + }""" def visit_array(self, element, **kw): + if not element.clauses and not element.type.item_type._isnull: + return "ARRAY[]::%s" % element.type.compile(self.dialect) return "ARRAY[%s]" % self.visit_clauselist(element, **kw) def visit_slice(self, element, **kw): @@ -1925,9 +2070,10 @@ def for_update_clause(self, select, **kw): for c in select._for_update_arg.of: tables.update(sql_util.surface_selectables_only(c)) + of_kw = dict(kw) + of_kw.update(ashint=True, use_schema=False) tmp += " OF " + ", ".join( - self.process(table, ashint=True, use_schema=False, **kw) - for table in tables + self.process(table, **of_kw) for table in tables ) if select._for_update_arg.nowait: @@ -2009,6 +2155,8 @@ def visit_on_conflict_do_update(self, on_conflict, **kw): else: continue + # TODO: this coercion should be up front. we can't cache + # SQL constructs with non-bound literals buried in them if coercions._is_literal(value): value = elements.BindParameter(None, value, type_=c.type) @@ -2086,9 +2234,11 @@ def fetch_clause(self, select, **kw): text += "\n FETCH FIRST (%s)%s ROWS %s" % ( self.process(select._fetch_clause, **kw), " PERCENT" if select._fetch_clause_options["percent"] else "", - "WITH TIES" - if select._fetch_clause_options["with_ties"] - else "ONLY", + ( + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY" + ), ) return text @@ -2152,6 +2302,18 @@ def _define_constraint_validity(self, constraint): not_valid = constraint.dialect_options["postgresql"]["not_valid"] return " NOT VALID" if not_valid else "" + def _define_include(self, obj): + includeclause = obj.dialect_options["postgresql"]["include"] + if not includeclause: + return "" + inclusions = [ + obj.table.c[col] if isinstance(col, str) else col + for col in includeclause + ] + return " INCLUDE (%s)" % ", ".join( + [self.preparer.quote(c.name) for c in inclusions] + ) + def visit_check_constraint(self, constraint, **kw): if constraint._type_bound: typ = list(constraint.columns)[0].type @@ -2175,6 +2337,29 @@ def visit_foreign_key_constraint(self, constraint, **kw): text += self._define_constraint_validity(constraint) return text + def visit_primary_key_constraint(self, constraint, **kw): + text = super().visit_primary_key_constraint(constraint) + text += self._define_include(constraint) + return text + + def visit_unique_constraint(self, constraint, **kw): + text = super().visit_unique_constraint(constraint) + text += self._define_include(constraint) + return text + + @util.memoized_property + def _fk_ondelete_pattern(self): + return re.compile( + r"^(?:RESTRICT|CASCADE|SET (?:NULL|DEFAULT)(?:\s*\(.+\))?" + r"|NO ACTION)$", + re.I, + ) + + def define_constraint_ondelete_cascade(self, constraint): + return " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, self._fk_ondelete_pattern + ) + def visit_create_enum_type(self, create, **kw): type_ = create.element @@ -2258,9 +2443,11 @@ def visit_create_index(self, create, **kw): ", ".join( [ self.sql_compiler.process( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr, + ( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr + ), include_table=False, literal_binds=True, ) @@ -2274,15 +2461,7 @@ def visit_create_index(self, create, **kw): ) ) - includeclause = index.dialect_options["postgresql"]["include"] - if includeclause: - inclusions = [ - index.table.c[col] if isinstance(col, str) else col - for col in includeclause - ] - text += " INCLUDE (%s)" % ", ".join( - [preparer.quote(c.name) for c in inclusions] - ) + text += self._define_include(index) nulls_not_distinct = index.dialect_options["postgresql"][ "nulls_not_distinct" @@ -2395,6 +2574,9 @@ def post_create_table(self, table): if pg_opts["partition_by"]: table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) + if pg_opts["using"]: + table_opts.append("\n USING %s" % pg_opts["using"]) + if pg_opts["with_oids"] is True: table_opts.append("\n WITH OIDS") elif pg_opts["with_oids"] is False: @@ -2582,17 +2764,21 @@ def visit_DOMAIN(self, type_, identifier_preparer=None, **kw): def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -2713,6 +2899,8 @@ class ReflectedDomain(ReflectedNamedType): """The constraints defined in the domain, if any. The constraint are in order of evaluation by postgresql. """ + collation: Optional[str] + """The collation for the domain.""" class ReflectedEnum(ReflectedNamedType): @@ -3006,6 +3194,7 @@ class PGDialect(default.DefaultDialect): "with_oids": None, "on_commit": None, "inherits": None, + "using": None, }, ), ( @@ -3020,9 +3209,16 @@ class PGDialect(default.DefaultDialect): "not_valid": False, }, ), + ( + schema.PrimaryKeyConstraint, + {"include": None}, + ), ( schema.UniqueConstraint, - {"nulls_not_distinct": None}, + { + "include": None, + "nulls_not_distinct": None, + }, ), ] @@ -3097,9 +3293,7 @@ def set_deferrable(self, connection, value): def get_deferrable(self, connection): raise NotImplementedError() - def _split_multihost_from_url( - self, url: URL - ) -> Union[ + def _split_multihost_from_url(self, url: URL) -> Union[ Tuple[None, None], Tuple[Tuple[Optional[str], ...], Tuple[Optional[int], ...]], ]: @@ -3511,6 +3705,7 @@ def _columns_query(self, schema, has_filter_names, scope, kind): pg_catalog.pg_sequence.c.seqcache, "cycle", pg_catalog.pg_sequence.c.seqcycle, + type_=sqltypes.JSON(), ) ) .select_from(pg_catalog.pg_sequence) @@ -3631,9 +3826,11 @@ def get_multi_columns( # dictionary with (name, ) if default search path or (schema, name) # as keys enums = dict( - ((rec["name"],), rec) - if rec["visible"] - else ((rec["schema"], rec["name"]), rec) + ( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + ) for rec in self._load_enums( connection, schema="*", info_cache=kw.get("info_cache") ) @@ -3643,155 +3840,187 @@ def get_multi_columns( return columns.items() - def _get_columns_info(self, rows, domains, enums, schema): - array_type_pattern = re.compile(r"\[\]$") - attype_pattern = re.compile(r"\(.*\)") - charlen_pattern = re.compile(r"\(([\d,]+)\)") - args_pattern = re.compile(r"\((.*)\)") - args_split_pattern = re.compile(r"\s*,\s*") - - def _handle_array_type(attype): - return ( - # strip '[]' from integer[], etc. - array_type_pattern.sub("", attype), - attype.endswith("[]"), + _format_type_args_pattern = re.compile(r"\((.*)\)") + _format_type_args_delim = re.compile(r"\s*,\s*") + _format_array_spec_pattern = re.compile(r"((?:\[\])*)$") + + def _reflect_type( + self, + format_type: Optional[str], + domains: Dict[str, ReflectedDomain], + enums: Dict[str, ReflectedEnum], + type_description: str, + ) -> sqltypes.TypeEngine[Any]: + """ + Attempts to reconstruct a column type defined in ischema_names based + on the information available in the format_type. + + If the `format_type` cannot be associated with a known `ischema_names`, + it is treated as a reference to a known PostgreSQL named `ENUM` or + `DOMAIN` type. + """ + type_description = type_description or "unknown type" + if format_type is None: + util.warn( + "PostgreSQL format_type() returned NULL for %s" + % type_description + ) + return sqltypes.NULLTYPE + + attype_args_match = self._format_type_args_pattern.search(format_type) + if attype_args_match and attype_args_match.group(1): + attype_args = self._format_type_args_delim.split( + attype_args_match.group(1) + ) + else: + attype_args = () + + match_array_dim = self._format_array_spec_pattern.search(format_type) + # Each "[]" in array specs corresponds to an array dimension + array_dim = len(match_array_dim.group(1) or "") // 2 + + # Remove all parameters and array specs from format_type to obtain an + # ischema_name candidate + attype = self._format_type_args_pattern.sub("", format_type) + attype = self._format_array_spec_pattern.sub("", attype) + + schema_type = self.ischema_names.get(attype.lower(), None) + args, kwargs = (), {} + + if attype == "numeric": + if len(attype_args) == 2: + precision, scale = map(int, attype_args) + args = (precision, scale) + + elif attype == "double precision": + args = (53,) + + elif attype == "integer": + args = () + + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype == "bit varying": + kwargs["varying"] = True + if len(attype_args) == 1: + charlen = int(attype_args[0]) + args = (charlen,) + + elif attype.startswith("interval"): + schema_type = INTERVAL + + field_match = re.match(r"interval (.+)", attype) + if field_match: + kwargs["fields"] = field_match.group(1) + + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + else: + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + if enum_or_domain_key in enums: + schema_type = ENUM + enum = enums[enum_or_domain_key] + + kwargs["name"] = enum["name"] + + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + elif enum_or_domain_key in domains: + schema_type = DOMAIN + domain = domains[enum_or_domain_key] + + data_type = self._reflect_type( + domain["type"], + domains, + enums, + type_description="DOMAIN '%s'" % domain["name"], + ) + args = (domain["name"], data_type) + + kwargs["collation"] = domain["collation"] + kwargs["default"] = domain["default"] + kwargs["not_null"] = not domain["nullable"] + kwargs["create_type"] = False + + if domain["constraints"]: + # We only support a single constraint + check_constraint = domain["constraints"][0] + + kwargs["constraint_name"] = check_constraint["name"] + kwargs["check"] = check_constraint["check"] + + if not domain["visible"]: + kwargs["schema"] = domain["schema"] + + else: + try: + charlen = int(attype_args[0]) + args = (charlen, *attype_args[1:]) + except (ValueError, IndexError): + args = attype_args + + if not schema_type: + util.warn( + "Did not recognize type '%s' of %s" + % (attype, type_description) ) + return sqltypes.NULLTYPE + data_type = schema_type(*args, **kwargs) + if array_dim >= 1: + # postgres does not preserve dimensionality or size of array types. + data_type = _array.ARRAY(data_type) + + return data_type + + def _get_columns_info(self, rows, domains, enums, schema): columns = defaultdict(list) for row_dict in rows: # ensure that each table has an entry, even if it has no columns if row_dict["name"] is None: - columns[ - (schema, row_dict["table_name"]) - ] = ReflectionDefaults.columns() + columns[(schema, row_dict["table_name"])] = ( + ReflectionDefaults.columns() + ) continue table_cols = columns[(schema, row_dict["table_name"])] - format_type = row_dict["format_type"] + coltype = self._reflect_type( + row_dict["format_type"], + domains, + enums, + type_description="column '%s'" % row_dict["name"], + ) + default = row_dict["default"] name = row_dict["name"] generated = row_dict["generated"] - identity = row_dict["identity_options"] - - if format_type is None: - no_format_type = True - attype = format_type = "no format_type()" - is_array = False - else: - no_format_type = False - - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = attype_pattern.sub("", format_type) - - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) - - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - nullable = not row_dict["not_null"] - charlen = charlen_pattern.search(format_type) - if charlen: - charlen = charlen.group(1) - args = args_pattern.search(format_type) - if args and args.group(1): - args = tuple(args_split_pattern.split(args.group(1))) - else: - args = () - kwargs = {} + if isinstance(coltype, DOMAIN): + if not default: + # domain can override the default value but + # cant set it to None + if coltype.default is not None: + default = coltype.default - if attype == "numeric": - if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) - else: - args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: - args = (int(charlen),) - else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["type"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple( - util.quoted_token_parser(attype) - ) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue - else: - coltype = None - break - - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - elif no_format_type: - util.warn( - "PostgreSQL format_type() returned NULL for column '%s'" - % (name,) - ) - coltype = sqltypes.NULLTYPE - else: - util.warn( - "Did not recognize type '%s' of column '%s'" - % (attype, name) - ) - coltype = sqltypes.NULLTYPE + nullable = nullable and not coltype.not_null + + identity = row_dict["identity_options"] # If a zero byte or blank string depending on driver (is also # absent for older PG versions), then not a generated column. @@ -3870,21 +4099,35 @@ def _get_table_oids( result = connection.execute(oid_q, params) return result.all() - @lru_cache() - def _constraint_query(self, is_unique): + @util.memoized_property + def _constraint_query(self): + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts") + + if self.server_version_info >= (15,): + indnullsnotdistinct = pg_catalog.pg_index.c.indnullsnotdistinct + else: + indnullsnotdistinct = sql.false().label("indnullsnotdistinct") + con_sq = ( select( pg_catalog.pg_constraint.c.conrelid, pg_catalog.pg_constraint.c.conname, - pg_catalog.pg_constraint.c.conindid, - sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( - "attnum" - ), + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), sql.func.generate_subscripts( - pg_catalog.pg_constraint.c.conkey, 1 + pg_catalog.pg_index.c.indkey, 1 ).label("ord"), + indnkeyatts, + indnullsnotdistinct, pg_catalog.pg_description.c.description, ) + .join( + pg_catalog.pg_index, + pg_catalog.pg_constraint.c.conindid + == pg_catalog.pg_index.c.indexrelid, + ) .outerjoin( pg_catalog.pg_description, pg_catalog.pg_description.c.objoid @@ -3893,6 +4136,9 @@ def _constraint_query(self, is_unique): .where( pg_catalog.pg_constraint.c.contype == bindparam("contype"), pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + # NOTE: filtering also on pg_index.indrelid for oids does + # not seem to have a performance effect, but it may be an + # option if perf problems are reported ) .subquery("con") ) @@ -3901,9 +4147,10 @@ def _constraint_query(self, is_unique): select( con_sq.c.conrelid, con_sq.c.conname, - con_sq.c.conindid, con_sq.c.description, con_sq.c.ord, + con_sq.c.indnkeyatts, + con_sq.c.indnullsnotdistinct, pg_catalog.pg_attribute.c.attname, ) .select_from(pg_catalog.pg_attribute) @@ -3926,7 +4173,7 @@ def _constraint_query(self, is_unique): .subquery("attr") ) - constraint_query = ( + return ( select( attr_sq.c.conrelid, sql.func.array_agg( @@ -3938,31 +4185,15 @@ def _constraint_query(self, is_unique): ).label("cols"), attr_sq.c.conname, sql.func.min(attr_sq.c.description).label("description"), + sql.func.min(attr_sq.c.indnkeyatts).label("indnkeyatts"), + sql.func.bool_and(attr_sq.c.indnullsnotdistinct).label( + "indnullsnotdistinct" + ), ) .group_by(attr_sq.c.conrelid, attr_sq.c.conname) .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if is_unique: - if self.server_version_info >= (15,): - constraint_query = constraint_query.join( - pg_catalog.pg_index, - attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid, - ).add_columns( - sql.func.bool_and( - pg_catalog.pg_index.c.indnullsnotdistinct - ).label("indnullsnotdistinct") - ) - else: - constraint_query = constraint_query.add_columns( - sql.false().label("indnullsnotdistinct") - ) - else: - constraint_query = constraint_query.add_columns( - sql.null().label("extra") - ) - return constraint_query - def _reflect_constraint( self, connection, contype, schema, filter_names, scope, kind, **kw ): @@ -3978,26 +4209,42 @@ def _reflect_constraint( batches[0:3000] = [] result = connection.execute( - self._constraint_query(is_unique), + self._constraint_query, {"oids": [r[0] for r in batch], "contype": contype}, - ) + ).mappings() result_by_oid = defaultdict(list) - for oid, cols, constraint_name, comment, extra in result: - result_by_oid[oid].append( - (cols, constraint_name, comment, extra) - ) + for row_dict in result: + result_by_oid[row_dict["conrelid"]].append(row_dict) for oid, tablename in batch: for_oid = result_by_oid.get(oid, ()) if for_oid: - for cols, constraint, comment, extra in for_oid: - if is_unique: - yield tablename, cols, constraint, comment, { - "nullsnotdistinct": extra - } + for row in for_oid: + # See note in get_multi_indexes + all_cols = row["cols"] + indnkeyatts = row["indnkeyatts"] + if len(all_cols) > indnkeyatts: + inc_cols = all_cols[indnkeyatts:] + cst_cols = all_cols[:indnkeyatts] else: - yield tablename, cols, constraint, comment, None + inc_cols = [] + cst_cols = all_cols + + opts = {} + if self.server_version_info >= (11,): + opts["postgresql_include"] = inc_cols + if is_unique: + opts["postgresql_nulls_not_distinct"] = row[ + "indnullsnotdistinct" + ] + yield ( + tablename, + cst_cols, + row["conname"], + row["description"], + opts, + ) else: yield tablename, None, None, None, None @@ -4023,18 +4270,27 @@ def get_multi_pk_constraint( # only a single pk can be present for each table. Return an entry # even if a table has no primary key default = ReflectionDefaults.pk_constraint + + def pk_constraint(pk_name, cols, comment, opts): + info = { + "constrained_columns": cols, + "name": pk_name, + "comment": comment, + } + if opts: + info["dialect_options"] = opts + return info + return ( ( (schema, table_name), - { - "constrained_columns": [] if cols is None else cols, - "name": pk_name, - "comment": comment, - } - if pk_name is not None - else default(), + ( + pk_constraint(pk_name, cols, comment, opts) + if pk_name is not None + else default() + ), ) - for table_name, cols, pk_name, comment, _ in result + for table_name, cols, pk_name, comment, opts in result ) @reflection.cache @@ -4128,7 +4384,8 @@ def _fk_regex_pattern(self): r"[\s]?(ON UPDATE " r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" r"[\s]?(ON DELETE " - r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"(CASCADE|RESTRICT|NO ACTION|" + r"SET (?:NULL|DEFAULT)(?:\s\(.+\))?)+)?" r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) @@ -4244,7 +4501,10 @@ def get_indexes(self, connection, table_name, schema=None, **kw): @util.memoized_property def _index_query(self): - pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: pg_index is used as from two times to improve performance, + # since extraing all the index information from `idx_sq` to avoid + # the second pg_index use leads to a worse performing query in + # particular when querying for a single table (as of pg 17) # NOTE: repeating oids clause improve query performance # subquery to get the columns @@ -4253,6 +4513,9 @@ def _index_query(self): pg_catalog.pg_index.c.indexrelid, pg_catalog.pg_index.c.indrelid, sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.unnest(pg_catalog.pg_index.c.indclass).label( + "att_opclass" + ), sql.func.generate_subscripts( pg_catalog.pg_index.c.indkey, 1 ).label("ord"), @@ -4284,6 +4547,8 @@ def _index_query(self): else_=pg_catalog.pg_attribute.c.attname.cast(TEXT), ).label("element"), (idx_sq.c.attnum == 0).label("is_expr"), + pg_catalog.pg_opclass.c.opcname, + pg_catalog.pg_opclass.c.opcdefault, ) .select_from(idx_sq) .outerjoin( @@ -4294,6 +4559,10 @@ def _index_query(self): pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, ), ) + .outerjoin( + pg_catalog.pg_opclass, + pg_catalog.pg_opclass.c.oid == idx_sq.c.att_opclass, + ) .where(idx_sq.c.indrelid.in_(bindparam("oids"))) .subquery("idx_attr") ) @@ -4308,6 +4577,12 @@ def _index_query(self): sql.func.array_agg( aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord) ).label("elements_is_expr"), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.opcname, attr_sq.c.ord) + ).label("elements_opclass"), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.opcdefault, attr_sq.c.ord) + ).label("elements_opdefault"), ) .group_by(attr_sq.c.indexrelid) .subquery("idx_cols") @@ -4316,7 +4591,7 @@ def _index_query(self): if self.server_version_info >= (11, 0): indnkeyatts = pg_catalog.pg_index.c.indnkeyatts else: - indnkeyatts = sql.null().label("indnkeyatts") + indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts") if self.server_version_info >= (15,): nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct @@ -4326,13 +4601,13 @@ def _index_query(self): return ( select( pg_catalog.pg_index.c.indrelid, - pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_class.c.relname, pg_catalog.pg_index.c.indisunique, pg_catalog.pg_constraint.c.conrelid.is_not(None).label( "has_constraint" ), pg_catalog.pg_index.c.indoption, - pg_class_index.c.reloptions, + pg_catalog.pg_class.c.reloptions, pg_catalog.pg_am.c.amname, # NOTE: pg_get_expr is very fast so this case has almost no # performance impact @@ -4350,6 +4625,8 @@ def _index_query(self): nulls_not_distinct, cols_sq.c.elements, cols_sq.c.elements_is_expr, + cols_sq.c.elements_opclass, + cols_sq.c.elements_opdefault, ) .select_from(pg_catalog.pg_index) .where( @@ -4357,12 +4634,12 @@ def _index_query(self): ~pg_catalog.pg_index.c.indisprimary, ) .join( - pg_class_index, - pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + pg_catalog.pg_class, + pg_catalog.pg_index.c.indexrelid == pg_catalog.pg_class.c.oid, ) .join( pg_catalog.pg_am, - pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + pg_catalog.pg_class.c.relam == pg_catalog.pg_am.c.oid, ) .outerjoin( cols_sq, @@ -4379,7 +4656,9 @@ def _index_query(self): == sql.any_(_array.array(("p", "u", "x"))), ), ) - .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + .order_by( + pg_catalog.pg_index.c.indrelid, pg_catalog.pg_class.c.relname + ) ) def get_multi_indexes( @@ -4414,17 +4693,19 @@ def get_multi_indexes( continue for row in result_by_oid[oid]: - index_name = row["relname_index"] + index_name = row["relname"] table_indexes = indexes[(schema, table_name)] all_elements = row["elements"] all_elements_is_expr = row["elements_is_expr"] + all_elements_opclass = row["elements_opclass"] + all_elements_opdefault = row["elements_opdefault"] indnkeyatts = row["indnkeyatts"] # "The number of key columns in the index, not counting any # included columns, which are merely stored and do not # participate in the index semantics" - if indnkeyatts and len(all_elements) > indnkeyatts: + if len(all_elements) > indnkeyatts: # this is a "covering index" which has INCLUDE columns # as well as regular index columns inc_cols = all_elements[indnkeyatts:] @@ -4439,10 +4720,18 @@ def get_multi_indexes( not is_expr for is_expr in all_elements_is_expr[indnkeyatts:] ) + idx_elements_opclass = all_elements_opclass[ + :indnkeyatts + ] + idx_elements_opdefault = all_elements_opdefault[ + :indnkeyatts + ] else: idx_elements = all_elements idx_elements_is_expr = all_elements_is_expr inc_cols = [] + idx_elements_opclass = all_elements_opclass + idx_elements_opdefault = all_elements_opdefault index = {"name": index_name, "unique": row["indisunique"]} if any(idx_elements_is_expr): @@ -4456,6 +4745,19 @@ def get_multi_indexes( else: index["column_names"] = idx_elements + dialect_options = {} + + if not all(idx_elements_opdefault): + dialect_options["postgresql_ops"] = { + name: opclass + for name, opclass, is_default in zip( + idx_elements, + idx_elements_opclass, + idx_elements_opdefault, + ) + if not is_default + } + sorting = {} for col_index, col_flags in enumerate(row["indoption"]): col_sorting = () @@ -4475,10 +4777,12 @@ def get_multi_indexes( if row["has_constraint"]: index["duplicates_constraint"] = index_name - dialect_options = {} if row["reloptions"]: dialect_options["postgresql_with"] = dict( - [option.split("=") for option in row["reloptions"]] + [ + option.split("=", 1) + for option in row["reloptions"] + ] ) # it *might* be nice to include that this is 'btree' in the # reflection info. But we don't want an Index object @@ -4551,12 +4855,7 @@ def get_multi_unique_constraints( "comment": comment, } if options: - if options["nullsnotdistinct"]: - uc_dict["dialect_options"] = { - "postgresql_nulls_not_distinct": options[ - "nullsnotdistinct" - ] - } + uc_dict["dialect_options"] = options uniques[(schema, table_name)].append(uc_dict) return uniques.items() @@ -4588,6 +4887,8 @@ def _comment_query(self, schema, has_filter_names, scope, kind): pg_catalog.pg_class.c.oid == pg_catalog.pg_description.c.objoid, pg_catalog.pg_description.c.objsubid == 0, + pg_catalog.pg_description.c.classoid + == sql.func.cast("pg_catalog.pg_class", REGCLASS), ), ) .where(self._pg_class_relkind_condition(relkinds)) @@ -4696,9 +4997,13 @@ def get_multi_check_constraints( # "CHECK (((a > 1) AND (a < 5))) NOT VALID" # "CHECK (some_boolean_function(a))" # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)" + # "CHECK (a NOT NULL) NO INHERIT" + # "CHECK (a NOT NULL) NO INHERIT NOT VALID" m = re.match( - r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL + r"^CHECK *\((.+)\)( NO INHERIT)?( NOT VALID)?$", + src, + flags=re.DOTALL, ) if not m: util.warn("Could not parse CHECK constraint text: %r" % src) @@ -4712,8 +5017,14 @@ def get_multi_check_constraints( "sqltext": sqltext, "comment": comment, } - if m and m.group(2): - entry["dialect_options"] = {"not_valid": True} + if m: + do = {} + if " NOT VALID" in m.groups(): + do["not_valid"] = True + if " NO INHERIT" in m.groups(): + do["no_inherit"] = True + if do: + entry["dialect_options"] = do check_constraints[(schema, table_name)].append(entry) return check_constraints.items() @@ -4828,12 +5139,18 @@ def _domain_query(self, schema): pg_catalog.pg_namespace.c.nspname.label("schema"), con_sq.c.condefs, con_sq.c.connames, + pg_catalog.pg_collation.c.collname, ) .join( pg_catalog.pg_namespace, pg_catalog.pg_namespace.c.oid == pg_catalog.pg_type.c.typnamespace, ) + .outerjoin( + pg_catalog.pg_collation, + pg_catalog.pg_type.c.typcollation + == pg_catalog.pg_collation.c.oid, + ) .outerjoin( con_sq, pg_catalog.pg_type.c.oid == con_sq.c.contypid, @@ -4847,14 +5164,13 @@ def _domain_query(self, schema): @reflection.cache def _load_domains(self, connection, schema=None, **kw): - # Load data types for domains: result = connection.execute(self._domain_query(schema)) - domains = [] + domains: List[ReflectedDomain] = [] for domain in result.mappings(): # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) - constraints = [] + constraints: List[ReflectedDomainConstraint] = [] if domain["connames"]: # When a domain has multiple CHECK constraints, they will # be tested in alphabetical order by name. @@ -4863,12 +5179,13 @@ def _load_domains(self, connection, schema=None, **kw): key=lambda t: t[0], ) for name, def_ in sorted_constraints: - # constraint is in the form "CHECK (expression)". + # constraint is in the form "CHECK (expression)" + # or "NOT NULL". Ignore the "NOT NULL" and # remove "CHECK (" and the tailing ")". - check = def_[7:-1] - constraints.append({"name": name, "check": check}) - - domain_rec = { + if def_.casefold().startswith("check"): + check = def_[7:-1] + constraints.append({"name": name, "check": check}) + domain_rec: ReflectedDomain = { "name": domain["name"], "schema": domain["schema"], "visible": domain["visible"], @@ -4876,6 +5193,7 @@ def _load_domains(self, connection, schema=None, **kw): "nullable": domain["nullable"], "default": domain["default"], "constraints": constraints, + "collation": domain["collname"], } domains.append(domain_rec) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index dee7af3311e..1187b6bf5f0 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -1,5 +1,5 @@ -# postgresql/dml.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/dml.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,7 +7,10 @@ from __future__ import annotations from typing import Any +from typing import List from typing import Optional +from typing import Tuple +from typing import Union from . import ext from .._typing import _OnConflictConstraintT @@ -26,7 +29,9 @@ from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement +from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -153,11 +158,10 @@ def on_conflict_do_update( :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. If present, can be a literal SQL - string or an acceptable expression for a ``WHERE`` clause - that restricts the rows affected by ``DO UPDATE SET``. Rows - not meeting the ``WHERE`` condition will not be updated - (effectively a ``DO NOTHING`` for those rows). + Optional argument. An expression object representing a ``WHERE`` + clause that restricts the rows affected by ``DO UPDATE SET``. Rows not + meeting the ``WHERE`` condition will not be updated (effectively a + ``DO NOTHING`` for those rows). .. seealso:: @@ -212,8 +216,10 @@ class OnConflictClause(ClauseElement): stringify_dialect = "postgresql" constraint_target: Optional[str] - inferred_target_elements: _OnConflictIndexElementsT - inferred_target_whereclause: _OnConflictIndexWhereT + inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] + inferred_target_whereclause: Optional[ + Union[ColumnElement[Any], TextClause] + ] def __init__( self, @@ -254,12 +260,28 @@ def __init__( if index_elements is not None: self.constraint_target = None - self.inferred_target_elements = index_elements - self.inferred_target_whereclause = index_where + self.inferred_target_elements = [ + coercions.expect(roles.DDLConstraintColumnRole, column) + for column in index_elements + ] + + self.inferred_target_whereclause = ( + coercions.expect( + ( + roles.StatementOptionRole + if isinstance(constraint, ext.ExcludeConstraint) + else roles.WhereHavingRole + ), + index_where, + ) + if index_where is not None + else None + ) + elif constraint is None: - self.constraint_target = ( - self.inferred_target_elements - ) = self.inferred_target_whereclause = None + self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None class OnConflictDoNothing(OnConflictClause): @@ -269,6 +291,9 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" + update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_whereclause: Optional[ColumnElement[Any]] + def __init__( self, constraint: _OnConflictConstraintT = None, @@ -307,4 +332,8 @@ def __init__( (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = where + self.update_whereclause = ( + coercions.expect(roles.WhereHavingRole, where) + if where is not None + else None + ) diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index ad1267750bb..54bacd94471 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -1,5 +1,5 @@ -# postgresql/ext.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/ext.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,6 +8,10 @@ from __future__ import annotations from typing import Any +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar @@ -23,34 +27,44 @@ from ...sql.sqltypes import TEXT from ...sql.visitors import InternalTraversal -_T = TypeVar("_T", bound=Any) - if TYPE_CHECKING: + from ...sql._typing import _ColumnExpressionArgument + from ...sql.elements import ClauseElement + from ...sql.elements import ColumnElement + from ...sql.operators import OperatorType + from ...sql.selectable import FromClause + from ...sql.visitors import _CloneCallableType from ...sql.visitors import _TraverseInternalsType +_T = TypeVar("_T", bound=Any) + -class aggregate_order_by(expression.ColumnElement): +class aggregate_order_by(expression.ColumnElement[_T]): """Represent a PostgreSQL aggregate order by expression. E.g.:: from sqlalchemy.dialects.postgresql import aggregate_order_by + expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) stmt = select(expr) - would represent the expression:: + would represent the expression: + + .. sourcecode:: sql SELECT array_agg(a ORDER BY b DESC) FROM table; Similarly:: expr = func.string_agg( - table.c.a, - aggregate_order_by(literal_column("','"), table.c.a) + table.c.a, aggregate_order_by(literal_column("','"), table.c.a) ) stmt = select(expr) - Would represent:: + Would represent: + + .. sourcecode:: sql SELECT string_agg(a, ',' ORDER BY a) FROM table; @@ -71,11 +85,32 @@ class aggregate_order_by(expression.ColumnElement): ("order_by", InternalTraversal.dp_clauseelement), ] - def __init__(self, target, *order_by): - self.target = coercions.expect(roles.ExpressionElementRole, target) + @overload + def __init__( + self, + target: ColumnElement[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + @overload + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): ... + + def __init__( + self, + target: _ColumnExpressionArgument[_T], + *order_by: _ColumnExpressionArgument[Any], + ): + self.target: ClauseElement = coercions.expect( + roles.ExpressionElementRole, target + ) self.type = self.target.type _lob = len(order_by) + self.order_by: ClauseElement if _lob == 0: raise TypeError("at least one ORDER BY element is required") elif _lob == 1: @@ -87,18 +122,22 @@ def __init__(self, target, *order_by): *order_by, _literal_as_text_role=roles.ExpressionElementRole ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: return self - def get_children(self, **kwargs): + def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]: return self.target, self.order_by - def _copy_internals(self, clone=elements._clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = elements._clone, **kw: Any + ) -> None: self.target = clone(self.target, **kw) self.order_by = clone(self.order_by, **kw) @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return self.target._from_objects + self.order_by._from_objects @@ -131,10 +170,10 @@ def __init__(self, *elements, **kw): E.g.:: const = ExcludeConstraint( - (Column('period'), '&&'), - (Column('group'), '='), - where=(Column('group') != 'some group'), - ops={'group': 'my_operator_class'} + (Column("period"), "&&"), + (Column("group"), "="), + where=(Column("group") != "some group"), + ops={"group": "my_operator_class"}, ) The constraint is normally embedded into the :class:`_schema.Table` @@ -142,19 +181,20 @@ def __init__(self, *elements, **kw): directly, or added later using :meth:`.append_constraint`:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, primary_key=True), - Column('period', TSRANGE()), - Column('group', String) + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("period", TSRANGE()), + Column("group", String), ) some_table.append_constraint( ExcludeConstraint( - (some_table.c.period, '&&'), - (some_table.c.group, '='), - where=some_table.c.group != 'some group', - name='some_table_excl_const', - ops={'group': 'my_operator_class'} + (some_table.c.period, "&&"), + (some_table.c.group, "="), + where=some_table.c.group != "some group", + name="some_table_excl_const", + ops={"group": "my_operator_class"}, ) ) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 83c4932a6ea..0a915b17dff 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -1,5 +1,5 @@ -# postgresql/hstore.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/hstore.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,28 +28,29 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: - data_table = Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', HSTORE) + data_table = Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", HSTORE), ) with engine.connect() as conn: conn.execute( - data_table.insert(), - data = {"key1": "value1", "key2": "value2"} + data_table.insert(), data={"key1": "value1", "key2": "value2"} ) :class:`.HSTORE` provides for a wide range of operations, including: * Index operations:: - data_table.c.data['some key'] == 'some value' + data_table.c.data["some key"] == "some value" * Containment operations:: - data_table.c.data.has_key('some key') + data_table.c.data.has_key("some key") - data_table.c.data.has_all(['one', 'two', 'three']) + data_table.c.data.has_all(["one", "two", "three"]) * Concatenation:: @@ -72,17 +73,19 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): from sqlalchemy.ext.mutable import MutableDict + class MyClass(Base): - __tablename__ = 'data_table' + __tablename__ = "data_table" id = Column(Integer, primary_key=True) data = Column(MutableDict.as_mutable(HSTORE)) + my_object = session.query(MyClass).one() # in-place mutation, requires Mutable extension # in order for the ORM to detect - my_object.data['some_key'] = 'some value' + my_object.data["some_key"] = "some value" session.commit() @@ -96,7 +99,7 @@ class MyClass(Base): :class:`.hstore` - render the PostgreSQL ``hstore()`` function. - """ + """ # noqa: E501 __visit_name__ = "HSTORE" hashable = False @@ -192,6 +195,9 @@ def matrix(self): comparator_factory = Comparator def bind_processor(self, dialect): + # note that dialect-specific types like that of psycopg and + # psycopg2 will override this method to allow driver-level conversion + # instead, see _PsycopgHStore def process(value): if isinstance(value, dict): return _serialize_hstore(value) @@ -201,6 +207,9 @@ def process(value): return process def result_processor(self, dialect, coltype): + # note that dialect-specific types like that of psycopg and + # psycopg2 will override this method to allow driver-level conversion + # instead, see _PsycopgHStore def process(value): if value is not None: return _parse_hstore(value) @@ -221,12 +230,12 @@ class hstore(sqlfunc.GenericFunction): from sqlalchemy.dialects.postgresql import array, hstore - select(hstore('key1', 'value1')) + select(hstore("key1", "value1")) select( hstore( - array(['key1', 'key2', 'key3']), - array(['value1', 'value2', 'value3']) + array(["key1", "key2", "key3"]), + array(["value1", "value2", "value3"]), ) ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index ee56a745048..06f8db5b2af 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -1,11 +1,18 @@ -# postgresql/json.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/json.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .array import ARRAY from .array import array as _pg_array @@ -21,13 +28,23 @@ from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast +from ...sql._typing import _T + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.elements import ColumnElement + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import TypeEngine __all__ = ("JSON", "JSONB") class JSONPathType(sqltypes.JSON.JSONPathType): - def _processor(self, dialect, super_proc): - def process(value): + def _processor( + self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]] + ) -> Callable[[Any], Any]: + def process(value: Any) -> Any: if isinstance(value, str): # If it's already a string assume that it's in json path # format. This allows using cast with json paths literals @@ -44,11 +61,13 @@ def process(value): return process - def bind_processor(self, dialect): - return self._processor(dialect, self.string_bind_processor(dialect)) + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501 - def literal_processor(self, dialect): - return self._processor(dialect, self.string_literal_processor(dialect)) + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501 class JSONPATH(JSONPathType): @@ -90,14 +109,14 @@ class JSON(sqltypes.JSON): * Index operations (the ``->`` operator):: - data_table.c.data['some key'] + data_table.c.data["some key"] data_table.c.data[5] + * Index operations returning text + (the ``->>`` operator):: - * Index operations returning text (the ``->>`` operator):: - - data_table.c.data['some key'].astext == 'some value' + data_table.c.data["some key"].astext == "some value" Note that equivalent functionality is available via the :attr:`.JSON.Comparator.as_string` accessor. @@ -105,18 +124,20 @@ class JSON(sqltypes.JSON): * Index operations with CAST (equivalent to ``CAST(col ->> ['some key'] AS )``):: - data_table.c.data['some key'].astext.cast(Integer) == 5 + data_table.c.data["some key"].astext.cast(Integer) == 5 Note that equivalent functionality is available via the :attr:`.JSON.Comparator.as_integer` and similar accessors. * Path index operations (the ``#>`` operator):: - data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] + data_table.c.data[("key_1", "key_2", 5, ..., "key_n")] * Path index operations returning text (the ``#>>`` operator):: - data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value' + data_table.c.data[ + ("key_1", "key_2", 5, ..., "key_n") + ].astext == "some value" Index operations return an expression object whose type defaults to :class:`_types.JSON` by default, @@ -128,10 +149,11 @@ class JSON(sqltypes.JSON): using psycopg2, the DBAPI only allows serializers at the per-cursor or per-connection level. E.g.:: - engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", - json_serializer=my_serialize_fn, - json_deserializer=my_deserialize_fn - ) + engine = create_engine( + "postgresql+psycopg2://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn, + ) When using the psycopg2 dialect, the json_deserializer is registered against the database using ``psycopg2.extras.register_default_json``. @@ -144,9 +166,14 @@ class JSON(sqltypes.JSON): """ # noqa - astext_type = sqltypes.Text() + render_bind_cast = True + astext_type: TypeEngine[str] = sqltypes.Text() - def __init__(self, none_as_null=False, astext_type=None): + def __init__( + self, + none_as_null: bool = False, + astext_type: Optional[TypeEngine[str]] = None, + ): """Construct a :class:`_types.JSON` type. :param none_as_null: if True, persist the value ``None`` as a @@ -155,7 +182,8 @@ def __init__(self, none_as_null=False, astext_type=None): be used to persist a NULL value:: from sqlalchemy import null - conn.execute(table.insert(), data=null()) + + conn.execute(table.insert(), {"data": null()}) .. seealso:: @@ -170,17 +198,19 @@ def __init__(self, none_as_null=False, astext_type=None): if astext_type is not None: self.astext_type = astext_type - class Comparator(sqltypes.JSON.Comparator): + class Comparator(sqltypes.JSON.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" + type: JSON + @property - def astext(self): + def astext(self) -> ColumnElement[str]: """On an indexed expression, use the "astext" (e.g. "->>") conversion when rendered in SQL. E.g.:: - select(data_table.c.data['some key'].astext) + select(data_table.c.data["some key"].astext) .. seealso:: @@ -188,13 +218,13 @@ def astext(self): """ if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): - return self.expr.left.operate( + return self.expr.left.operate( # type: ignore[no-any-return] JSONPATH_ASTEXT, self.expr.right, result_type=self.type.astext_type, ) else: - return self.expr.left.operate( + return self.expr.left.operate( # type: ignore[no-any-return] ASTEXT, self.expr.right, result_type=self.type.astext_type ) @@ -207,15 +237,16 @@ class JSONB(JSON): The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, e.g.:: - data_table = Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', JSONB) + data_table = Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", JSONB), ) with engine.connect() as conn: conn.execute( - data_table.insert(), - data = {"key1": "value1", "key2": "value2"} + data_table.insert(), data={"key1": "value1", "key2": "value2"} ) The :class:`_postgresql.JSONB` type includes all operations provided by @@ -252,43 +283,53 @@ class JSONB(JSON): __visit_name__ = "JSONB" - class Comparator(JSON.Comparator): + class Comparator(JSON.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" - def has_key(self, other): - """Boolean expression. Test for presence of a key. Note that the - key may be a SQLA expression. + type: JSONB + + def has_key(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Test for presence of a key (equivalent of + the ``?`` operator). Note that the key may be a SQLA expression. """ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) - def has_all(self, other): - """Boolean expression. Test for presence of all keys in jsonb""" + def has_all(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Test for presence of all keys in jsonb + (equivalent of the ``?&`` operator) + """ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) - def has_any(self, other): - """Boolean expression. Test for presence of any key in jsonb""" + def has_any(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Test for presence of any key in jsonb + (equivalent of the ``?|`` operator) + """ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) - def contains(self, other, **kwargs): + def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: """Boolean expression. Test if keys (or array) are a superset - of/contained the keys of the argument jsonb expression. + of/contained the keys of the argument jsonb expression + (equivalent of the ``@>`` operator). kwargs may be ignored by this operator but are required for API conformance. """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other): + def contained_by(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test if keys are a proper subset of the - keys of the argument jsonb expression. + keys of the argument jsonb expression + (equivalent of the ``<@`` operator). """ return self.operate( CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def delete_path(self, array): + def delete_path( + self, array: Union[List[str], _pg_array[str]] + ) -> ColumnElement[JSONB]: """JSONB expression. Deletes field or array element specified in - the argument array. + the argument array (equivalent of the ``#-`` operator). The input may be a list of strings that will be coerced to an ``ARRAY`` or an instance of :meth:`_postgres.array`. @@ -300,9 +341,9 @@ def delete_path(self, array): right_side = cast(array, ARRAY(sqltypes.TEXT)) return self.operate(DELETE_PATH, right_side, result_type=JSONB) - def path_exists(self, other): + def path_exists(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of item given by the - argument JSONPath expression. + argument JSONPath expression (equivalent of the ``@?`` operator). .. versionadded:: 2.0 """ @@ -310,9 +351,10 @@ def path_exists(self, other): PATH_EXISTS, other, result_type=sqltypes.Boolean ) - def path_match(self, other): + def path_match(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test if JSONPath predicate given by the - argument JSONPath expression matches. + argument JSONPath expression matches + (equivalent of the ``@@`` operator). Only the first item of the result is taken into account. diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index 19994d4b99f..5807041ead3 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -1,5 +1,5 @@ -# postgresql/named_types.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/named_types.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,7 +7,9 @@ # mypy: ignore-errors from __future__ import annotations +from types import ModuleType from typing import Any +from typing import Dict from typing import Optional from typing import Type from typing import TYPE_CHECKING @@ -25,10 +27,11 @@ from ...sql.ddl import InvokeDropDDLBase if TYPE_CHECKING: + from ...sql._typing import _CreateDropBind from ...sql._typing import _TypeEngineArgument -class NamedType(sqltypes.TypeEngine): +class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): """Base for named types.""" __abstract__ = True @@ -36,7 +39,9 @@ class NamedType(sqltypes.TypeEngine): DDLDropper: Type[NamedTypeDropper] create_type: bool - def create(self, bind, checkfirst=True, **kw): + def create( + self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any + ) -> None: """Emit ``CREATE`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -50,7 +55,9 @@ def create(self, bind, checkfirst=True, **kw): """ bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) - def drop(self, bind, checkfirst=True, **kw): + def drop( + self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any + ) -> None: """Emit ``DROP`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -63,7 +70,9 @@ def drop(self, bind, checkfirst=True, **kw): """ bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) - def _check_for_name_in_memos(self, checkfirst, kw): + def _check_for_name_in_memos( + self, checkfirst: bool, kw: Dict[str, Any] + ) -> bool: """Look in the 'ddl runner' for 'memos', then note our name in that collection. @@ -87,7 +96,13 @@ def _check_for_name_in_memos(self, checkfirst, kw): else: return False - def _on_table_create(self, target, bind, checkfirst=False, **kw): + def _on_table_create( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if ( checkfirst or ( @@ -97,7 +112,13 @@ def _on_table_create(self, target, bind, checkfirst=False, **kw): ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_table_drop(self, target, bind, checkfirst=False, **kw): + def _on_table_drop( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if ( not self.metadata and not kw.get("_is_metadata_operation", False) @@ -105,11 +126,23 @@ def _on_table_drop(self, target, bind, checkfirst=False, **kw): ): self.drop(bind=bind, checkfirst=checkfirst) - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + def _on_metadata_create( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + def _on_metadata_drop( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) @@ -163,7 +196,6 @@ def visit_enum(self, enum): class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): - """PostgreSQL ENUM type. This is a subclass of :class:`_types.Enum` which includes @@ -186,8 +218,10 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): :meth:`_schema.Table.drop` methods are called:: - table = Table('sometable', metadata, - Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + table = Table( + "sometable", + metadata, + Column("some_enum", ENUM("a", "b", "c", name="myenum")), ) table.create(engine) # will emit CREATE ENUM and CREATE TABLE @@ -198,21 +232,17 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): :class:`_postgresql.ENUM` independently, and associate it with the :class:`_schema.MetaData` object itself:: - my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + my_enum = ENUM("a", "b", "c", name="myenum", metadata=metadata) - t1 = Table('sometable_one', metadata, - Column('some_enum', myenum) - ) + t1 = Table("sometable_one", metadata, Column("some_enum", myenum)) - t2 = Table('sometable_two', metadata, - Column('some_enum', myenum) - ) + t2 = Table("sometable_two", metadata, Column("some_enum", myenum)) When this pattern is used, care must still be taken at the level of individual table creates. Emitting CREATE TABLE without also specifying ``checkfirst=True`` will still cause issues:: - t1.create(engine) # will fail: no such type 'myenum' + t1.create(engine) # will fail: no such type 'myenum' If we specify ``checkfirst=True``, the individual table-level create operation will check for the ``ENUM`` and create if not exists:: @@ -317,7 +347,7 @@ def adapt_emulated_to_native(cls, impl, **kw): return cls(**kw) - def create(self, bind=None, checkfirst=True): + def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``CREATE TYPE`` for this :class:`_postgresql.ENUM`. @@ -338,7 +368,7 @@ def create(self, bind=None, checkfirst=True): super().create(bind, checkfirst=checkfirst) - def drop(self, bind=None, checkfirst=True): + def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``DROP TYPE`` for this :class:`_postgresql.ENUM`. @@ -358,7 +388,7 @@ def drop(self, bind=None, checkfirst=True): super().drop(bind, checkfirst=checkfirst) - def get_dbapi_type(self, dbapi): + def get_dbapi_type(self, dbapi: ModuleType) -> None: """dont return dbapi.STRING for ENUM in PostgreSQL, since that's a different type""" @@ -388,14 +418,12 @@ class DOMAIN(NamedType, sqltypes.SchemaType): A domain is essentially a data type with optional constraints that restrict the allowed set of values. E.g.:: - PositiveInt = DOMAIN( - "pos_int", Integer, check="VALUE > 0", not_null=True - ) + PositiveInt = DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True) UsPostalCode = DOMAIN( "us_postal_code", Text, - check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'" + check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'", ) See the `PostgreSQL documentation`__ for additional details @@ -404,7 +432,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType): .. versionadded:: 2.0 - """ + """ # noqa: E501 DDLGenerator = DomainGenerator DDLDropper = DomainDropper @@ -417,10 +445,10 @@ def __init__( data_type: _TypeEngineArgument[Any], *, collation: Optional[str] = None, - default: Optional[Union[str, elements.TextClause]] = None, + default: Union[elements.TextClause, str, None] = None, constraint_name: Optional[str] = None, not_null: Optional[bool] = None, - check: Optional[str] = None, + check: Union[elements.TextClause, str, None] = None, create_type: bool = True, **kw: Any, ): @@ -464,7 +492,7 @@ def __init__( self.default = default self.collation = collation self.constraint_name = constraint_name - self.not_null = not_null + self.not_null = bool(not_null) if check is not None: check = coercions.expect(roles.DDLExpressionRole, check) self.check = check diff --git a/lib/sqlalchemy/dialects/postgresql/operators.py b/lib/sqlalchemy/dialects/postgresql/operators.py index f393451c6e1..ebcafcba991 100644 --- a/lib/sqlalchemy/dialects/postgresql/operators.py +++ b/lib/sqlalchemy/dialects/postgresql/operators.py @@ -1,5 +1,5 @@ -# postgresql/operators.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/operators.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 71ee4ebd63e..bf113230e07 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -1,5 +1,5 @@ -# postgresql/pg8000.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under @@ -27,19 +27,21 @@ the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. Typically, this can be changed to ``utf-8``, as a more useful default:: - #client_encoding = sql_ascii # actually, defaults to database - # encoding + # client_encoding = sql_ascii # actually, defaults to database encoding client_encoding = utf8 The ``client_encoding`` can be overridden for a session by executing the SQL: -SET CLIENT_ENCODING TO 'utf8'; +.. sourcecode:: sql + + SET CLIENT_ENCODING TO 'utf8'; SQLAlchemy will execute this SQL on all new connections based on the value passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter:: engine = create_engine( - "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') + "postgresql+pg8000://user:pass@host/dbname", client_encoding="utf8" + ) .. _pg8000_ssl: @@ -50,6 +52,7 @@ :paramref:`_sa.create_engine.connect_args` dictionary:: import ssl + ssl_context = ssl.create_default_context() engine = sa.create_engine( "postgresql+pg8000://scott:tiger@192.168.0.199/test", @@ -61,6 +64,7 @@ necessary to disable hostname checking:: import ssl + ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -253,7 +257,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): pass -class _Pg8000Range(ranges.AbstractRangeImpl): +class _Pg8000Range(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): pg8000_Range = dialect.dbapi.Range @@ -304,15 +308,13 @@ def result_processor(self, dialect, coltype): def to_multirange(value): if value is None: return None - - mr = [] - for v in value: - mr.append( + else: + return ranges.MultiRange( ranges.Range( v.lower, v.upper, bounds=v.bounds, empty=v.is_empty ) + for v in value ) - return mr return to_multirange @@ -584,8 +586,8 @@ def _set_client_encoding(self, dbapi_connection, client_encoding): cursor = dbapi_connection.cursor() cursor.execute( f"""SET CLIENT_ENCODING TO '{ - client_encoding.replace("'", "''") - }'""" + client_encoding.replace("'", "''") + }'""" ) cursor.execute("COMMIT") cursor.close() diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py index fa4b30f03f4..9625ccf3347 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -1,10 +1,16 @@ -# postgresql/pg_catalog.py -# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# dialects/postgresql/pg_catalog.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors + +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING from .array import ARRAY from .types import OID @@ -23,31 +29,37 @@ from ...types import Text from ...types import TypeDecorator +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _ResultProcessorType + # types -class NAME(TypeDecorator): +class NAME(TypeDecorator[str]): impl = String(64, collation="C") cache_ok = True -class PG_NODE_TREE(TypeDecorator): +class PG_NODE_TREE(TypeDecorator[str]): impl = Text(collation="C") cache_ok = True -class INT2VECTOR(TypeDecorator): +class INT2VECTOR(TypeDecorator[Sequence[int]]): impl = ARRAY(SmallInteger) cache_ok = True -class OIDVECTOR(TypeDecorator): +class OIDVECTOR(TypeDecorator[Sequence[int]]): impl = ARRAY(OID) cache_ok = True class _SpaceVector: - def result_processor(self, dialect, coltype): - def process(value): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[list[int]]: + def process(value: Any) -> Optional[list[int]]: if value is None: return value return [int(p) for p in value.split(" ")] @@ -77,7 +89,7 @@ def process(value): RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW # tables -pg_catalog_meta = MetaData() +pg_catalog_meta = MetaData(schema="pg_catalog") pg_namespace = Table( "pg_namespace", @@ -85,7 +97,6 @@ def process(value): Column("oid", OID), Column("nspname", NAME), Column("nspowner", OID), - schema="pg_catalog", ) pg_class = Table( @@ -120,7 +131,6 @@ def process(value): Column("relispartition", Boolean, info={"server_version": (10,)}), Column("relrewrite", OID, info={"server_version": (11,)}), Column("reloptions", ARRAY(Text)), - schema="pg_catalog", ) pg_type = Table( @@ -155,7 +165,6 @@ def process(value): Column("typndims", Integer), Column("typcollation", OID, info={"server_version": (9, 1)}), Column("typdefault", Text), - schema="pg_catalog", ) pg_index = Table( @@ -182,7 +191,6 @@ def process(value): Column("indoption", INT2VECTOR), Column("indexprs", PG_NODE_TREE), Column("indpred", PG_NODE_TREE), - schema="pg_catalog", ) pg_attribute = Table( @@ -209,7 +217,6 @@ def process(value): Column("attislocal", Boolean), Column("attinhcount", Integer), Column("attcollation", OID, info={"server_version": (9, 1)}), - schema="pg_catalog", ) pg_constraint = Table( @@ -235,7 +242,6 @@ def process(value): Column("connoinherit", Boolean, info={"server_version": (9, 2)}), Column("conkey", ARRAY(SmallInteger)), Column("confkey", ARRAY(SmallInteger)), - schema="pg_catalog", ) pg_sequence = Table( @@ -249,7 +255,6 @@ def process(value): Column("seqmin", BigInteger), Column("seqcache", BigInteger), Column("seqcycle", Boolean), - schema="pg_catalog", info={"server_version": (10,)}, ) @@ -260,7 +265,6 @@ def process(value): Column("adrelid", OID), Column("adnum", SmallInteger), Column("adbin", PG_NODE_TREE), - schema="pg_catalog", ) pg_description = Table( @@ -270,7 +274,6 @@ def process(value): Column("classoid", OID), Column("objsubid", Integer), Column("description", Text(collation="C")), - schema="pg_catalog", ) pg_enum = Table( @@ -280,7 +283,6 @@ def process(value): Column("enumtypid", OID), Column("enumsortorder", Float(), info={"server_version": (9, 1)}), Column("enumlabel", NAME), - schema="pg_catalog", ) pg_am = Table( @@ -290,5 +292,35 @@ def process(value): Column("amname", NAME), Column("amhandler", REGPROC, info={"server_version": (9, 6)}), Column("amtype", CHAR, info={"server_version": (9, 6)}), - schema="pg_catalog", +) + +pg_collation = Table( + "pg_collation", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("collname", NAME), + Column("collnamespace", OID), + Column("collowner", OID), + Column("collprovider", CHAR, info={"server_version": (10,)}), + Column("collisdeterministic", Boolean, info={"server_version": (12,)}), + Column("collencoding", Integer), + Column("collcollate", Text), + Column("collctype", Text), + Column("colliculocale", Text), + Column("collicurules", Text, info={"server_version": (16,)}), + Column("collversion", Text, info={"server_version": (10,)}), +) + +pg_opclass = Table( + "pg_opclass", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("opcmethod", NAME), + Column("opcname", NAME), + Column("opsnamespace", OID), + Column("opsowner", OID), + Column("opcfamily", OID), + Column("opcintype", OID), + Column("opcdefault", Boolean), + Column("opckeytype", OID), ) diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index 87f1c9a4cea..c76f5f51849 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -1,3 +1,9 @@ +# dialects/postgresql/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import time @@ -91,7 +97,7 @@ def drop_all_schema_objects_pre_tables(cfg, eng): for xid in conn.exec_driver_sql( "select gid from pg_prepared_xacts" ).scalars(): - conn.execute("ROLLBACK PREPARED '%s'" % xid) + conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid) @drop_all_schema_objects_post_tables.for_db("postgresql") diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index dcd69ce6631..0554048c2bf 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -1,5 +1,5 @@ -# postgresql/psycopg2.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/psycopg.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,20 +29,29 @@ automatically select the sync version, e.g.:: from sqlalchemy import create_engine - sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") + + sync_engine = create_engine( + "postgresql+psycopg://scott:tiger@localhost/test" + ) * calling :func:`_asyncio.create_async_engine` with ``postgresql+psycopg://...`` will automatically select the async version, e.g.:: from sqlalchemy.ext.asyncio import create_async_engine - asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") + + asyncio_engine = create_async_engine( + "postgresql+psycopg://scott:tiger@localhost/test" + ) The asyncio version of the dialect may also be specified explicitly using the ``psycopg_async`` suffix, as:: from sqlalchemy.ext.asyncio import create_async_engine - asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") + + asyncio_engine = create_async_engine( + "postgresql+psycopg_async://scott:tiger@localhost/test" + ) .. seealso:: @@ -50,9 +59,42 @@ dialect shares most of its behavior with the ``psycopg2`` dialect. Further documentation is available there. +Using a different Cursor class +------------------------------ + +One of the differences between ``psycopg`` and the older ``psycopg2`` +is how bound parameters are handled: ``psycopg2`` would bind them +client side, while ``psycopg`` by default will bind them server side. + +It's possible to configure ``psycopg`` to do client side binding by +specifying the ``cursor_factory`` to be ``ClientCursor`` when creating +the engine:: + + from psycopg import ClientCursor + + client_side_engine = create_engine( + "postgresql+psycopg://...", + connect_args={"cursor_factory": ClientCursor}, + ) + +Similarly when using an async engine the ``AsyncClientCursor`` can be +specified:: + + from psycopg import AsyncClientCursor + + client_side_engine = create_async_engine( + "postgresql+psycopg://...", + connect_args={"cursor_factory": AsyncClientCursor}, + ) + +.. seealso:: + + `Client-side-binding cursors `_ + """ # noqa from __future__ import annotations +from collections import deque import logging import re from typing import cast @@ -79,6 +121,8 @@ if TYPE_CHECKING: from typing import Iterable + from psycopg import AsyncConnection + logger = logging.getLogger("sqlalchemy.dialects.postgresql") @@ -91,8 +135,6 @@ class _PGREGCONFIG(REGCONFIG): class _PGJSON(JSON): - render_bind_cast = True - def bind_processor(self, dialect): return self._make_bind_processor(None, dialect._psycopg_Json) @@ -101,8 +143,6 @@ def result_processor(self, dialect, coltype): class _PGJSONB(JSONB): - render_bind_cast = True - def bind_processor(self, dialect): return self._make_bind_processor(None, dialect._psycopg_Jsonb) @@ -162,7 +202,7 @@ class _PGBoolean(sqltypes.Boolean): render_bind_cast = True -class _PsycopgRange(ranges.AbstractRangeImpl): +class _PsycopgRange(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range @@ -218,8 +258,10 @@ def to_range(value): def result_processor(self, dialect, coltype): def to_range(value): - if value is not None: - value = [ + if value is None: + return None + else: + return ranges.MultiRange( ranges.Range( elem._lower, elem._upper, @@ -227,9 +269,7 @@ def to_range(value): empty=not elem._bounds, ) for elem in value - ] - - return value + ) return to_range @@ -286,7 +326,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): sqltypes.Integer: _PGInteger, sqltypes.SmallInteger: _PGSmallInteger, sqltypes.BigInteger: _PGBigInteger, - ranges.AbstractRange: _PsycopgRange, + ranges.AbstractSingleRange: _PsycopgRange, ranges.AbstractMultiRange: _PsycopgMultiRange, }, ) @@ -366,10 +406,12 @@ def initialize(self, connection): # register the adapter for connections made subsequent to # this one + assert self._psycopg_adapters_map register_hstore(info, self._psycopg_adapters_map) # register the adapter for this connection - register_hstore(info, connection.connection) + assert connection.connection + register_hstore(info, connection.connection.driver_connection) @classmethod def import_dbapi(cls): @@ -530,7 +572,7 @@ class AsyncAdapt_psycopg_cursor: def __init__(self, cursor, await_) -> None: self._cursor = cursor self.await_ = await_ - self._rows = [] + self._rows = deque() def __getattr__(self, name): return getattr(self._cursor, name) @@ -557,24 +599,19 @@ def execute(self, query, params=None, **kw): # eq/ne if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: rows = self.await_(self._cursor.fetchall()) - if not isinstance(rows, list): - self._rows = list(rows) - else: - self._rows = rows + self._rows = deque(rows) return result def executemany(self, query, params_seq): return self.await_(self._cursor.executemany(query, params_seq)) def __iter__(self): - # TODO: try to avoid pop(0) on a list while self._rows: - yield self._rows.pop(0) + yield self._rows.popleft() def fetchone(self): if self._rows: - # TODO: try to avoid pop(0) on a list - return self._rows.pop(0) + return self._rows.popleft() else: return None @@ -582,13 +619,12 @@ def fetchmany(self, size=None): if size is None: size = self._cursor.arraysize - retval = self._rows[0:size] - self._rows = self._rows[size:] - return retval + rr = self._rows + return [rr.popleft() for _ in range(min(size, len(rr)))] def fetchall(self): - retval = self._rows - self._rows = [] + retval = list(self._rows) + self._rows.clear() return retval @@ -619,6 +655,7 @@ def __iter__(self): class AsyncAdapt_psycopg_connection(AdaptedConnection): + _connection: AsyncConnection __slots__ = () await_ = staticmethod(await_only) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 2719f3dc5e5..eeb7604f796 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -1,5 +1,5 @@ -# postgresql/psycopg2.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/psycopg2.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -88,7 +88,6 @@ "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require" ) - Unix Domain Connections ------------------------ @@ -103,13 +102,17 @@ was built. This value can be overridden by passing a pathname to psycopg2, using ``host`` as an additional keyword argument:: - create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql" + ) .. warning:: The format accepted here allows for a hostname in the main URL in addition to the "host" query string argument. **When using this URL format, the initial host is silently ignored**. That is, this URL:: - engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2") + engine = create_engine( + "postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2" + ) Above, the hostname ``myhost1`` is **silently ignored and discarded.** The host which is connected is the ``myhost2`` host. @@ -190,7 +193,7 @@ For this form, the URL can be passed without any elements other than the initial scheme:: - engine = create_engine('postgresql+psycopg2://') + engine = create_engine("postgresql+psycopg2://") In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()`` function which in turn represents an empty DSN passed to libpq. @@ -242,7 +245,7 @@ Modern versions of psycopg2 include a feature known as `Fast Execution Helpers \ -`_, which +`_, which have been shown in benchmarking to improve psycopg2's executemany() performance, primarily with INSERT statements, by at least an order of magnitude. @@ -264,8 +267,8 @@ engine = create_engine( "postgresql+psycopg2://scott:tiger@host/dbname", - executemany_mode='values_plus_batch') - + executemany_mode="values_plus_batch", + ) Possible options for ``executemany_mode`` include: @@ -311,8 +314,10 @@ engine = create_engine( "postgresql+psycopg2://scott:tiger@host/dbname", - executemany_mode='values_plus_batch', - insertmanyvalues_page_size=5000, executemany_batch_page_size=500) + executemany_mode="values_plus_batch", + insertmanyvalues_page_size=5000, + executemany_batch_page_size=500, + ) .. seealso:: @@ -338,7 +343,9 @@ passed in the database URL; this parameter is consumed by the underlying ``libpq`` PostgreSQL client library:: - engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8") + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8" + ) Alternatively, the above ``client_encoding`` value may be passed using :paramref:`_sa.create_engine.connect_args` for programmatic establishment with @@ -346,7 +353,7 @@ engine = create_engine( "postgresql+psycopg2://user:pass@host/dbname", - connect_args={'client_encoding': 'utf8'} + connect_args={"client_encoding": "utf8"}, ) * For all PostgreSQL versions, psycopg2 supports a client-side encoding @@ -355,8 +362,7 @@ ``client_encoding`` parameter passed to :func:`_sa.create_engine`:: engine = create_engine( - "postgresql+psycopg2://user:pass@host/dbname", - client_encoding="utf8" + "postgresql+psycopg2://user:pass@host/dbname", client_encoding="utf8" ) .. tip:: The above ``client_encoding`` parameter admittedly is very similar @@ -375,11 +381,9 @@ # postgresql.conf file # client_encoding = sql_ascii # actually, defaults to database - # encoding + # encoding client_encoding = utf8 - - Transactions ------------ @@ -426,15 +430,15 @@ import logging - logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) Above, it is assumed that logging is configured externally. If this is not the case, configuration such as ``logging.basicConfig()`` must be utilized:: import logging - logging.basicConfig() # log messages to stdout - logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + logging.basicConfig() # log messages to stdout + logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) .. seealso:: @@ -471,8 +475,10 @@ use of the hstore extension by setting ``use_native_hstore`` to ``False`` as follows:: - engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", - use_native_hstore=False) + engine = create_engine( + "postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False, + ) The ``HSTORE`` type is **still supported** when the ``psycopg2.extensions.register_hstore()`` extension is not used. It merely @@ -513,7 +519,7 @@ def result_processor(self, dialect, coltype): return None -class _Psycopg2Range(ranges.AbstractRangeImpl): +class _Psycopg2Range(ranges.AbstractSingleRangeImpl): _psycopg2_range_cls = "none" def bind_processor(self, dialect): @@ -844,33 +850,43 @@ def is_disconnect(self, e, connection, cursor): # checks based on strings. in the case that .closed # didn't cut it, fall back onto these. str_e = str(e).partition("\n")[0] - for msg in [ - # these error messages from libpq: interfaces/libpq/fe-misc.c - # and interfaces/libpq/fe-secure.c. - "terminating connection", - "closed the connection", - "connection not open", - "could not receive data from server", - "could not send data to server", - # psycopg2 client errors, psycopg2/connection.h, - # psycopg2/cursor.h - "connection already closed", - "cursor already closed", - # not sure where this path is originally from, it may - # be obsolete. It really says "losed", not "closed". - "losed the connection unexpectedly", - # these can occur in newer SSL - "connection has been closed unexpectedly", - "SSL error: decryption failed or bad record mac", - "SSL SYSCALL error: Bad file descriptor", - "SSL SYSCALL error: EOF detected", - "SSL SYSCALL error: Operation timed out", - "SSL SYSCALL error: Bad address", - ]: + for msg in self._is_disconnect_messages: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: return True return False + @util.memoized_property + def _is_disconnect_messages(self): + return ( + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + # This can occur in OpenSSL 1 when an unexpected EOF occurs. + # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS + # It may also occur in newer OpenSSL for a non-recoverable I/O + # error as a result of a system call that does not set 'errno' + # in libc. + "SSL SYSCALL error: Success", + ) + dialect = PGDialect_psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index 211432c6dc7..55e17607044 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -1,5 +1,5 @@ -# testing/engines.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/psycopg2cffi.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index f1c29897d01..0ce4ea29137 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -1,4 +1,5 @@ -# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/ranges.py +# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -14,8 +15,10 @@ from typing import Any from typing import cast from typing import Generic +from typing import List from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -151,8 +154,8 @@ def upper_inf(self) -> bool: return not self.empty and self.upper is None @property - def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: - return AbstractRange() + def __sa_type_engine__(self) -> AbstractSingleRange[_T]: + return AbstractSingleRange() def _contains_value(self, value: _T) -> bool: """Return True if this range contains the given value.""" @@ -268,9 +271,9 @@ def _compare_edges( value2 += step value2_inc = False - if value1 < value2: # type: ignore + if value1 < value2: return -1 - elif value1 > value2: # type: ignore + elif value1 > value2: return 1 elif only_values: return 0 @@ -357,6 +360,8 @@ def contains(self, value: Union[_T, Range[_T]]) -> bool: else: return self._contains_value(value) + __contains__ = contains + def overlaps(self, other: Range[_T]) -> bool: "Determine whether this range overlaps with `other`." @@ -707,27 +712,46 @@ def _stringify(self) -> str: return f"{b0}{l},{r}{b1}" -class AbstractRange(sqltypes.TypeEngine[Range[_T]]): - """ - Base for PostgreSQL RANGE types. +class MultiRange(List[Range[_T]]): + """Represents a multirange sequence. + + This list subclass is an utility to allow automatic type inference of + the proper multi-range SQL type depending on the single range values. + This is useful when operating on literal multi-ranges:: + + import sqlalchemy as sa + from sqlalchemy.dialects.postgresql import MultiRange, Range + + value = literal(MultiRange([Range(2, 4)])) + + select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)]))) + + .. versionadded:: 2.0.26 .. seealso:: - `PostgreSQL range functions `_ + - :ref:`postgresql_multirange_list_use`. + """ + + @property + def __sa_type_engine__(self) -> AbstractMultiRange[_T]: + return AbstractMultiRange() - """ # noqa: E501 + +class AbstractRange(sqltypes.TypeEngine[_T]): + """Base class for single and multi Range SQL types.""" render_bind_cast = True __abstract__ = True @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, @@ -741,7 +765,10 @@ def adapt( and also render as ``INT4RANGE`` in SQL and DDL. """ - if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__: + if ( + issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl)) + and cls is not self.__class__ + ): # two ways to do this are: 1. create a new type on the fly # or 2. have AbstractRangeImpl(visit_name) constructor and a # visit_abstract_range_impl() method in the PG compiler. @@ -760,21 +787,6 @@ def adapt( else: return super().adapt(cls) - def _resolve_for_literal(self, value: Any) -> Any: - spec = value.lower if value.lower is not None else value.upper - - if isinstance(spec, int): - return INT8RANGE() - elif isinstance(spec, (Decimal, float)): - return NUMRANGE() - elif isinstance(spec, datetime): - return TSRANGE() if not spec.tzinfo else TSTZRANGE() - elif isinstance(spec, date): - return DATERANGE() - else: - # empty Range, SQL datatype can't be determined here - return sqltypes.NULLTYPE - class comparator_factory(TypeEngine.Comparator[Range[Any]]): """Define comparison operations for range types.""" @@ -856,91 +868,164 @@ def intersection(self, other: Any) -> ColumnElement[Range[_T]]: return self.expr.operate(operators.mul, other) -class AbstractRangeImpl(AbstractRange[Range[_T]]): - """Marker for AbstractRange that will apply a subclass-specific +class AbstractSingleRange(AbstractRange[Range[_T]]): + """Base for PostgreSQL RANGE types. + + These are types that return a single :class:`_postgresql.Range` object. + + .. seealso:: + + `PostgreSQL range functions `_ + + """ # noqa: E501 + + __abstract__ = True + + def _resolve_for_literal(self, value: Range[Any]) -> Any: + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises + # "operator does not exist: integer <@ int8range" as of pg 16 + if _is_int32(value): + return INT4RANGE() + else: + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + +class AbstractSingleRangeImpl(AbstractSingleRange[_T]): + """Marker for AbstractSingleRange that will apply a subclass-specific adaptation""" -class AbstractMultiRange(AbstractRange[Range[_T]]): - """base for PostgreSQL MULTIRANGE types""" +class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]): + """Base for PostgreSQL MULTIRANGE types. + + these are types that return a sequence of :class:`_postgresql.Range` + objects. + + """ __abstract__ = True + def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any: + if not value: + # empty MultiRange, SQL datatype can't be determined here + return sqltypes.NULLTYPE + first = value[0] + spec = first.lower if first.lower is not None else first.upper + + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises + # "operator does not exist: integer <@ int8multirange" as of pg 16 + if all(_is_int32(r) for r in value): + return INT4MULTIRANGE() + else: + return INT8MULTIRANGE() + elif isinstance(spec, (Decimal, float)): + return NUMMULTIRANGE() + elif isinstance(spec, datetime): + return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE() + elif isinstance(spec, date): + return DATEMULTIRANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + -class AbstractMultiRangeImpl( - AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] -): - """Marker for AbstractRange that will apply a subclass-specific +class AbstractMultiRangeImpl(AbstractMultiRange[_T]): + """Marker for AbstractMultiRange that will apply a subclass-specific adaptation""" -class INT4RANGE(AbstractRange[Range[int]]): +class INT4RANGE(AbstractSingleRange[int]): """Represent the PostgreSQL INT4RANGE type.""" __visit_name__ = "INT4RANGE" -class INT8RANGE(AbstractRange[Range[int]]): +class INT8RANGE(AbstractSingleRange[int]): """Represent the PostgreSQL INT8RANGE type.""" __visit_name__ = "INT8RANGE" -class NUMRANGE(AbstractRange[Range[Decimal]]): +class NUMRANGE(AbstractSingleRange[Decimal]): """Represent the PostgreSQL NUMRANGE type.""" __visit_name__ = "NUMRANGE" -class DATERANGE(AbstractRange[Range[date]]): +class DATERANGE(AbstractSingleRange[date]): """Represent the PostgreSQL DATERANGE type.""" __visit_name__ = "DATERANGE" -class TSRANGE(AbstractRange[Range[datetime]]): +class TSRANGE(AbstractSingleRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSRANGE" -class TSTZRANGE(AbstractRange[Range[datetime]]): +class TSTZRANGE(AbstractSingleRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZRANGE" -class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT4MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT4MULTIRANGE type.""" __visit_name__ = "INT4MULTIRANGE" -class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT8MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT8MULTIRANGE type.""" __visit_name__ = "INT8MULTIRANGE" -class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): +class NUMMULTIRANGE(AbstractMultiRange[Decimal]): """Represent the PostgreSQL NUMMULTIRANGE type.""" __visit_name__ = "NUMMULTIRANGE" -class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): +class DATEMULTIRANGE(AbstractMultiRange[date]): """Represent the PostgreSQL DATEMULTIRANGE type.""" __visit_name__ = "DATEMULTIRANGE" -class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSMULTIRANGE" -class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSTZMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZMULTIRANGE" + + +_max_int_32 = 2**31 - 1 +_min_int_32 = -(2**31) + + +def _is_int32(r: Range[int]) -> bool: + return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and ( + r.upper is None or _min_int_32 <= r.upper <= _max_int_32 + ) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 2cac5d816dd..1aed2bf4724 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -1,4 +1,5 @@ -# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors +# dialects/postgresql/types.py +# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -37,43 +38,52 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): @overload def __init__( self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... - ) -> None: - ... + ) -> None: ... @overload - def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None: - ... + def __init__( + self: PGUuid[str], as_uuid: Literal[False] = ... + ) -> None: ... - def __init__(self, as_uuid: bool = True) -> None: - ... + def __init__(self, as_uuid: bool = True) -> None: ... class BYTEA(sqltypes.LargeBinary): __visit_name__ = "BYTEA" -class INET(sqltypes.TypeEngine[str]): +class _NetworkAddressTypeMixin: + + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + if TYPE_CHECKING: + assert isinstance(self, TypeEngine) + return self + + +class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "INET" PGInet = INET -class CIDR(sqltypes.TypeEngine[str]): +class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "CIDR" PGCidr = CIDR -class MACADDR(sqltypes.TypeEngine[str]): +class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR" PGMacAddr = MACADDR -class MACADDR8(sqltypes.TypeEngine[str]): +class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR8" @@ -94,12 +104,11 @@ class MONEY(sqltypes.TypeEngine[str]): from sqlalchemy import Dialect from sqlalchemy import TypeDecorator + class NumericMoney(TypeDecorator): impl = MONEY - def process_result_value( - self, value: Any, dialect: Dialect - ) -> None: + def process_result_value(self, value: Any, dialect: Dialect) -> None: if value is not None: # adjust this for the currency and numeric m = re.match(r"\$([\d.]+)", value) @@ -114,6 +123,7 @@ def process_result_value( from sqlalchemy import cast from sqlalchemy import TypeDecorator + class NumericMoney(TypeDecorator): impl = MONEY @@ -122,20 +132,18 @@ def column_expression(self, column: Any): .. versionadded:: 1.2 - """ + """ # noqa: E501 __visit_name__ = "MONEY" class OID(sqltypes.TypeEngine[int]): - """Provide the PostgreSQL OID type.""" __visit_name__ = "OID" class REGCONFIG(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL REGCONFIG type. .. versionadded:: 2.0.0rc1 @@ -146,7 +154,6 @@ class REGCONFIG(sqltypes.TypeEngine[str]): class TSQUERY(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL TSQUERY type. .. versionadded:: 2.0.0rc1 @@ -157,7 +164,6 @@ class TSQUERY(sqltypes.TypeEngine[str]): class REGCLASS(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL REGCLASS type. .. versionadded:: 1.2.7 @@ -168,7 +174,6 @@ class REGCLASS(sqltypes.TypeEngine[str]): class TIMESTAMP(sqltypes.TIMESTAMP): - """Provide the PostgreSQL TIMESTAMP type.""" __visit_name__ = "TIMESTAMP" @@ -189,7 +194,6 @@ def __init__( class TIME(sqltypes.TIME): - """PostgreSQL TIME type.""" __visit_name__ = "TIME" @@ -210,7 +214,6 @@ def __init__( class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): - """PostgreSQL INTERVAL type.""" __visit_name__ = "INTERVAL" @@ -280,7 +283,6 @@ def __init__( class TSVECTOR(sqltypes.TypeEngine[str]): - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL text search type TSVECTOR. @@ -297,7 +299,6 @@ class TSVECTOR(sqltypes.TypeEngine[str]): class CITEXT(sqltypes.TEXT): - """Provide the PostgreSQL CITEXT type. .. versionadded:: 2.0.7 diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index 56bca47faeb..7b381fa6f52 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -1,5 +1,5 @@ -# sqlite/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index d9438d1880e..b8cb8c3819b 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -1,5 +1,5 @@ -# sqlite/aiosqlite.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/aiosqlite.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -31,6 +31,7 @@ :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("sqlite+aiosqlite:///filename") The URL passes through all arguments to the ``pysqlite`` driver, so all @@ -49,35 +50,37 @@ Serializable isolation / Savepoints / Transactional DDL (asyncio version) ------------------------------------------------------------------------- -Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature. +A newly revised version of this important section is now available +at the top level of the SQLAlchemy SQLite documentation, in the section +:ref:`sqlite_transactions`. -The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async:: - from sqlalchemy import create_engine, event - from sqlalchemy.ext.asyncio import create_async_engine +.. _aiosqlite_pooling: - engine = create_async_engine("sqlite+aiosqlite:///myfile.db") +Pooling Behavior +---------------- - @event.listens_for(engine.sync_engine, "connect") - def do_connect(dbapi_connection, connection_record): - # disable aiosqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before any DDL. - dbapi_connection.isolation_level = None +The SQLAlchemy ``aiosqlite`` DBAPI establishes the connection pool differently +based on the kind of SQLite database that's requested: - @event.listens_for(engine.sync_engine, "begin") - def do_begin(conn): - # emit our own BEGIN - conn.exec_driver_sql("BEGIN") +* When a ``:memory:`` SQLite database is specified, the dialect by default + will use :class:`.StaticPool`. This pool maintains a single + connection, so that all access to the engine + use the same ``:memory:`` database. +* When a file-based database is specified, the dialect will use + :class:`.AsyncAdaptedQueuePool` as the source of connections. -.. warning:: When using the above recipe, it is advised to not use the - :paramref:`.Connection.execution_options.isolation_level` setting on - :class:`_engine.Connection` and :func:`_sa.create_engine` - with the SQLite driver, - as this function necessarily will also alter the ".isolation_level" setting. + .. versionchanged:: 2.0.38 + + SQLite file database engines now use :class:`.AsyncAdaptedQueuePool` by default. + Previously, :class:`.NullPool` were used. The :class:`.NullPool` class + may be used by specifying it via the + :paramref:`_sa.create_engine.poolclass` parameter. """ # noqa import asyncio +from collections import deque from functools import partial from .base import SQLiteExecutionContext @@ -113,10 +116,10 @@ def __init__(self, adapt_connection): self.arraysize = 1 self.rowcount = -1 self.description = None - self._rows = [] + self._rows = deque() def close(self): - self._rows[:] = [] + self._rows.clear() def execute(self, operation, parameters=None): try: @@ -132,7 +135,7 @@ def execute(self, operation, parameters=None): self.lastrowid = self.rowcount = -1 if not self.server_side: - self._rows = self.await_(_cursor.fetchall()) + self._rows = deque(self.await_(_cursor.fetchall())) else: self.description = None self.lastrowid = _cursor.lastrowid @@ -161,11 +164,11 @@ def setinputsizes(self, *inputsizes): def __iter__(self): while self._rows: - yield self._rows.pop(0) + yield self._rows.popleft() def fetchone(self): if self._rows: - return self._rows.pop(0) + return self._rows.popleft() else: return None @@ -173,13 +176,12 @@ def fetchmany(self, size=None): if size is None: size = self.arraysize - retval = self._rows[0:size] - self._rows[:] = self._rows[size:] - return retval + rr = self._rows + return [rr.popleft() for _ in range(min(size, len(rr)))] def fetchall(self): - retval = self._rows[:] - self._rows[:] = [] + retval = list(self._rows) + self._rows.clear() return retval @@ -377,7 +379,7 @@ def import_dbapi(cls): @classmethod def get_pool_class(cls, url): if cls._is_url_file_db(url): - return pool.NullPool + return pool.AsyncAdaptedQueuePool else: return pool.StaticPool diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d4eb3bca41b..cc43a826f5a 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1,5 +1,5 @@ -# sqlite/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,10 +7,9 @@ # mypy: ignore-errors -r""" +r''' .. dialect:: sqlite :name: SQLite - :full_support: 3.36.0 :normal_support: 3.12+ :best_effort: 3.7.16+ @@ -70,9 +69,12 @@ when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table construct:: - Table('sometable', metadata, - Column('id', Integer, primary_key=True), - sqlite_autoincrement=True) + Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + sqlite_autoincrement=True, + ) Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -92,8 +94,13 @@ only using :meth:`.TypeEngine.with_variant`:: table = Table( - "my_table", metadata, - Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) + "my_table", + metadata, + Column( + "id", + BigInteger().with_variant(Integer, "sqlite"), + primary_key=True, + ), ) Another is to use a subclass of :class:`.BigInteger` that overrides its DDL @@ -102,21 +109,23 @@ from sqlalchemy import BigInteger from sqlalchemy.ext.compiler import compiles + class SLBigInteger(BigInteger): pass - @compiles(SLBigInteger, 'sqlite') + + @compiles(SLBigInteger, "sqlite") def bi_c(element, compiler, **kw): return "INTEGER" + @compiles(SLBigInteger) def bi_c(element, compiler, **kw): return compiler.visit_BIGINT(element, **kw) table = Table( - "my_table", metadata, - Column("id", SLBigInteger(), primary_key=True) + "my_table", metadata, Column("id", SLBigInteger(), primary_key=True) ) .. seealso:: @@ -127,99 +136,199 @@ def bi_c(element, compiler, **kw): `Datatypes In SQLite Version 3 `_ -.. _sqlite_concurrency: - -Database Locking Behavior / Concurrency ---------------------------------------- - -SQLite is not designed for a high level of write concurrency. The database -itself, being a file, is locked completely during write operations within -transactions, meaning exactly one "connection" (in reality a file handle) -has exclusive access to the database during this period - all other -"connections" will be blocked during this time. - -The Python DBAPI specification also calls for a connection model that is -always in a transaction; there is no ``connection.begin()`` method, -only ``connection.commit()`` and ``connection.rollback()``, upon which a -new transaction is to be begun immediately. This may seem to imply -that the SQLite driver would in theory allow only a single filehandle on a -particular database file at any time; however, there are several -factors both within SQLite itself as well as within the pysqlite driver -which loosen this restriction significantly. - -However, no matter what locking modes are used, SQLite will still always -lock the database file once a transaction is started and DML (e.g. INSERT, -UPDATE, DELETE) has at least been emitted, and this will block -other transactions at least at the point that they also attempt to emit DML. -By default, the length of time on this block is very short before it times out -with an error. - -This behavior becomes more critical when used in conjunction with the -SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs -within a transaction, and with its autoflush model, may emit DML preceding -any SELECT statement. This may lead to a SQLite database that locks -more quickly than is expected. The locking mode of SQLite and the pysqlite -driver can be manipulated to some degree, however it should be noted that -achieving a high degree of write-concurrency with SQLite is a losing battle. - -For more information on SQLite's lack of write concurrency by design, please -see -`Situations Where Another RDBMS May Work Better - High Concurrency -`_ near the bottom of the page. - -The following subsections introduce areas that are impacted by SQLite's -file-based architecture and additionally will usually require workarounds to -work when using the pysqlite driver. +.. _sqlite_transactions: + +Transactions with SQLite and the sqlite3 driver +----------------------------------------------- + +As a file-based database, SQLite's approach to transactions differs from +traditional databases in many ways. Additionally, the ``sqlite3`` driver +standard with Python (as well as the async version ``aiosqlite`` which builds +on top of it) has several quirks, workarounds, and API features in the +area of transaction control, all of which generally need to be addressed when +constructing a SQLAlchemy application that uses SQLite. + +Legacy Transaction Mode with the sqlite3 driver +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The most important aspect of transaction handling with the sqlite3 driver is +that it defaults (which will continue through Python 3.15 before being +removed in Python 3.16) to legacy transactional behavior which does +not strictly follow :pep:`249`. The way in which the driver diverges from the +PEP is that it does not "begin" a transaction automatically as dictated by +:pep:`249` except in the case of DML statements, e.g. INSERT, UPDATE, and +DELETE. Normally, :pep:`249` dictates that a BEGIN must be emitted upon +the first SQL statement of any kind, so that all subsequent operations will +be established within a transaction until ``connection.commit()`` has been +called. The ``sqlite3`` driver, in an effort to be easier to use in +highly concurrent environments, skips this step for DQL (e.g. SELECT) statements, +and also skips it for DDL (e.g. CREATE TABLE etc.) statements for more legacy +reasons. Statements such as SAVEPOINT are also skipped. + +In modern versions of the ``sqlite3`` driver as of Python 3.12, this legacy +mode of operation is referred to as +`"legacy transaction control" `_, and is in +effect by default due to the ``Connection.autocommit`` parameter being set to +the constant ``sqlite3.LEGACY_TRANSACTION_CONTROL``. Prior to Python 3.12, +the ``Connection.autocommit`` attribute did not exist. + +The implications of legacy transaction mode include: + +* **Incorrect support for transactional DDL** - statements like CREATE TABLE, ALTER TABLE, + CREATE INDEX etc. will not automatically BEGIN a transaction if one were not + started already, leading to the changes by each statement being + "autocommitted" immediately unless BEGIN were otherwise emitted first. Very + old (pre Python 3.6) versions of SQLite would also force a COMMIT for these + operations even if a transaction were present, however this is no longer the + case. +* **SERIALIZABLE behavior not fully functional** - SQLite's transaction isolation + behavior is normally consistent with SERIALIZABLE isolation, as it is a file- + based system that locks the database file entirely for write operations, + preventing COMMIT until all reader transactions (and associated file locks) + have completed. However, sqlite3's legacy transaction mode fails to emit BEGIN for SELECT + statements, which causes these SELECT statements to no longer be "repeatable", + failing one of the consistency guarantees of SERIALIZABLE. +* **Incorrect behavior for SAVEPOINT** - as the SAVEPOINT statement does not + imply a BEGIN, a new SAVEPOINT emitted before a BEGIN will function on its + own but fails to participate in the enclosing transaction, meaning a ROLLBACK + of the transaction will not rollback elements that were part of a released + savepoint. + +Legacy transaction mode first existed in order to faciliate working around +SQLite's file locks. Because SQLite relies upon whole-file locks, it is easy to +get "database is locked" errors, particularly when newer features like "write +ahead logging" are disabled. This is a key reason why ``sqlite3``'s legacy +transaction mode is still the default mode of operation; disabling it will +produce behavior that is more susceptible to locked database errors. However +note that **legacy transaction mode will no longer be the default** in a future +Python version (3.16 as of this writing). + +.. _sqlite_enabling_transactions: + +Enabling Non-Legacy SQLite Transactional Modes with the sqlite3 or aiosqlite driver +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Current SQLAlchemy support allows either for setting the +``.Connection.autocommit`` attribute, most directly by using a +:func:`._sa.create_engine` parameter, or if on an older version of Python where +the attribute is not available, using event hooks to control the behavior of +BEGIN. + +* **Enabling modern sqlite3 transaction control via the autocommit connect parameter** (Python 3.12 and above) + + To use SQLite in the mode described at `Transaction control via the autocommit attribute `_, + the most straightforward approach is to set the attribute to its recommended value + of ``False`` at the connect level using :paramref:`_sa.create_engine.connect_args``:: + + from sqlalchemy import create_engine + + engine = create_engine( + "sqlite:///myfile.db", connect_args={"autocommit": False} + ) + + This parameter is also passed through when using the aiosqlite driver:: + + from sqlalchemy.ext.asyncio import create_async_engine + + engine = create_async_engine( + "sqlite+aiosqlite:///myfile.db", connect_args={"autocommit": False} + ) + + The parameter can also be set at the attribute level using the :meth:`.PoolEvents.connect` + event hook, however this will only work for sqlite3, as aiosqlite does not yet expose this + attribute on its ``Connection`` object:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # enable autocommit=False mode + dbapi_connection.autocommit = False + +* **Using SQLAlchemy to emit BEGIN in lieu of SQLite's transaction control** (all Python versions, sqlite3 and aiosqlite) + + For older versions of ``sqlite3`` or for cross-compatiblity with older and + newer versions, SQLAlchemy can also take over the job of transaction control. + This is achieved by using the :meth:`.ConnectionEvents.begin` hook + to emit the "BEGIN" command directly, while also disabling SQLite's control + of this command using the :meth:`.PoolEvents.connect` event hook to set the + ``Connection.isolation_level`` attribute to ``None``:: + + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable sqlite3's emitting of the BEGIN statement entirely. + dbapi_connection.isolation_level = None + + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN. sqlite3 still emits COMMIT/ROLLBACK correctly + conn.exec_driver_sql("BEGIN") + + When using the asyncio variant ``aiosqlite``, refer to ``engine.sync_engine`` + as in the example below:: + + from sqlalchemy import create_engine, event + from sqlalchemy.ext.asyncio import create_async_engine + + engine = create_async_engine("sqlite+aiosqlite:///myfile.db") + + + @event.listens_for(engine.sync_engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable aiosqlite's emitting of the BEGIN statement entirely. + dbapi_connection.isolation_level = None + + + @event.listens_for(engine.sync_engine, "begin") + def do_begin(conn): + # emit our own BEGIN. aiosqlite still emits COMMIT/ROLLBACK correctly + conn.exec_driver_sql("BEGIN") .. _sqlite_isolation_level: -Transaction Isolation Level / Autocommit ----------------------------------------- - -SQLite supports "transaction isolation" in a non-standard way, along two -axes. One is that of the -`PRAGMA read_uncommitted `_ -instruction. This setting can essentially switch SQLite between its -default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation -mode normally referred to as ``READ UNCOMMITTED``. - -SQLAlchemy ties into this PRAGMA statement using the -:paramref:`_sa.create_engine.isolation_level` parameter of -:func:`_sa.create_engine`. -Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"`` -and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively. -SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by -the pysqlite driver's default behavior. - -When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also -available, which will alter the pysqlite connection using the ``.isolation_level`` -attribute on the DBAPI connection and set it to None for the duration -of the setting. - -.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level - when using the pysqlite / sqlite3 SQLite driver. - - -The other axis along which SQLite's transactional locking is impacted is -via the nature of the ``BEGIN`` statement used. The three varieties -are "deferred", "immediate", and "exclusive", as described at -`BEGIN TRANSACTION `_. A straight -``BEGIN`` statement uses the "deferred" mode, where the database file is -not locked until the first read or write operation, and read access remains -open to other transactions until the first write operation. But again, -it is critical to note that the pysqlite driver interferes with this behavior -by *not even emitting BEGIN* until the first write operation. +Using SQLAlchemy's Driver Level AUTOCOMMIT Feature with SQLite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. warning:: +SQLAlchemy has a comprehensive database isolation feature with optional +autocommit support that is introduced in the section :ref:`dbapi_autocommit`. - SQLite's transactional scope is impacted by unresolved - issues in the pysqlite driver, which defers BEGIN statements to a greater - degree than is often feasible. See the section :ref:`pysqlite_serializable` - or :ref:`aiosqlite_serializable` for techniques to work around this behavior. +For the ``sqlite3`` and ``aiosqlite`` drivers, SQLAlchemy only includes +built-in support for "AUTOCOMMIT". Note that this mode is currently incompatible +with the non-legacy isolation mode hooks documented in the previous +section at :ref:`sqlite_enabling_transactions`. -.. seealso:: +To use the ``sqlite3`` driver with SQLAlchemy driver-level autocommit, +create an engine setting the :paramref:`_sa.create_engine.isolation_level` +parameter to "AUTOCOMMIT":: + + eng = create_engine("sqlite:///myfile.db", isolation_level="AUTOCOMMIT") + +When using the above mode, any event hooks that set the sqlite3 ``Connection.autocommit`` +parameter away from its default of ``sqlite3.LEGACY_TRANSACTION_CONTROL`` +as well as hooks that emit ``BEGIN`` should be disabled. + +Additional Reading for SQLite / sqlite3 transaction control +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Links with important information on SQLite, the sqlite3 driver, +as well as long historical conversations on how things got to their current state: + +* `Isolation in SQLite `_ - on the SQLite website +* `Transaction control `_ - describes the sqlite3 autocommit attribute as well + as the legacy isolation_level attribute. +* `sqlite3 SELECT does not BEGIN a transaction, but should according to spec `_ - imported Python standard library issue on github +* `sqlite3 module breaks transactions and potentially corrupts data `_ - imported Python standard library issue on github - :ref:`dbapi_autocommit` INSERT/UPDATE/DELETE...RETURNING --------------------------------- @@ -236,63 +345,29 @@ def bi_c(element, compiler, **kw): # INSERT..RETURNING result = connection.execute( - table.insert(). - values(name='foo'). - returning(table.c.col1, table.c.col2) + table.insert().values(name="foo").returning(table.c.col1, table.c.col2) ) print(result.all()) # UPDATE..RETURNING result = connection.execute( - table.update(). - where(table.c.name=='foo'). - values(name='bar'). - returning(table.c.col1, table.c.col2) + table.update() + .where(table.c.name == "foo") + .values(name="bar") + .returning(table.c.col1, table.c.col2) ) print(result.all()) # DELETE..RETURNING result = connection.execute( - table.delete(). - where(table.c.name=='foo'). - returning(table.c.col1, table.c.col2) + table.delete() + .where(table.c.name == "foo") + .returning(table.c.col1, table.c.col2) ) print(result.all()) .. versionadded:: 2.0 Added support for SQLite RETURNING -SAVEPOINT Support ----------------------------- - -SQLite supports SAVEPOINTs, which only function once a transaction is -begun. SQLAlchemy's SAVEPOINT support is available using the -:meth:`_engine.Connection.begin_nested` method at the Core level, and -:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs -won't work at all with pysqlite unless workarounds are taken. - -.. warning:: - - SQLite's SAVEPOINT feature is impacted by unresolved - issues in the pysqlite and aiosqlite drivers, which defer BEGIN statements - to a greater degree than is often feasible. See the sections - :ref:`pysqlite_serializable` and :ref:`aiosqlite_serializable` - for techniques to work around this behavior. - -Transactional DDL ----------------------------- - -The SQLite database supports transactional :term:`DDL` as well. -In this case, the pysqlite driver is not only failing to start transactions, -it also is ending any existing transaction when DDL is detected, so again, -workarounds are required. - -.. warning:: - - SQLite's transactional DDL is impacted by unresolved issues - in the pysqlite driver, which fails to emit BEGIN and additionally - forces a COMMIT to cancel any transaction when DDL is encountered. - See the section :ref:`pysqlite_serializable` - for techniques to work around this behavior. .. _sqlite_foreign_keys: @@ -318,6 +393,7 @@ def bi_c(element, compiler, **kw): from sqlalchemy.engine import Engine from sqlalchemy import event + @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() @@ -380,13 +456,16 @@ def set_sqlite_pragma(dbapi_connection, connection_record): that specifies the IGNORE algorithm:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer), - UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE') + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Integer), + UniqueConstraint("id", "data", sqlite_on_conflict="IGNORE"), ) -The above renders CREATE TABLE DDL as:: +The above renders CREATE TABLE DDL as: + +.. sourcecode:: sql CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -403,13 +482,17 @@ def set_sqlite_pragma(dbapi_connection, connection_record): UNIQUE constraint in the DDL:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer, unique=True, - sqlite_on_conflict_unique='IGNORE') + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "data", Integer, unique=True, sqlite_on_conflict_unique="IGNORE" + ), ) -rendering:: +rendering: + +.. sourcecode:: sql CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -422,13 +505,17 @@ def set_sqlite_pragma(dbapi_connection, connection_record): ``sqlite_on_conflict_not_null`` is used:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer, nullable=False, - sqlite_on_conflict_not_null='FAIL') + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "data", Integer, nullable=False, sqlite_on_conflict_not_null="FAIL" + ), ) -this renders the column inline ON CONFLICT phrase:: +this renders the column inline ON CONFLICT phrase: + +.. sourcecode:: sql CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -440,13 +527,20 @@ def set_sqlite_pragma(dbapi_connection, connection_record): Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, primary_key=True, - sqlite_on_conflict_primary_key='FAIL') + "some_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + sqlite_on_conflict_primary_key="FAIL", + ), ) SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict -resolution algorithm is applied to the constraint itself:: +resolution algorithm is applied to the constraint itself: + +.. sourcecode:: sql CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -456,7 +550,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. _sqlite_on_conflict_insert: INSERT...ON CONFLICT (Upsert) ------------------------------------ +----------------------------- .. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as @@ -484,21 +578,18 @@ def set_sqlite_pragma(dbapi_connection, connection_record): >>> from sqlalchemy.dialects.sqlite import insert >>> insert_stmt = insert(my_table).values( - ... id='some_existing_id', - ... data='inserted value') + ... id="some_existing_id", data="inserted value" + ... ) >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value') + ... index_elements=["id"], set_=dict(data="updated value") ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO UPDATE SET data = ?{stop} - >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( - ... index_elements=['id'] - ... ) + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"]) >>> print(do_nothing_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) @@ -529,13 +620,13 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + >>> stmt = insert(my_table).values(user_email="a@b.com", data="inserted data") >>> do_update_stmt = stmt.on_conflict_do_update( ... index_elements=[my_table.c.user_email], - ... index_where=my_table.c.user_email.like('%@gmail.com'), - ... set_=dict(data=stmt.excluded.data) - ... ) + ... index_where=my_table.c.user_email.like("%@gmail.com"), + ... set_=dict(data=stmt.excluded.data), + ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (data, user_email) VALUES (?, ?) @@ -555,11 +646,10 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value') + ... index_elements=["id"], set_=dict(data="updated value") ... ) >>> print(do_update_stmt) @@ -587,14 +677,12 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id='some_id', - ... data='inserted value', - ... author='jlh' + ... id="some_id", data="inserted value", author="jlh" ... ) >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value', author=stmt.excluded.author) + ... index_elements=["id"], + ... set_=dict(data="updated value", author=stmt.excluded.author), ... ) >>> print(do_update_stmt) @@ -611,15 +699,13 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id='some_id', - ... data='inserted value', - ... author='jlh' + ... id="some_id", data="inserted value", author="jlh" ... ) >>> on_update_stmt = stmt.on_conflict_do_update( - ... index_elements=['id'], - ... set_=dict(data='updated value', author=stmt.excluded.author), - ... where=(my_table.c.status == 2) + ... index_elements=["id"], + ... set_=dict(data="updated value", author=stmt.excluded.author), + ... where=(my_table.c.status == 2), ... ) >>> print(on_update_stmt) {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) @@ -636,8 +722,8 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') - >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING @@ -648,7 +734,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record): .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = insert(my_table).values(id="some_id", data="inserted value") >>> stmt = stmt.on_conflict_do_nothing() >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING @@ -708,11 +794,16 @@ def set_sqlite_pragma(dbapi_connection, connection_record): A partial index, e.g. one which uses a WHERE clause, can be specified with the DDL system using the argument ``sqlite_where``:: - tbl = Table('testtbl', m, Column('data', Integer)) - idx = Index('test_idx1', tbl.c.data, - sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + tbl = Table("testtbl", m, Column("data", Integer)) + idx = Index( + "test_idx1", + tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10), + ) + +The index will be rendered at create time as: -The index will be rendered at create time as:: +.. sourcecode:: sql CREATE INDEX test_idx1 ON testtbl (data) WHERE data > 5 AND data < 10 @@ -732,7 +823,11 @@ def set_sqlite_pragma(dbapi_connection, connection_record): import sqlite3 - assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version" + assert sqlite3.sqlite_version_info < ( + 3, + 10, + 0, + ), "bug is fixed in this version" conn = sqlite3.connect(":memory:") cursor = conn.cursor() @@ -742,17 +837,22 @@ def set_sqlite_pragma(dbapi_connection, connection_record): cursor.execute("insert into x (a, b) values (2, 2)") cursor.execute("select x.a, x.b from x") - assert [c[0] for c in cursor.description] == ['a', 'b'] + assert [c[0] for c in cursor.description] == ["a", "b"] - cursor.execute(''' + cursor.execute( + """ select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - ''') - assert [c[0] for c in cursor.description] == ['a', 'b'], \ - [c[0] for c in cursor.description] + """ + ) + assert [c[0] for c in cursor.description] == ["a", "b"], [ + c[0] for c in cursor.description + ] -The second assertion fails:: +The second assertion fails: + +.. sourcecode:: text Traceback (most recent call last): File "test.py", line 19, in @@ -780,11 +880,13 @@ def set_sqlite_pragma(dbapi_connection, connection_record): result = conn.exec_driver_sql("select x.a, x.b from x") assert result.keys() == ["a", "b"] - result = conn.exec_driver_sql(''' + result = conn.exec_driver_sql( + """ select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - ''') + """ + ) assert result.keys() == ["a", "b"] Note that above, even though SQLAlchemy filters out the dots, *both @@ -808,16 +910,20 @@ def set_sqlite_pragma(dbapi_connection, connection_record): the ``sqlite_raw_colnames`` execution option may be provided, either on a per-:class:`_engine.Connection` basis:: - result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql(''' + result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql( + """ select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - ''') + """ + ) assert result.keys() == ["x.a", "x.b"] or on a per-:class:`_engine.Engine` basis:: - engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True}) + engine = create_engine( + "sqlite://", execution_options={"sqlite_raw_colnames": True} + ) When using the per-:class:`_engine.Engine` execution option, note that **Core and ORM queries that use UNION may not function properly**. @@ -832,12 +938,18 @@ def set_sqlite_pragma(dbapi_connection, connection_record): Table("some_table", metadata, ..., sqlite_with_rowid=False) +* + ``STRICT``:: + + Table("some_table", metadata, ..., sqlite_strict=True) + + .. versionadded:: 2.0.37 + .. seealso:: `SQLite CREATE TABLE options `_ - .. _sqlite_include_internal: Reflecting internal schema tables @@ -866,7 +978,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record): `SQLite Internal Schema Objects `_ - in the SQLite documentation. -""" # noqa +''' # noqa from __future__ import annotations import datetime @@ -888,7 +1000,6 @@ def set_sqlite_pragma(dbapi_connection, connection_record): from ...engine import reflection from ...engine.reflection import ReflectionDefaults from ...sql import coercions -from ...sql import ColumnElement from ...sql import compiler from ...sql import elements from ...sql import roles @@ -980,7 +1091,9 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" - e.g.:: + e.g.: + + .. sourcecode:: text 2021-03-15 12:05:57.105542 @@ -996,11 +1109,17 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): import re from sqlalchemy.dialects.sqlite import DATETIME - dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d " - "%(hour)02d:%(minute)02d:%(second)02d", - regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" + dt = DATETIME( + storage_format=( + "%(year)04d/%(month)02d/%(day)02d %(hour)02d:%(minute)02d:%(second)02d" + ), + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)", ) + :param truncate_microseconds: when ``True`` microseconds will be truncated + from the datetime. Can't be specified together with ``storage_format`` + or ``regexp``. + :param storage_format: format string which will be applied to the dict with keys year, month, day, hour, minute, second, and microsecond. @@ -1088,7 +1207,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): "%(year)04d-%(month)02d-%(day)02d" - e.g.:: + e.g.: + + .. sourcecode:: text 2011-03-15 @@ -1106,9 +1227,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): from sqlalchemy.dialects.sqlite import DATE d = DATE( - storage_format="%(month)02d/%(day)02d/%(year)04d", - regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)") - ) + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)"), + ) :param storage_format: format string which will be applied to the dict with keys year, month, and day. @@ -1162,7 +1283,9 @@ class TIME(_DateTimeMixin, sqltypes.Time): "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" - e.g.:: + e.g.: + + .. sourcecode:: text 12:05:57.10558 @@ -1178,11 +1301,15 @@ class TIME(_DateTimeMixin, sqltypes.Time): import re from sqlalchemy.dialects.sqlite import TIME - t = TIME(storage_format="%(hour)02d-%(minute)02d-" - "%(second)02d-%(microsecond)06d", - regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") + t = TIME( + storage_format="%(hour)02d-%(minute)02d-%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?"), ) + :param truncate_microseconds: when ``True`` microseconds will be truncated + from the time. Can't be specified together with ``storage_format`` + or ``regexp``. + :param storage_format: format string which will be applied to the dict with keys hour, minute, second, and microsecond. @@ -1308,7 +1435,7 @@ def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" def visit_localtimestamp_func(self, func, **kw): - return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + return "DATETIME(CURRENT_TIMESTAMP, 'localtime')" def visit_true(self, expr, **kw): return "1" @@ -1429,9 +1556,7 @@ def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) def _on_conflict_target(self, clause, **kw): - if clause.constraint_target is not None: - target_text = "(%s)" % clause.constraint_target - elif clause.inferred_target_elements is not None: + if clause.inferred_target_elements is not None: target_text = "(%s)" % ", ".join( ( self.preparer.quote(c) @@ -1445,7 +1570,7 @@ def _on_conflict_target(self, clause, **kw): clause.inferred_target_whereclause, include_table=False, use_schema=False, - literal_binds=True, + literal_execute=True, ) else: @@ -1528,6 +1653,13 @@ def visit_on_conflict_do_update(self, on_conflict, **kw): return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + def visit_bitwise_xor_op_binary(self, binary, operator, **kw): + # sqlite has no xor. Use "a XOR b" = "(a | b) - (a & b)". + kw["eager_grouping"] = True + or_ = self._generate_generic_binary(binary, " | ", **kw) + and_ = self._generate_generic_binary(binary, " & ", **kw) + return f"({or_} - {and_})" + class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): @@ -1537,9 +1669,13 @@ def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: - if isinstance(column.server_default.arg, ColumnElement): - default = "(" + default + ")" - colspec += " DEFAULT " + default + + if not re.match(r"""^\s*[\'\"\(]""", default) and re.match( + r".*\W.*", default + ): + colspec += f" DEFAULT ({default})" + else: + colspec += f" DEFAULT {default}" if not column.nullable: colspec += " NOT NULL" @@ -1701,9 +1837,18 @@ def visit_create_index( return text def post_create_table(self, table): - if table.dialect_options["sqlite"]["with_rowid"] is False: - return "\n WITHOUT ROWID" - return "" + table_options = [] + + if not table.dialect_options["sqlite"]["with_rowid"]: + table_options.append("WITHOUT ROWID") + + if table.dialect_options["sqlite"]["strict"]: + table_options.append("STRICT") + + if table_options: + return "\n " + ",\n ".join(table_options) + else: + return "" class SQLiteTypeCompiler(compiler.GenericTypeCompiler): @@ -1938,6 +2083,7 @@ class SQLiteDialect(default.DefaultDialect): { "autoincrement": False, "with_rowid": True, + "strict": False, }, ), (sa_schema.Index, {"where": None}), @@ -2030,9 +2176,9 @@ def __init__( ) if self.dbapi.sqlite_version_info < (3, 35) or util.pypy: - self.update_returning = ( - self.delete_returning - ) = self.insert_returning = False + self.update_returning = self.delete_returning = ( + self.insert_returning + ) = False if self.dbapi.sqlite_version_info < (3, 32, 0): # https://www.sqlite.org/limits.html @@ -2231,6 +2377,14 @@ def get_columns(self, connection, table_name, schema=None, **kw): tablesql = self._get_table_sql( connection, table_name, schema, **kw ) + # remove create table + match = re.match( + r"create table .*?\((.*)\)$", + tablesql.strip(), + re.DOTALL | re.IGNORECASE, + ) + assert match, f"create table not found in {tablesql}" + tablesql = match.group(1).strip() columns.append( self._get_column_info( @@ -2285,7 +2439,10 @@ def _get_column_info( if generated: sqltext = "" if tablesql: - pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" + pattern = ( + r"[^,]*\s+GENERATED\s+ALWAYS\s+AS" + r"\s+\((.*)\)\s*(?:virtual|stored)?" + ) match = re.search( re.escape(name) + pattern, tablesql, re.IGNORECASE ) @@ -2570,8 +2727,8 @@ def parse_uqs(): return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( - r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' - r"+[a-z0-9_ ]+? +UNIQUE" + r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?)[\t ]' + r"+[a-z0-9_ ]+?[\t ]+UNIQUE" ) for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): @@ -2606,15 +2763,21 @@ def get_check_constraints(self, connection, table_name, schema=None, **kw): connection, table_name, schema=schema, **kw ) - CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" - cks = [] - # NOTE: we aren't using re.S here because we actually are - # taking advantage of each CHECK constraint being all on one - # line in the table definition in order to delineate. This + # NOTE NOTE NOTE + # DO NOT CHANGE THIS REGULAR EXPRESSION. There is no known way + # to parse CHECK constraints that contain newlines themselves using + # regular expressions, and the approach here relies upon each + # individual + # CHECK constraint being on a single line by itself. This # necessarily makes assumptions as to how the CREATE TABLE - # was emitted. + # was emitted. A more comprehensive DDL parsing solution would be + # needed to improve upon the current situation. See #11840 for + # background + CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?CHECK *\( *(.+) *\),? *" + cks = [] for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): + name = match.group(1) if name: diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index ec428f5b172..84cdb8bec23 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -1,5 +1,5 @@ -# sqlite/dml.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/dml.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,6 +7,10 @@ from __future__ import annotations from typing import Any +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from .._typing import _OnConflictIndexElementsT from .._typing import _OnConflictIndexWhereT @@ -15,6 +19,7 @@ from ... import util from ...sql import coercions from ...sql import roles +from ...sql import schema from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative @@ -22,7 +27,9 @@ from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement +from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -141,11 +148,10 @@ def on_conflict_do_update( :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. If present, can be a literal SQL - string or an acceptable expression for a ``WHERE`` clause - that restricts the rows affected by ``DO UPDATE SET``. Rows - not meeting the ``WHERE`` condition will not be updated - (effectively a ``DO NOTHING`` for those rows). + Optional argument. An expression object representing a ``WHERE`` + clause that restricts the rows affected by ``DO UPDATE SET``. Rows not + meeting the ``WHERE`` condition will not be updated (effectively a + ``DO NOTHING`` for those rows). """ @@ -184,9 +190,10 @@ def on_conflict_do_nothing( class OnConflictClause(ClauseElement): stringify_dialect = "sqlite" - constraint_target: None - inferred_target_elements: _OnConflictIndexElementsT - inferred_target_whereclause: _OnConflictIndexWhereT + inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] + inferred_target_whereclause: Optional[ + Union[ColumnElement[Any], TextClause] + ] def __init__( self, @@ -194,13 +201,22 @@ def __init__( index_where: _OnConflictIndexWhereT = None, ): if index_elements is not None: - self.constraint_target = None - self.inferred_target_elements = index_elements - self.inferred_target_whereclause = index_where + self.inferred_target_elements = [ + coercions.expect(roles.DDLConstraintColumnRole, column) + for column in index_elements + ] + self.inferred_target_whereclause = ( + coercions.expect( + roles.WhereHavingRole, + index_where, + ) + if index_where is not None + else None + ) else: - self.constraint_target = ( - self.inferred_target_elements - ) = self.inferred_target_whereclause = None + self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None class OnConflictDoNothing(OnConflictClause): @@ -210,6 +226,9 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" + update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_whereclause: Optional[ColumnElement[Any]] + def __init__( self, index_elements: _OnConflictIndexElementsT = None, @@ -237,4 +256,8 @@ def __init__( (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = where + self.update_whereclause = ( + coercions.expect(roles.WhereHavingRole, where) + if where is not None + else None + ) diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py index 69df3171c22..02f4ea4c90f 100644 --- a/lib/sqlalchemy/dialects/sqlite/json.py +++ b/lib/sqlalchemy/dialects/sqlite/json.py @@ -1,3 +1,9 @@ +# dialects/sqlite/json.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import types as sqltypes diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index 2ed8253ab47..e1df005e72c 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -1,3 +1,9 @@ +# dialects/sqlite/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import os @@ -46,8 +52,6 @@ def _format_url(url, driver, ident): assert "test_schema" not in filename tokens = re.split(r"[_\.]", filename) - new_filename = f"{driver}" - for token in tokens: if token in _drivernames: if driver is None: diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 28b900ea53d..7a3dc1bae13 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -1,5 +1,5 @@ -# sqlite/pysqlcipher.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/pysqlcipher.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -39,7 +39,7 @@ e = create_engine( "sqlite+pysqlcipher://:password@/dbname.db", - module=sqlcipher_compatible_driver + module=sqlcipher_compatible_driver, ) These drivers make use of the SQLCipher engine. This system essentially @@ -55,12 +55,12 @@ of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the "password" field is now accepted, which should contain a passphrase:: - e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + e = create_engine("sqlite+pysqlcipher://:testing@/foo.db") For an absolute file path, two leading slashes should be used for the database name:: - e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + e = create_engine("sqlite+pysqlcipher://:testing@//path/to/foo.db") A selection of additional encryption-related pragmas supported by SQLCipher as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed @@ -68,7 +68,9 @@ new connection. Currently, ``cipher``, ``kdf_iter`` ``cipher_page_size`` and ``cipher_use_hmac`` are supported:: - e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + e = create_engine( + "sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000" + ) .. warning:: Previous versions of sqlalchemy did not take into consideration the encryption-related pragmas passed in the url string, that were silently diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 3cd6e5f231a..4a777e3b81d 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -1,5 +1,5 @@ -# sqlite/pysqlite.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# dialects/sqlite/pysqlite.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,7 +28,9 @@ --------------- The file specification for the SQLite database is taken as the "database" -portion of the URL. Note that the format of a SQLAlchemy url is:: +portion of the URL. Note that the format of a SQLAlchemy url is: + +.. sourcecode:: text driver://user:pass@host/database @@ -37,25 +39,28 @@ looks like:: # relative path - e = create_engine('sqlite:///path/to/database.db') + e = create_engine("sqlite:///path/to/database.db") An absolute path, which is denoted by starting with a slash, means you need **four** slashes:: # absolute path - e = create_engine('sqlite:////path/to/database.db') + e = create_engine("sqlite:////path/to/database.db") To use a Windows path, regular drive specifications and backslashes can be used. Double backslashes are probably needed:: # absolute path on Windows - e = create_engine('sqlite:///C:\\path\\to\\database.db') + e = create_engine("sqlite:///C:\\path\\to\\database.db") -The sqlite ``:memory:`` identifier is the default if no filepath is -present. Specify ``sqlite://`` and nothing else:: +To use sqlite ``:memory:`` database specify it as the filename using +``sqlite:///:memory:``. It's also the default if no filepath is +present, specifying only ``sqlite://`` and nothing else:: - # in-memory database - e = create_engine('sqlite://') + # in-memory database (note three slashes) + e = create_engine("sqlite:///:memory:") + # also in-memory database + e2 = create_engine("sqlite://") .. _pysqlite_uri_connections: @@ -95,7 +100,9 @@ sqlite3.connect( "file:path/to/database?mode=ro&nolock=1", - check_same_thread=True, timeout=10, uri=True + check_same_thread=True, + timeout=10, + uri=True, ) Regarding future parameters added to either the Python or native drivers. new @@ -141,8 +148,11 @@ def regexp(a, b): return re.search(a, b) is not None + sqlite_connection.create_function( - "regexp", 2, regexp, + "regexp", + 2, + regexp, ) There is currently no support for regular expression flags as a separate @@ -183,10 +193,12 @@ def regexp(a, b): nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES can be forced if one configures "native_datetime=True" on create_engine():: - engine = create_engine('sqlite://', - connect_args={'detect_types': - sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, - native_datetime=True + engine = create_engine( + "sqlite://", + connect_args={ + "detect_types": sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES + }, + native_datetime=True, ) With this flag enabled, the DATE and TIMESTAMP types (but note - not the @@ -241,6 +253,7 @@ def regexp(a, b): parameter:: from sqlalchemy import NullPool + engine = create_engine("sqlite:///myfile.db", poolclass=NullPool) It's been observed that the :class:`.NullPool` implementation incurs an @@ -260,9 +273,12 @@ def regexp(a, b): as ``False``:: from sqlalchemy.pool import StaticPool - engine = create_engine('sqlite://', - connect_args={'check_same_thread':False}, - poolclass=StaticPool) + + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) Note that using a ``:memory:`` database in multiple threads requires a recent version of SQLite. @@ -281,14 +297,14 @@ def regexp(a, b): # maintain the same connection per thread from sqlalchemy.pool import SingletonThreadPool - engine = create_engine('sqlite:///mydb.db', - poolclass=SingletonThreadPool) + + engine = create_engine("sqlite:///mydb.db", poolclass=SingletonThreadPool) # maintain the same connection across all threads from sqlalchemy.pool import StaticPool - engine = create_engine('sqlite:///mydb.db', - poolclass=StaticPool) + + engine = create_engine("sqlite:///mydb.db", poolclass=StaticPool) Note that :class:`.SingletonThreadPool` should be configured for the number of threads that are to be used; beyond that number, connections will be @@ -317,13 +333,14 @@ def regexp(a, b): from sqlalchemy import String from sqlalchemy import TypeDecorator + class MixedBinary(TypeDecorator): impl = String cache_ok = True def process_result_value(self, value, dialect): if isinstance(value, str): - value = bytes(value, 'utf-8') + value = bytes(value, "utf-8") elif value is not None: value = bytes(value) @@ -337,74 +354,10 @@ def process_result_value(self, value, dialect): Serializable isolation / Savepoints / Transactional DDL ------------------------------------------------------- -In the section :ref:`sqlite_concurrency`, we refer to the pysqlite -driver's assortment of issues that prevent several features of SQLite -from working correctly. The pysqlite DBAPI driver has several -long-standing bugs which impact the correctness of its transactional -behavior. In its default mode of operation, SQLite features such as -SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are -non-functional, and in order to use these features, workarounds must -be taken. - -The issue is essentially that the driver attempts to second-guess the user's -intent, failing to start transactions and sometimes ending them prematurely, in -an effort to minimize the SQLite databases's file locking behavior, even -though SQLite itself uses "shared" locks for read-only activities. - -SQLAlchemy chooses to not alter this behavior by default, as it is the -long-expected behavior of the pysqlite driver; if and when the pysqlite -driver attempts to repair these issues, that will be more of a driver towards -defaults for SQLAlchemy. +A newly revised version of this important section is now available +at the top level of the SQLAlchemy SQLite documentation, in the section +:ref:`sqlite_transactions`. -The good news is that with a few events, we can implement transactional -support fully, by disabling pysqlite's feature entirely and emitting BEGIN -ourselves. This is achieved using two event listeners:: - - from sqlalchemy import create_engine, event - - engine = create_engine("sqlite:///myfile.db") - - @event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): - # disable pysqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before any DDL. - dbapi_connection.isolation_level = None - - @event.listens_for(engine, "begin") - def do_begin(conn): - # emit our own BEGIN - conn.exec_driver_sql("BEGIN") - -.. warning:: When using the above recipe, it is advised to not use the - :paramref:`.Connection.execution_options.isolation_level` setting on - :class:`_engine.Connection` and :func:`_sa.create_engine` - with the SQLite driver, - as this function necessarily will also alter the ".isolation_level" setting. - - -Above, we intercept a new pysqlite connection and disable any transactional -integration. Then, at the point at which SQLAlchemy knows that transaction -scope is to begin, we emit ``"BEGIN"`` ourselves. - -When we take control of ``"BEGIN"``, we can also control directly SQLite's -locking modes, introduced at -`BEGIN TRANSACTION `_, -by adding the desired locking mode to our ``"BEGIN"``:: - - @event.listens_for(engine, "begin") - def do_begin(conn): - conn.exec_driver_sql("BEGIN EXCLUSIVE") - -.. seealso:: - - `BEGIN TRANSACTION `_ - - on the SQLite site - - `sqlite3 SELECT does not BEGIN a transaction `_ - - on the Python bug tracker - - `sqlite3 module breaks transactions and potentially corrupts data `_ - - on the Python bug tracker .. _pysqlite_udfs: @@ -439,7 +392,6 @@ def connect(conn, rec): with engine.connect() as conn: print(conn.scalar(text("SELECT UDF()"))) - """ # noqa import math diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 843f970257a..f4205d89260 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -1,5 +1,5 @@ # engine/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py index 1cc5e8dea40..8536d53d779 100644 --- a/lib/sqlalchemy/engine/_py_processors.py +++ b/lib/sqlalchemy/engine/_py_processors.py @@ -1,5 +1,5 @@ -# sqlalchemy/processors.py -# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors +# engine/_py_processors.py +# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors # # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index 3358abd7848..38c60fcd276 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -1,3 +1,9 @@ +# engine/_py_row.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import operator diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py index 538c075a2b5..50badea2a94 100644 --- a/lib/sqlalchemy/engine/_py_util.py +++ b/lib/sqlalchemy/engine/_py_util.py @@ -1,3 +1,9 @@ +# engine/_py_util.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import typing diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0000e28103d..ad0e4b62435 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1,12 +1,10 @@ # engine/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`. - -""" +"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.""" from __future__ import annotations import contextlib @@ -70,12 +68,11 @@ from ..sql._typing import _InfoType from ..sql.compiler import Compiled from ..sql.ddl import ExecutableDDLElement - from ..sql.ddl import SchemaDropper - from ..sql.ddl import SchemaGenerator + from ..sql.ddl import InvokeDDLBase from ..sql.functions import FunctionElement from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr - from ..sql.schema import SchemaItem + from ..sql.schema import SchemaVisitable from ..sql.selectable import TypedReturnsRows @@ -109,6 +106,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ + dialect: Dialect dispatch: dispatcher[ConnectionEventsTarget] _sqla_logger_namespace = "sqlalchemy.engine.Connection" @@ -173,13 +171,9 @@ def __init__( if self._has_events or self.engine._has_events: self.dispatch.engine_connect(self) - @util.memoized_property - def _message_formatter(self) -> Any: - if "logging_token" in self._execution_options: - token = self._execution_options["logging_token"] - return lambda msg: "[%s] %s" % (token, msg) - else: - return None + # this can be assigned differently via + # characteristics.LoggingTokenCharacteristic + _message_formatter: Any = None def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter @@ -205,9 +199,9 @@ def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None: @property def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]: - schema_translate_map: Optional[ - SchemaTranslateMapType - ] = self._execution_options.get("schema_translate_map", None) + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) return schema_translate_map @@ -218,9 +212,9 @@ def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: """ name = obj.schema - schema_translate_map: Optional[ - SchemaTranslateMapType - ] = self._execution_options.get("schema_translate_map", None) + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) if ( schema_translate_map @@ -250,13 +244,12 @@ def execution_options( yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, **opt: Any, - ) -> Connection: - ... + ) -> Connection: ... @overload - def execution_options(self, **opt: Any) -> Connection: - ... + def execution_options(self, **opt: Any) -> Connection: ... def execution_options(self, **opt: Any) -> Connection: r"""Set non-SQL options for the connection which take effect @@ -382,12 +375,11 @@ def execution_options(self, **opt: Any) -> Connection: :param stream_results: Available on: :class:`_engine.Connection`, :class:`_sql.Executable`. - Indicate to the dialect that results should be - "streamed" and not pre-buffered, if possible. For backends - such as PostgreSQL, MySQL and MariaDB, this indicates the use of - a "server side cursor" as opposed to a client side cursor. - Other backends such as that of Oracle may already use server - side cursors by default. + Indicate to the dialect that results should be "streamed" and not + pre-buffered, if possible. For backends such as PostgreSQL, MySQL + and MariaDB, this indicates the use of a "server side cursor" as + opposed to a client side cursor. Other backends such as that of + Oracle Database may already use server side cursors by default. The usage of :paramref:`_engine.Connection.execution_options.stream_results` is @@ -492,6 +484,18 @@ def execution_options(self, **opt: Any) -> Connection: :ref:`schema_translating` + :param preserve_rowcount: Boolean; when True, the ``cursor.rowcount`` + attribute will be unconditionally memoized within the result and + made available via the :attr:`.CursorResult.rowcount` attribute. + Normally, this attribute is only preserved for UPDATE and DELETE + statements. Using this option, the DBAPIs rowcount value can + be accessed for other kinds of statements such as INSERT and SELECT, + to the degree that the DBAPI supports these statements. See + :attr:`.CursorResult.rowcount` for notes regarding the behavior + of this attribute. + + .. versionadded:: 2.0.28 + .. seealso:: :meth:`_engine.Engine.execution_options` @@ -793,7 +797,6 @@ def begin(self) -> RootTransaction: with conn.begin() as trans: conn.execute(table.insert(), {"username": "sandy"}) - The returned object is an instance of :class:`_engine.RootTransaction`. This object represents the "scope" of the transaction, which completes when either the :meth:`_engine.Transaction.rollback` @@ -899,7 +902,7 @@ def begin_nested(self) -> NestedTransaction: trans.rollback() # rollback to savepoint # outer transaction continues - connection.execute( ... ) + connection.execute(...) If :meth:`_engine.Connection.begin_nested` is called without first calling :meth:`_engine.Connection.begin` or @@ -909,11 +912,11 @@ def begin_nested(self) -> NestedTransaction: with engine.connect() as connection: # begin() wasn't called - with connection.begin_nested(): will auto-"begin()" first - connection.execute( ... ) + with connection.begin_nested(): # will auto-"begin()" first + connection.execute(...) # savepoint is released - connection.execute( ... ) + connection.execute(...) # explicitly commit outer transaction connection.commit() @@ -1262,8 +1265,7 @@ def scalar( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -1272,8 +1274,7 @@ def scalar( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -1311,8 +1312,7 @@ def scalars( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -1321,8 +1321,7 @@ def scalars( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, @@ -1356,8 +1355,7 @@ def execute( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[_T]: - ... + ) -> CursorResult[_T]: ... @overload def execute( @@ -1366,8 +1364,7 @@ def execute( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... def execute( self, @@ -1498,7 +1495,7 @@ def _execute_ddl( ) -> CursorResult[Any]: """Execute a schema.DDL object.""" - execution_options = ddl._execution_options.merge_with( + exec_opts = ddl._execution_options.merge_with( self._execution_options, execution_options ) @@ -1512,12 +1509,11 @@ def _execute_ddl( event_multiparams, event_params, ) = self._invoke_before_exec_event( - ddl, distilled_parameters, execution_options + ddl, distilled_parameters, exec_opts ) else: event_multiparams = event_params = None - exec_opts = self._execution_options.merge_with(execution_options) schema_translate_map = exec_opts.get("schema_translate_map", None) dialect = self.dialect @@ -1530,7 +1526,7 @@ def _execute_ddl( dialect.execution_ctx_cls._init_ddl, compiled, None, - execution_options, + exec_opts, compiled, ) if self._has_events or self.engine._has_events: @@ -1539,7 +1535,7 @@ def _execute_ddl( ddl, event_multiparams, event_params, - execution_options, + exec_opts, ret, ) return ret @@ -1737,21 +1733,20 @@ def exec_driver_sql( conn.exec_driver_sql( "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", - [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] + [{"id": 1, "value": "v1"}, {"id": 2, "value": "v2"}], ) Single dictionary:: conn.exec_driver_sql( "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", - dict(id=1, value="v1") + dict(id=1, value="v1"), ) Single tuple:: conn.exec_driver_sql( - "INSERT INTO table (id, value) VALUES (?, ?)", - (1, 'v1') + "INSERT INTO table (id, value) VALUES (?, ?)", (1, "v1") ) .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does @@ -1840,10 +1835,7 @@ def _execute_context( context.pre_exec() if context.execute_style is ExecuteStyle.INSERTMANYVALUES: - return self._exec_insertmany_context( - dialect, - context, - ) + return self._exec_insertmany_context(dialect, context) else: return self._exec_single_context( dialect, context, statement, parameters @@ -2018,16 +2010,22 @@ def _exec_insertmany_context( engine_events = self._has_events or self.engine._has_events if self.dialect._has_events: - do_execute_dispatch: Iterable[ - Any - ] = self.dialect.dispatch.do_execute + do_execute_dispatch: Iterable[Any] = ( + self.dialect.dispatch.do_execute + ) else: do_execute_dispatch = () if self._echo: stats = context._get_cache_stats() + " (insertmanyvalues)" + preserve_rowcount = context.execution_options.get( + "preserve_rowcount", False + ) + rowcount = 0 + for imv_batch in dialect._deliver_insertmanyvalues_batches( + self, cursor, str_statement, effective_parameters, @@ -2048,6 +2046,7 @@ def _exec_insertmany_context( imv_batch.replaced_parameters, None, context, + is_sub_exec=True, ) sub_stmt = imv_batch.replaced_statement @@ -2067,15 +2066,16 @@ def _exec_insertmany_context( if self._echo: self._log_info(sql_util._long_statement(sub_stmt)) - imv_stats = f""" { - imv_batch.batchnum}/{imv_batch.total_batches} ({ - 'ordered' - if imv_batch.rows_sorted else 'unordered' - }{ - '; batch not supported' - if imv_batch.is_downgraded - else '' - })""" + imv_stats = f""" {imv_batch.batchnum}/{ + imv_batch.total_batches + } ({ + 'ordered' + if imv_batch.rows_sorted else 'unordered' + }{ + '; batch not supported' + if imv_batch.is_downgraded + else '' + })""" if imv_batch.batchnum == 1: stats += imv_stats @@ -2136,9 +2136,15 @@ def _exec_insertmany_context( context.executemany, ) + if preserve_rowcount: + rowcount += imv_batch.current_batch_size + try: context.post_exec() + if preserve_rowcount: + context._rowcount = rowcount # type: ignore[attr-defined] + result = context._setup_result_proxy() except BaseException as e: @@ -2380,9 +2386,9 @@ def _handle_dbapi_exception_noconnection( None, cast(Exception, e), dialect.loaded_dbapi.Error, - hide_parameters=engine.hide_parameters - if engine is not None - else False, + hide_parameters=( + engine.hide_parameters if engine is not None else False + ), connection_invalidated=is_disconnect, dialect=dialect, ) @@ -2419,9 +2425,7 @@ def _handle_dbapi_exception_noconnection( break if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = ( - is_disconnect - ) = ctx.is_disconnect + sqlalchemy_exception.connection_invalidated = ctx.is_disconnect if newraise: raise newraise.with_traceback(exc_info[2]) from e @@ -2434,8 +2438,8 @@ def _handle_dbapi_exception_noconnection( def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: SchemaVisitable, **kwargs: Any, ) -> None: """run a DDL visitor. @@ -2444,7 +2448,9 @@ def _run_ddl_visitor( options given to the visitor so that "checkfirst" is skipped. """ - visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + visitorcallable( + dialect=self.dialect, connection=self, **kwargs + ).traverse_single(element) class ExceptionContextImpl(ExceptionContext): @@ -2502,6 +2508,7 @@ class Transaction(TransactionalContext): :class:`_engine.Connection`:: from sqlalchemy import create_engine + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") connection = engine.connect() trans = connection.begin() @@ -2990,7 +2997,7 @@ def clear_compiled_cache(self) -> None: This applies **only** to the built-in cache that is established via the :paramref:`_engine.create_engine.query_cache_size` parameter. It will not impact any dictionary caches that were passed via the - :paramref:`.Connection.execution_options.query_cache` parameter. + :paramref:`.Connection.execution_options.compiled_cache` parameter. .. versionadded:: 1.4 @@ -3029,12 +3036,10 @@ def execution_options( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> OptionEngine: - ... + ) -> OptionEngine: ... @overload - def execution_options(self, **opt: Any) -> OptionEngine: - ... + def execution_options(self, **opt: Any) -> OptionEngine: ... def execution_options(self, **opt: Any) -> OptionEngine: """Return a new :class:`_engine.Engine` that will provide @@ -3081,10 +3086,10 @@ def execution_options(self, **opt: Any) -> OptionEngine: shards = {"default": "base", "shard_1": "db1", "shard_2": "db2"} + @event.listens_for(Engine, "before_cursor_execute") - def _switch_shard(conn, cursor, stmt, - params, context, executemany): - shard_id = conn.get_execution_options().get('shard_id', "default") + def _switch_shard(conn, cursor, stmt, params, context, executemany): + shard_id = conn.get_execution_options().get("shard_id", "default") current_shard = conn.info.get("current_shard", None) if current_shard != shard_id: @@ -3210,9 +3215,7 @@ def begin(self) -> Iterator[Connection]: E.g.:: with engine.begin() as conn: - conn.execute( - text("insert into table (x, y, z) values (1, 2, 3)") - ) + conn.execute(text("insert into table (x, y, z) values (1, 2, 3)")) conn.execute(text("my_special_procedure(5)")) Upon successful operation, the :class:`.Transaction` @@ -3228,15 +3231,15 @@ def begin(self) -> Iterator[Connection]: :meth:`_engine.Connection.begin` - start a :class:`.Transaction` for a particular :class:`_engine.Connection`. - """ + """ # noqa: E501 with self.connect() as conn: with conn.begin(): yield conn def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: SchemaVisitable, **kwargs: Any, ) -> None: with self.begin() as conn: diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py index c0feb000be1..322c28b5aa7 100644 --- a/lib/sqlalchemy/engine/characteristics.py +++ b/lib/sqlalchemy/engine/characteristics.py @@ -1,3 +1,9 @@ +# engine/characteristics.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import abc @@ -6,6 +12,7 @@ from typing import ClassVar if typing.TYPE_CHECKING: + from .base import Connection from .interfaces import DBAPIConnection from .interfaces import Dialect @@ -38,13 +45,30 @@ class ConnectionCharacteristic(abc.ABC): def reset_characteristic( self, dialect: Dialect, dbapi_conn: DBAPIConnection ) -> None: - """Reset the characteristic on the connection to its default value.""" + """Reset the characteristic on the DBAPI connection to its default + value.""" @abc.abstractmethod def set_characteristic( self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any ) -> None: - """set characteristic on the connection to a given value.""" + """set characteristic on the DBAPI connection to a given value.""" + + def set_connection_characteristic( + self, + dialect: Dialect, + conn: Connection, + dbapi_conn: DBAPIConnection, + value: Any, + ) -> None: + """set characteristic on the :class:`_engine.Connection` to a given + value. + + .. versionadded:: 2.0.30 - added to support elements that are local + to the :class:`_engine.Connection` itself. + + """ + self.set_characteristic(dialect, dbapi_conn, value) @abc.abstractmethod def get_characteristic( @@ -55,8 +79,22 @@ def get_characteristic( """ + def get_connection_characteristic( + self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection + ) -> Any: + """Given a :class:`_engine.Connection`, get the current value of the + characteristic. + + .. versionadded:: 2.0.30 - added to support elements that are local + to the :class:`_engine.Connection` itself. + + """ + return self.get_characteristic(dialect, dbapi_conn) + class IsolationLevelCharacteristic(ConnectionCharacteristic): + """Manage the isolation level on a DBAPI connection""" + transactional: ClassVar[bool] = True def reset_characteristic( @@ -73,3 +111,45 @@ def get_characteristic( self, dialect: Dialect, dbapi_conn: DBAPIConnection ) -> Any: return dialect.get_isolation_level(dbapi_conn) + + +class LoggingTokenCharacteristic(ConnectionCharacteristic): + """Manage the 'logging_token' option of a :class:`_engine.Connection`. + + .. versionadded:: 2.0.30 + + """ + + transactional: ClassVar[bool] = False + + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: + pass + + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: + raise NotImplementedError() + + def set_connection_characteristic( + self, + dialect: Dialect, + conn: Connection, + dbapi_conn: DBAPIConnection, + value: Any, + ) -> None: + if value: + conn._message_formatter = lambda msg: "[%s] %s" % (value, msg) + else: + del conn._message_formatter + + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: + raise NotImplementedError() + + def get_connection_characteristic( + self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection + ) -> Any: + return conn._execution_options.get("logging_token", None) diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 684550e558c..920f620bd48 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -1,5 +1,5 @@ # engine/create.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -82,13 +82,11 @@ def create_engine( query_cache_size: int = ..., use_insertmanyvalues: bool = ..., **kwargs: Any, -) -> Engine: - ... +) -> Engine: ... @overload -def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: - ... +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ... @util.deprecated_params( @@ -135,8 +133,11 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: and its underlying :class:`.Dialect` and :class:`_pool.Pool` constructs:: - engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname", - pool_recycle=3600, echo=True) + engine = create_engine( + "mysql+mysqldb://scott:tiger@hostname/dbname", + pool_recycle=3600, + echo=True, + ) The string form of the URL is ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where @@ -616,6 +617,14 @@ def pop_kwarg(key: str, default: Optional[Any] = None) -> Any: # assemble connection arguments (cargs_tup, cparams) = dialect.create_connect_args(u) cparams.update(pop_kwarg("connect_args", {})) + + if "async_fallback" in cparams and util.asbool(cparams["async_fallback"]): + util.warn_deprecated( + "The async_fallback dialect argument is deprecated and will be " + "removed in SQLAlchemy 2.1.", + "2.0", + ) + cargs = list(cargs_tup) # allow mutability # look for existing pool or create @@ -657,6 +666,17 @@ def connect( else: pool._dialect = dialect + if ( + hasattr(pool, "_is_asyncio") + and pool._is_asyncio is not dialect.is_async + ): + raise exc.ArgumentError( + f"Pool class {pool.__class__.__name__} cannot be " + f"used with {'non-' if not dialect.is_async else ''}" + "asyncio engine", + code="pcls", + ) + # create engine. if not pop_kwarg("future", True): raise exc.ArgumentError( @@ -816,13 +836,11 @@ def create_pool_from_url( timeout: float = ..., use_lifo: bool = ..., **kwargs: Any, -) -> Pool: - ... +) -> Pool: ... @overload -def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: - ... +def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ... def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 45af49afccb..8e2348efab5 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1,5 +1,5 @@ # engine/cursor.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -20,6 +20,7 @@ from typing import cast from typing import ClassVar from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Mapping @@ -120,7 +121,7 @@ List[Any], # MD_OBJECTS str, # MD_LOOKUP_KEY str, # MD_RENDERED_NAME - Optional["_ResultProcessorType"], # MD_PROCESSOR + Optional["_ResultProcessorType[Any]"], # MD_PROCESSOR Optional[str], # MD_UNTRANSLATED ] @@ -134,7 +135,7 @@ List[Any], str, str, - Optional["_ResultProcessorType"], + Optional["_ResultProcessorType[Any]"], str, ] @@ -151,7 +152,7 @@ class CursorResultMetaData(ResultMetaData): "_translated_indexes", "_safe_for_cache", "_unpickled", - "_key_to_index" + "_key_to_index", # don't need _unique_filters support here for now. Can be added # if a need arises. ) @@ -225,9 +226,11 @@ def _splice_horizontally( { key: ( # int index should be None for ambiguous key - value[0] + offset - if value[0] is not None and key not in keymap - else None, + ( + value[0] + offset + if value[0] is not None and key not in keymap + else None + ), value[1] + offset, *value[2:], ) @@ -362,13 +365,11 @@ def __init__( ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = ( # type: ignore - cols_are_ordered - ) = ( + result_columns = cols_are_ordered = ( # type: ignore num_ctx_cols - ) = ( - ad_hoc_textual - ) = loose_column_name_matching = textual_ordered = False + ) = ad_hoc_textual = loose_column_name_matching = ( + textual_ordered + ) = False # merge cursor.description with the column info # present in the compiled structure, if any @@ -688,6 +689,7 @@ def _merge_textual_cols_by_position( % (num_ctx_cols, len(cursor_description)) ) seen = set() + for ( idx, colname, @@ -1161,7 +1163,7 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): result = conn.execution_options( stream_results=True, max_row_buffer=50 - ).execute(text("select * from table")) + ).execute(text("select * from table")) .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows. @@ -1246,8 +1248,9 @@ def fetchmany(self, result, dbapi_cursor, size=None): if size is None: return self.fetchall(result, dbapi_cursor) - buf = list(self._rowbuffer) - lb = len(buf) + rb = self._rowbuffer + lb = len(rb) + close = False if size > lb: try: new = dbapi_cursor.fetchmany(size - lb) @@ -1255,13 +1258,15 @@ def fetchmany(self, result, dbapi_cursor, size=None): self.handle_exception(result, dbapi_cursor, e) else: if not new: - result._soft_close() + # defer closing since it may clear the row buffer + close = True else: - buf.extend(new) + rb.extend(new) - result = buf[0:size] - self._rowbuffer = collections.deque(buf[size:]) - return result + res = [rb.popleft() for _ in range(min(size, len(rb)))] + if close: + result._soft_close() + return res def fetchall(self, result, dbapi_cursor): try: @@ -1285,12 +1290,16 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): __slots__ = ("_rowbuffer", "alternate_cursor_description") def __init__( - self, dbapi_cursor, alternate_description=None, initial_buffer=None + self, + dbapi_cursor: Optional[DBAPICursor], + alternate_description: Optional[_DBAPICursorDescription] = None, + initial_buffer: Optional[Iterable[Any]] = None, ): self.alternate_cursor_description = alternate_description if initial_buffer is not None: self._rowbuffer = collections.deque(initial_buffer) else: + assert dbapi_cursor is not None self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) def yield_per(self, result, dbapi_cursor, num): @@ -1315,9 +1324,8 @@ def fetchmany(self, result, dbapi_cursor, size=None): if size is None: return self.fetchall(result, dbapi_cursor) - buf = list(self._rowbuffer) - rows = buf[0:size] - self._rowbuffer = collections.deque(buf[size:]) + rb = self._rowbuffer + rows = [rb.popleft() for _ in range(min(size, len(rb)))] if not rows: result._soft_close() return rows @@ -1350,15 +1358,15 @@ def _reduce(self, keys): self._we_dont_return_rows() @property - def _keymap(self): + def _keymap(self): # type: ignore[override] self._we_dont_return_rows() @property - def _key_to_index(self): + def _key_to_index(self): # type: ignore[override] self._we_dont_return_rows() @property - def _processors(self): + def _processors(self): # type: ignore[override] self._we_dont_return_rows() @property @@ -1438,6 +1446,7 @@ def __init__( metadata = self._init_metadata(context, cursor_description) + _make_row: Any _make_row = functools.partial( Row, metadata, @@ -1610,11 +1619,11 @@ def inserted_primary_key_rows(self): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " "expression construct." + "Statement is not a compiled expression construct." ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " "expression construct." + "Statement is not an insert() expression construct." ) elif self.context._is_explicit_returning: raise exc.InvalidRequestError( @@ -1681,11 +1690,11 @@ def last_updated_params(self): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " "expression construct." + "Statement is not a compiled expression construct." ) elif not self.context.isupdate: raise exc.InvalidRequestError( - "Statement is not an update() " "expression construct." + "Statement is not an update() expression construct." ) elif self.context.executemany: return self.context.compiled_parameters @@ -1703,11 +1712,11 @@ def last_inserted_params(self): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " "expression construct." + "Statement is not a compiled expression construct." ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " "expression construct." + "Statement is not an insert() expression construct." ) elif self.context.executemany: return self.context.compiled_parameters @@ -1752,11 +1761,9 @@ def splice_horizontally(self, other): r1 = connection.execute( users.insert().returning( - users.c.user_name, - users.c.user_id, - sort_by_parameter_order=True + users.c.user_name, users.c.user_id, sort_by_parameter_order=True ), - user_values + user_values, ) r2 = connection.execute( @@ -1764,19 +1771,16 @@ def splice_horizontally(self, other): addresses.c.address_id, addresses.c.address, addresses.c.user_id, - sort_by_parameter_order=True + sort_by_parameter_order=True, ), - address_values + address_values, ) rows = r1.splice_horizontally(r2).all() - assert ( - rows == - [ - ("john", 1, 1, "foo@bar.com", 1), - ("jack", 2, 2, "bar@bat.com", 2), - ] - ) + assert rows == [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] .. versionadded:: 2.0 @@ -1785,7 +1789,7 @@ def splice_horizontally(self, other): :meth:`.CursorResult.splice_vertically` - """ + """ # noqa: E501 clone = self._generate() total_rows = [ @@ -1920,7 +1924,7 @@ def postfetch_cols(self): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " "expression construct." + "Statement is not a compiled expression construct." ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( @@ -1943,7 +1947,7 @@ def prefetch_cols(self): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " "expression construct." + "Statement is not a compiled expression construct." ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( @@ -1974,8 +1978,28 @@ def supports_sane_multi_rowcount(self): def rowcount(self) -> int: """Return the 'rowcount' for this result. - The 'rowcount' reports the number of rows *matched* - by the WHERE criterion of an UPDATE or DELETE statement. + The primary purpose of 'rowcount' is to report the number of rows + matched by the WHERE criterion of an UPDATE or DELETE statement + executed once (i.e. for a single parameter set), which may then be + compared to the number of rows expected to be updated or deleted as a + means of asserting data integrity. + + This attribute is transferred from the ``cursor.rowcount`` attribute + of the DBAPI before the cursor is closed, to support DBAPIs that + don't make this value available after cursor close. Some DBAPIs may + offer meaningful values for other kinds of statements, such as INSERT + and SELECT statements as well. In order to retrieve ``cursor.rowcount`` + for these statements, set the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option to True, which will cause the ``cursor.rowcount`` + value to be unconditionally memoized before any results are returned + or the cursor is closed, regardless of statement type. + + For cases where the DBAPI does not support rowcount for a particular + kind of statement and/or execution, the returned value will be ``-1``, + which is delivered directly from the DBAPI and is part of :pep:`249`. + All DBAPIs should support rowcount for single-parameter-set + UPDATE and DELETE statements, however. .. note:: @@ -1984,38 +2008,47 @@ def rowcount(self) -> int: * This attribute returns the number of rows *matched*, which is not necessarily the same as the number of rows - that were actually *modified* - an UPDATE statement, for example, + that were actually *modified*. For example, an UPDATE statement may have no net change on a given row if the SET values given are the same as those present in the row already. Such a row would be matched but not modified. On backends that feature both styles, such as MySQL, - rowcount is configured by default to return the match + rowcount is configured to return the match count in all cases. - * :attr:`_engine.CursorResult.rowcount` - is *only* useful in conjunction - with an UPDATE or DELETE statement. Contrary to what the Python - DBAPI says, it does *not* reliably return the - number of rows available from the results of a SELECT statement - as DBAPIs cannot support this functionality when rows are - unbuffered. - - * :attr:`_engine.CursorResult.rowcount` - may not be fully implemented by - all dialects. In particular, most DBAPIs do not support an - aggregate rowcount result from an executemany call. - The :meth:`_engine.CursorResult.supports_sane_rowcount` and - :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods - will report from the dialect if each usage is known to be - supported. - - * Statements that use RETURNING may not return a correct - rowcount. + * :attr:`_engine.CursorResult.rowcount` in the default case is + *only* useful in conjunction with an UPDATE or DELETE statement, + and only with a single set of parameters. For other kinds of + statements, SQLAlchemy will not attempt to pre-memoize the value + unless the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is used. Note that contrary to :pep:`249`, many + DBAPIs do not support rowcount values for statements that are not + UPDATE or DELETE, particularly when rows are being returned which + are not fully pre-buffered. DBAPIs that dont support rowcount + for a particular kind of statement should return the value ``-1`` + for such statements. + + * :attr:`_engine.CursorResult.rowcount` may not be meaningful + when executing a single statement with multiple parameter sets + (i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount" + values across multiple parameter sets and will return ``-1`` + when accessed. + + * SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support + a correct population of :attr:`_engine.CursorResult.rowcount` + when the :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is set to True. + + * Statements that use RETURNING may not support rowcount, returning + a ``-1`` value instead. .. seealso:: :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` + :paramref:`.Connection.execution_options.preserve_rowcount` + """ # noqa: E501 try: return self.context.rowcount @@ -2109,8 +2142,7 @@ def _raw_row_iterator(self): def merge(self, *others: Result[Any]) -> MergedResult[Any]: merged_result = super().merge(*others) - setup_rowcounts = self.context._has_rowcount - if setup_rowcounts: + if self.context._has_rowcount: merged_result.rowcount = sum( cast("CursorResult[Any]", result).rowcount for result in (self,) + others diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 553d8f0bea1..57759f79cfc 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1,5 +1,5 @@ # engine/default.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,6 +58,7 @@ from ..sql import dml from ..sql import expression from ..sql import type_api +from ..sql import util as sql_util from ..sql._typing import is_tuple_type from ..sql.base import _NoArg from ..sql.compiler import DDLCompiler @@ -76,10 +77,13 @@ from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _MutableCoreSingleExecuteParams from .interfaces import _ParamStyle + from .interfaces import ConnectArgsType from .interfaces import DBAPIConnection + from .interfaces import DBAPIModule from .interfaces import IsolationLevel from .row import Row from .url import URL @@ -95,8 +99,10 @@ from ..sql.elements import BindParameter from ..sql.schema import Column from ..sql.type_api import _BindProcessorType + from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine + # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) @@ -167,7 +173,10 @@ class DefaultDialect(Dialect): tuple_in_values = False connection_characteristics = util.immutabledict( - {"isolation_level": characteristics.IsolationLevelCharacteristic()} + { + "isolation_level": characteristics.IsolationLevelCharacteristic(), + "logging_token": characteristics.LoggingTokenCharacteristic(), + } ) engine_config_types: Mapping[str, Any] = util.immutabledict( @@ -249,7 +258,7 @@ class DefaultDialect(Dialect): default_schema_name: Optional[str] = None # indicates symbol names are - # UPPERCASEd if they are case insensitive + # UPPERCASED if they are case insensitive # within the database. # if this is True, the methods normalize_name() # and denormalize_name() must be provided. @@ -387,7 +396,8 @@ def insert_executemany_returning(self): available if the dialect in use has opted into using the "use_insertmanyvalues" feature. If they haven't opted into that, then this attribute is False, unless the dialect in question overrides this - and provides some other implementation (such as the Oracle dialect). + and provides some other implementation (such as the Oracle Database + dialects). """ return self.insert_returning and self.use_insertmanyvalues @@ -410,7 +420,7 @@ def insert_executemany_returning_sort_by_parameter_order(self): If the dialect in use hasn't opted into that, then this attribute is False, unless the dialect in question overrides this and provides some - other implementation (such as the Oracle dialect). + other implementation (such as the Oracle Database dialects). """ return self.insert_returning and self.use_insertmanyvalues @@ -419,7 +429,7 @@ def insert_executemany_returning_sort_by_parameter_order(self): delete_executemany_returning = False @util.memoized_property - def loaded_dbapi(self) -> ModuleType: + def loaded_dbapi(self) -> DBAPIModule: if self.dbapi is None: raise exc.InvalidRequestError( f"Dialect {self} does not have a Python DBAPI established " @@ -431,7 +441,7 @@ def loaded_dbapi(self) -> ModuleType: def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS - def _ensure_has_table_connection(self, arg): + def _ensure_has_table_connection(self, arg: Connection) -> None: if not isinstance(arg, Connection): raise exc.ArgumentError( "The argument passed to Dialect.has_table() should be a " @@ -468,7 +478,7 @@ def _type_memos(self): return weakref.WeakKeyDictionary() @property - def dialect_description(self): + def dialect_description(self): # type: ignore[override] return self.name + "+" + self.driver @property @@ -509,7 +519,7 @@ def builtin_connect(dbapi_conn, conn_rec): else: return None - def initialize(self, connection): + def initialize(self, connection: Connection) -> None: try: self.server_version_info = self._get_server_version_info( connection @@ -545,7 +555,7 @@ def initialize(self, connection): % (self.label_length, self.max_identifier_length) ) - def on_connect(self): + def on_connect(self) -> Optional[Callable[[Any], None]]: # inherits the docstring from interfaces.Dialect.on_connect return None @@ -604,18 +614,18 @@ def has_schema( ) -> bool: return schema_name in self.get_schema_names(connection, **kw) - def validate_identifier(self, ident): + def validate_identifier(self, ident: str) -> None: if len(ident) > self.max_identifier_length: raise exc.IdentifierError( "Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length) ) - def connect(self, *cargs, **cparams): + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: # inherits the docstring from interfaces.Dialect.connect - return self.loaded_dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501 - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: # inherits the docstring from interfaces.Dialect.create_connect_args opts = url.translate_connect_args() opts.update(url.query) @@ -659,7 +669,7 @@ def _set_connection_characteristics(self, connection, characteristics): if connection.in_transaction(): trans_objs = [ (name, obj) - for name, obj, value in characteristic_values + for name, obj, _ in characteristic_values if obj.transactional ] if trans_objs: @@ -672,8 +682,10 @@ def _set_connection_characteristics(self, connection, characteristics): ) dbapi_connection = connection.connection.dbapi_connection - for name, characteristic, value in characteristic_values: - characteristic.set_characteristic(self, dbapi_connection, value) + for _, characteristic, value in characteristic_values: + characteristic.set_connection_characteristic( + self, connection, dbapi_connection, value + ) connection.connection._connection_record.finalize_callback.append( functools.partial(self._reset_characteristics, characteristics) ) @@ -728,8 +740,6 @@ def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: raise def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: - cursor = None - cursor = dbapi_connection.cursor() try: cursor.execute(self._dialect_specific_select_one) @@ -756,11 +766,25 @@ def do_release_savepoint(self, connection, name): connection.execute(expression.ReleaseSavepointClause(name)) def _deliver_insertmanyvalues_batches( - self, cursor, statement, parameters, generic_setinputsizes, context + self, + connection, + cursor, + statement, + parameters, + generic_setinputsizes, + context, ): context = cast(DefaultExecutionContext, context) compiled = cast(SQLCompiler, context.compiled) + _composite_sentinel_proc: Sequence[ + Optional[_ResultProcessorType[Any]] + ] = () + _scalar_sentinel_proc: Optional[_ResultProcessorType[Any]] = None + _sentinel_proc_initialized: bool = False + + compiled_parameters = context.compiled_parameters + imv = compiled._insertmanyvalues assert imv is not None @@ -769,7 +793,12 @@ def _deliver_insertmanyvalues_batches( "insertmanyvalues_page_size", self.insertmanyvalues_page_size ) - sentinel_value_resolvers = None + if compiled.schema_translate_map: + schema_translate_map = context.execution_options.get( + "schema_translate_map", {} + ) + else: + schema_translate_map = None if is_returning: result: Optional[List[Any]] = [] @@ -777,10 +806,6 @@ def _deliver_insertmanyvalues_batches( sort_by_parameter_order = imv.sort_by_parameter_order - if imv.num_sentinel_columns: - sentinel_value_resolvers = ( - compiled._imv_sentinel_value_resolvers - ) else: sort_by_parameter_order = False result = None @@ -788,14 +813,27 @@ def _deliver_insertmanyvalues_batches( for imv_batch in compiled._deliver_insertmanyvalues_batches( statement, parameters, + compiled_parameters, generic_setinputsizes, batch_size, sort_by_parameter_order, + schema_translate_map, ): yield imv_batch if is_returning: - rows = context.fetchall_for_returning(cursor) + + try: + rows = context.fetchall_for_returning(cursor) + except BaseException as be: + connection._handle_dbapi_exception( + be, + sql_util._long_statement(imv_batch.replaced_statement), + imv_batch.replaced_parameters, + None, + context, + is_sub_exec=True, + ) # I would have thought "is_returning: Final[bool]" # would have assured this but pylance thinks not @@ -815,11 +853,46 @@ def _deliver_insertmanyvalues_batches( # otherwise, create dictionaries to match up batches # with parameters assert imv.sentinel_param_keys + assert imv.sentinel_columns + _nsc = imv.num_sentinel_columns + + if not _sentinel_proc_initialized: + if composite_sentinel: + _composite_sentinel_proc = [ + col.type._cached_result_processor( + self, cursor_desc[1] + ) + for col, cursor_desc in zip( + imv.sentinel_columns, + cursor.description[-_nsc:], + ) + ] + else: + _scalar_sentinel_proc = ( + imv.sentinel_columns[0] + ).type._cached_result_processor( + self, cursor.description[-1][1] + ) + _sentinel_proc_initialized = True + + rows_by_sentinel: Union[ + Dict[Tuple[Any, ...], Any], + Dict[Any, Any], + ] if composite_sentinel: - _nsc = imv.num_sentinel_columns rows_by_sentinel = { - tuple(row[-_nsc:]): row for row in rows + tuple( + (proc(val) if proc else val) + for val, proc in zip( + row[-_nsc:], _composite_sentinel_proc + ) + ): row + for row in rows + } + elif _scalar_sentinel_proc: + rows_by_sentinel = { + _scalar_sentinel_proc(row[-1]): row for row in rows } else: rows_by_sentinel = {row[-1]: row for row in rows} @@ -838,61 +911,10 @@ def _deliver_insertmanyvalues_batches( ) try: - if composite_sentinel: - if sentinel_value_resolvers: - # composite sentinel (PK) with value resolvers - ordered_rows = [ - rows_by_sentinel[ - tuple( - _resolver(parameters[_spk]) # type: ignore # noqa: E501 - if _resolver - else parameters[_spk] # type: ignore # noqa: E501 - for _resolver, _spk in zip( - sentinel_value_resolvers, - imv.sentinel_param_keys, - ) - ) - ] - for parameters in imv_batch.batch - ] - else: - # composite sentinel (PK) with no value - # resolvers - ordered_rows = [ - rows_by_sentinel[ - tuple( - parameters[_spk] # type: ignore - for _spk in imv.sentinel_param_keys - ) - ] - for parameters in imv_batch.batch - ] - else: - _sentinel_param_key = imv.sentinel_param_keys[0] - if ( - sentinel_value_resolvers - and sentinel_value_resolvers[0] - ): - # single-column sentinel with value resolver - _sentinel_value_resolver = ( - sentinel_value_resolvers[0] - ) - ordered_rows = [ - rows_by_sentinel[ - _sentinel_value_resolver( - parameters[_sentinel_param_key] # type: ignore # noqa: E501 - ) - ] - for parameters in imv_batch.batch - ] - else: - # single-column sentinel with no value resolver - ordered_rows = [ - rows_by_sentinel[ - parameters[_sentinel_param_key] # type: ignore # noqa: E501 - ] - for parameters in imv_batch.batch - ] + ordered_rows = [ + rows_by_sentinel[sentinel_keys] + for sentinel_keys in imv_batch.sentinel_values + ] except KeyError as ke: # see test_insert_exec.py:: # IMVSentinelTest::test_sentinel_cant_match_keys @@ -924,7 +946,14 @@ def do_execute(self, cursor, statement, parameters, context=None): def do_execute_no_params(self, cursor, statement, context=None): cursor.execute(statement) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: DBAPIModule.Error, + connection: Union[ + pool.PoolProxiedConnection, interfaces.DBAPIConnection, None + ], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: return False @util.memoized_instancemethod @@ -1024,7 +1053,7 @@ def denormalize_name(self, name): name = name_upper return name - def get_driver_connection(self, connection): + def get_driver_connection(self, connection: DBAPIConnection) -> Any: return connection def _overrides_default(self, method): @@ -1196,7 +1225,7 @@ class DefaultExecutionContext(ExecutionContext): _soft_closed = False - _has_rowcount = False + _rowcount: Optional[int] = None # a hook for SQLite's translation of # result column names @@ -1453,9 +1482,11 @@ def _init_compiled( assert positiontup is not None for compiled_params in self.compiled_parameters: l_param: List[Any] = [ - flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] + ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) for key in positiontup ] core_positional_parameters.append( @@ -1476,18 +1507,20 @@ def _init_compiled( for compiled_params in self.compiled_parameters: if escaped_names: d_param = { - escaped_names.get(key, key): flattened_processors[key]( - compiled_params[key] + escaped_names.get(key, key): ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] ) - if key in flattened_processors - else compiled_params[key] for key in compiled_params } else: d_param = { - key: flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] + key: ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) for key in compiled_params } @@ -1577,7 +1610,13 @@ def _get_cache_stats(self) -> str: elif ch is CACHE_MISS: return "generated in %.5fs" % (now - gen_time,) elif ch is CACHING_DISABLED: - return "caching disabled %.5fs" % (now - gen_time,) + if "_cache_disable_reason" in self.execution_options: + return "caching disabled (%s) %.5fs " % ( + self.execution_options["_cache_disable_reason"], + now - gen_time, + ) + else: + return "caching disabled %.5fs" % (now - gen_time,) elif ch is NO_DIALECT_SUPPORT: return "dialect %s+%s does not support caching %.5fs" % ( self.dialect.name, @@ -1588,7 +1627,7 @@ def _get_cache_stats(self) -> str: return "unknown" @property - def executemany(self): + def executemany(self): # type: ignore[override] return self.execute_style in ( ExecuteStyle.EXECUTEMANY, ExecuteStyle.INSERTMANYVALUES, @@ -1630,7 +1669,12 @@ def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: def no_parameters(self): return self.execution_options.get("no_parameters", False) - def _execute_scalar(self, stmt, type_, parameters=None): + def _execute_scalar( + self, + stmt: str, + type_: Optional[TypeEngine[Any]], + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: """Execute a string statement on the current cursor, returning a scalar result. @@ -1704,7 +1748,7 @@ def _use_server_side_cursor(self): return use_server_side - def create_cursor(self): + def create_cursor(self) -> DBAPICursor: if ( # inlining initial preference checks for SS cursors self.dialect.supports_server_side_cursors @@ -1725,10 +1769,10 @@ def create_cursor(self): def fetchall_for_returning(self, cursor): return cursor.fetchall() - def create_default_cursor(self): + def create_default_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor() - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: raise NotImplementedError() def pre_exec(self): @@ -1776,7 +1820,14 @@ def handle_dbapi_exception(self, e): @util.non_memoized_property def rowcount(self) -> int: - return self.cursor.rowcount + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + @property + def _has_rowcount(self): + return self._rowcount is not None def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount @@ -1787,9 +1838,13 @@ def supports_sane_multi_rowcount(self): def _setup_result_proxy(self): exec_opt = self.execution_options + if self._rowcount is None and exec_opt.get("preserve_rowcount", False): + self._rowcount = self.cursor.rowcount + + yp: Optional[Union[int, bool]] if self.is_crud or self.is_text: result = self._setup_dml_or_text_result() - yp = sr = False + yp = False else: yp = exec_opt.get("yield_per", None) sr = self._is_server_side or exec_opt.get("stream_results", False) @@ -1943,8 +1998,7 @@ def _setup_dml_or_text_result(self): if rows: self.returned_default_rows = rows - result.rowcount = len(rows) - self._has_rowcount = True + self._rowcount = len(rows) if self._is_supplemental_returning: result._rewind(rows) @@ -1958,12 +2012,12 @@ def _setup_dml_or_text_result(self): elif not result._metadata.returns_rows: # no results, get rowcount # (which requires open cursor on some drivers) - result.rowcount - self._has_rowcount = True + if self._rowcount is None: + self._rowcount = self.cursor.rowcount result._soft_close() elif self.isupdate or self.isdelete: - result.rowcount - self._has_rowcount = True + if self._rowcount is None: + self._rowcount = self.cursor.rowcount return result @util.memoized_property @@ -2012,10 +2066,11 @@ def _prepare_set_input_sizes( style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. - This method only called by those dialects which set - the :attr:`.Dialect.bind_typing` attribute to - :attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI - that requires setinputsizes(), pyodbc offers it as an option. + This method only called by those dialects which set the + :attr:`.Dialect.bind_typing` attribute to + :attr:`.BindTyping.SETINPUTSIZES`. Python-oracledb and cx_Oracle are + the only DBAPIs that requires setinputsizes(); pyodbc offers it as an + option. Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used for pg8000 and asyncpg, which has been changed to inline rendering @@ -2143,17 +2198,21 @@ def _exec_default_clause_element(self, column, default, type_): if compiled.positional: parameters = self.dialect.execute_sequence_format( [ - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] + ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) for key in compiled.positiontup or () ] ) else: parameters = { - key: processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] + key: ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) for key in compiled_params } return self._execute_scalar( diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index aac756d18a2..b759382cb27 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -1,5 +1,5 @@ -# sqlalchemy/engine/events.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# engine/events.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -54,19 +54,24 @@ class or instance, such as an :class:`_engine.Engine`, e.g.:: from sqlalchemy import event, create_engine - def before_cursor_execute(conn, cursor, statement, parameters, context, - executemany): + + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): log.info("Received statement: %s", statement) - engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test') + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") event.listen(engine, "before_cursor_execute", before_cursor_execute) or with a specific :class:`_engine.Connection`:: with engine.begin() as conn: - @event.listens_for(conn, 'before_cursor_execute') - def before_cursor_execute(conn, cursor, statement, parameters, - context, executemany): + + @event.listens_for(conn, "before_cursor_execute") + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): log.info("Received statement: %s", statement) When the methods are called with a `statement` parameter, such as in @@ -84,9 +89,11 @@ def before_cursor_execute(conn, cursor, statement, parameters, from sqlalchemy.engine import Engine from sqlalchemy import event + @event.listens_for(Engine, "before_cursor_execute", retval=True) - def comment_sql_calls(conn, cursor, statement, parameters, - context, executemany): + def comment_sql_calls( + conn, cursor, statement, parameters, context, executemany + ): statement = statement + " -- some comment" return statement, parameters @@ -316,8 +323,9 @@ def before_cursor_execute( returned as a two-tuple in this case:: @event.listens_for(Engine, "before_cursor_execute", retval=True) - def before_cursor_execute(conn, cursor, statement, - parameters, context, executemany): + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): # do something with statement, parameters return statement, parameters @@ -766,9 +774,9 @@ def handle_error( @event.listens_for(Engine, "handle_error") def handle_exception(context): - if isinstance(context.original_exception, - psycopg2.OperationalError) and \ - "failed" in str(context.original_exception): + if isinstance( + context.original_exception, psycopg2.OperationalError + ) and "failed" in str(context.original_exception): raise MySpecialException("failed operation") .. warning:: Because the @@ -791,10 +799,13 @@ def handle_exception(context): @event.listens_for(Engine, "handle_error", retval=True) def handle_exception(context): - if context.chained_exception is not None and \ - "special" in context.chained_exception.message: - return MySpecialException("failed", - cause=context.chained_exception) + if ( + context.chained_exception is not None + and "special" in context.chained_exception.message + ): + return MySpecialException( + "failed", cause=context.chained_exception + ) Handlers that return ``None`` may be used within the chain; when a handler returns ``None``, the previous exception instance, @@ -836,7 +847,8 @@ def do_connect( e = create_engine("postgresql+psycopg2://user@host/dbname") - @event.listens_for(e, 'do_connect') + + @event.listens_for(e, "do_connect") def receive_do_connect(dialect, conn_rec, cargs, cparams): cparams["password"] = "some_password" @@ -845,7 +857,8 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams): e = create_engine("postgresql+psycopg2://user@host/dbname") - @event.listens_for(e, 'do_connect') + + @event.listens_for(e, "do_connect") def receive_do_connect(dialect, conn_rec, cargs, cparams): return psycopg2.connect(*cargs, **cparams) @@ -928,7 +941,8 @@ def do_setinputsizes( The setinputsizes hook overall is only used for dialects which include the flag ``use_setinputsizes=True``. Dialects which use this - include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. + include python-oracledb, cx_Oracle, pg8000, asyncpg, and pyodbc + dialects. .. note:: diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index ea1f27d0629..fd99afafd09 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -1,5 +1,5 @@ # engine/interfaces.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,7 +10,6 @@ from __future__ import annotations from enum import Enum -from types import ModuleType from typing import Any from typing import Awaitable from typing import Callable @@ -34,7 +33,7 @@ from .. import util from ..event import EventTarget from ..pool import Pool -from ..pool import PoolProxiedConnection +from ..pool import PoolProxiedConnection as PoolProxiedConnection from ..sql.compiler import Compiled as Compiled from ..sql.compiler import Compiled # noqa from ..sql.compiler import TypeCompiler as TypeCompiler @@ -51,6 +50,7 @@ from .base import Engine from .cursor import CursorResult from .url import URL + from ..connectors.asyncio import AsyncIODBAPIConnection from ..event import _ListenerFnType from ..event import dispatcher from ..exc import StatementError @@ -70,6 +70,7 @@ from ..sql.sqltypes import Integer from ..sql.type_api import _TypeMemoDict from ..sql.type_api import TypeEngine + from ..util.langhelpers import generic_fn_descriptor ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]] @@ -106,6 +107,22 @@ class ExecuteStyle(Enum): """ +class DBAPIModule(Protocol): + class Error(Exception): + def __getattr__(self, key: str) -> Any: ... + + class OperationalError(Error): + pass + + class InterfaceError(Error): + pass + + class IntegrityError(Error): + pass + + def __getattr__(self, key: str) -> Any: ... + + class DBAPIConnection(Protocol): """protocol representing a :pep:`249` database connection. @@ -118,19 +135,17 @@ class DBAPIConnection(Protocol): """ # noqa: E501 - def close(self) -> None: - ... + def close(self) -> None: ... - def commit(self) -> None: - ... + def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: - ... + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... - def rollback(self) -> None: - ... + def rollback(self) -> None: ... + + def __getattr__(self, key: str) -> Any: ... - autocommit: bool + def __setattr__(self, key: str, value: Any) -> None: ... class DBAPIType(Protocol): @@ -174,53 +189,43 @@ def description( ... @property - def rowcount(self) -> int: - ... + def rowcount(self) -> int: ... arraysize: int lastrowid: int - def close(self) -> None: - ... + def close(self) -> None: ... def execute( self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: - ... + ) -> Any: ... def executemany( self, operation: Any, - parameters: Sequence[_DBAPIMultiExecuteParams], - ) -> Any: - ... + parameters: _DBAPIMultiExecuteParams, + ) -> Any: ... - def fetchone(self) -> Optional[Any]: - ... + def fetchone(self) -> Optional[Any]: ... - def fetchmany(self, size: int = ...) -> Sequence[Any]: - ... + def fetchmany(self, size: int = ...) -> Sequence[Any]: ... - def fetchall(self) -> Sequence[Any]: - ... + def fetchall(self) -> Sequence[Any]: ... - def setinputsizes(self, sizes: Sequence[Any]) -> None: - ... + def setinputsizes(self, sizes: Sequence[Any]) -> None: ... - def setoutputsize(self, size: Any, column: Any) -> None: - ... + def setoutputsize(self, size: Any, column: Any) -> None: ... - def callproc(self, procname: str, parameters: Sequence[Any] = ...) -> Any: - ... + def callproc( + self, procname: str, parameters: Sequence[Any] = ... + ) -> Any: ... - def nextset(self) -> Optional[bool]: - ... + def nextset(self) -> Optional[bool]: ... - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... _CoreSingleExecuteParams = Mapping[str, Any] @@ -284,6 +289,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False): yield_per: int insertmanyvalues_page_size: int schema_translate_map: Optional[SchemaTranslateMapType] + preserve_rowcount: bool _ExecuteOptions = immutabledict[str, Any] @@ -593,8 +599,8 @@ class BindTyping(Enum): """Use the pep-249 setinputsizes method. This is only implemented for DBAPIs that support this method and for which - the SQLAlchemy dialect has the appropriate infrastructure for that - dialect set up. Current dialects include cx_Oracle as well as + the SQLAlchemy dialect has the appropriate infrastructure for that dialect + set up. Current dialects include python-oracledb, cx_Oracle as well as optional support for SQL Server using pyodbc. When using setinputsizes, dialects also have a means of only using the @@ -671,7 +677,7 @@ class Dialect(EventTarget): dialect_description: str - dbapi: Optional[ModuleType] + dbapi: Optional[DBAPIModule] """A reference to the DBAPI module object itself. SQLAlchemy dialects import DBAPI modules using the classmethod @@ -695,7 +701,7 @@ class Dialect(EventTarget): """ @util.non_memoized_property - def loaded_dbapi(self) -> ModuleType: + def loaded_dbapi(self) -> DBAPIModule: """same as .dbapi, but is never None; will raise an error if no DBAPI was set up. @@ -792,8 +798,14 @@ def loaded_dbapi(self) -> ModuleType: max_identifier_length: int """The maximum length of identifier names.""" - - supports_server_side_cursors: bool + max_index_name_length: Optional[int] + """The maximum length of index names if different from + ``max_identifier_length``.""" + max_constraint_name_length: Optional[int] + """The maximum length of constraint names if different from + ``max_identifier_length``.""" + + supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool] """indicates if the dialect supports server side cursors""" server_side_cursors: bool @@ -884,12 +896,12 @@ def loaded_dbapi(self) -> ModuleType: the statement multiple times for a series of batches when large numbers of rows are given. - The parameter is False for the default dialect, and is set to - True for SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, - SQL Server. It remains at False for Oracle, which provides native - "executemany with RETURNING" support and also does not support - ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL - dialects that don't support RETURNING will not report + The parameter is False for the default dialect, and is set to True for + SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, SQL Server. + It remains at False for Oracle Database, which provides native "executemany + with RETURNING" support and also does not support + ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL dialects + that don't support RETURNING will not report ``insert_executemany_returning`` as True. .. versionadded:: 2.0 @@ -1073,11 +1085,7 @@ def loaded_dbapi(self) -> ModuleType: To implement, establish as a series of tuples, as in:: construct_arguments = [ - (schema.Index, { - "using": False, - "where": None, - "ops": None - }) + (schema.Index, {"using": False, "where": None, "ops": None}), ] If the above construct is established on the PostgreSQL dialect, @@ -1106,7 +1114,8 @@ def loaded_dbapi(self) -> ModuleType: established on a :class:`.Table` object which will be passed as "reflection options" when using :paramref:`.Table.autoload_with`. - Current example is "oracle_resolve_synonyms" in the Oracle dialect. + Current example is "oracle_resolve_synonyms" in the Oracle Database + dialects. """ @@ -1130,7 +1139,7 @@ def loaded_dbapi(self) -> ModuleType: supports_constraint_comments: bool """Indicates if the dialect supports comment DDL on constraints. - .. versionadded: 2.0 + .. versionadded:: 2.0 """ _has_events = False @@ -1249,7 +1258,7 @@ def create_connect_args(self, url): raise NotImplementedError() @classmethod - def import_dbapi(cls) -> ModuleType: + def import_dbapi(cls) -> DBAPIModule: """Import the DBAPI module that is used by this dialect. The Python module object returned here will be assigned as an @@ -1266,8 +1275,7 @@ def import_dbapi(cls) -> ModuleType: """ raise NotImplementedError() - @classmethod - def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: + def type_descriptor(self, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -1299,12 +1307,9 @@ def initialize(self, connection: Connection) -> None: """ - pass - if TYPE_CHECKING: - def _overrides_default(self, method_name: str) -> bool: - ... + def _overrides_default(self, method_name: str) -> bool: ... def get_columns( self, @@ -1330,6 +1335,7 @@ def get_columns( def get_multi_columns( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1378,6 +1384,7 @@ def get_pk_constraint( def get_multi_pk_constraint( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1424,6 +1431,7 @@ def get_foreign_keys( def get_multi_foreign_keys( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1583,6 +1591,7 @@ def get_indexes( def get_multi_indexes( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1629,6 +1638,7 @@ def get_unique_constraints( def get_multi_unique_constraints( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1676,6 +1686,7 @@ def get_check_constraints( def get_multi_check_constraints( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1718,6 +1729,7 @@ def get_table_options( def get_multi_table_options( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1769,6 +1781,7 @@ def get_table_comment( def get_multi_table_comment( self, connection: Connection, + *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -2161,6 +2174,7 @@ def do_recover_twophase(self, connection: Connection) -> List[Any]: def _deliver_insertmanyvalues_batches( self, + connection: Connection, cursor: DBAPICursor, statement: str, parameters: _DBAPIMultiExecuteParams, @@ -2214,7 +2228,7 @@ def do_execute_no_params( def is_disconnect( self, - e: Exception, + e: DBAPIModule.Error, connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], cursor: Optional[DBAPICursor], ) -> bool: @@ -2318,7 +2332,7 @@ def do_on_connect(connection): """ return self.on_connect() - def on_connect(self) -> Optional[Callable[[Any], Any]]: + def on_connect(self) -> Optional[Callable[[Any], None]]: """return a callable which sets up a newly created DBAPI connection. The callable should accept a single argument "conn" which is the @@ -2491,7 +2505,7 @@ def get_default_isolation_level( def get_isolation_level_values( self, dbapi_conn: DBAPIConnection - ) -> List[IsolationLevel]: + ) -> Sequence[IsolationLevel]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -2504,7 +2518,7 @@ def get_isolation_level_values( ``REPEATABLE READ``. isolation level names will have underscores converted to spaces before being passed along to the dialect. * The names for the four standard isolation names to the extent that - they are supported by the backend should be ``READ UNCOMMITTED`` + they are supported by the backend should be ``READ UNCOMMITTED``, ``READ COMMITTED``, ``REPEATABLE READ``, ``SERIALIZABLE`` * if the dialect supports an autocommit option it should be provided using the isolation level name ``AUTOCOMMIT``. @@ -2665,6 +2679,9 @@ def get_dialect_pool_class(self, url: URL) -> Type[Pool]: """return a Pool class to use for a given URL""" raise NotImplementedError() + def validate_identifier(self, ident: str) -> None: + """Validates an identifier name, raising an exception if invalid""" + class CreateEnginePlugin: """A set of hooks intended to augment the construction of an @@ -2690,11 +2707,14 @@ class CreateEnginePlugin: from sqlalchemy.engine import CreateEnginePlugin from sqlalchemy import event + class LogCursorEventsPlugin(CreateEnginePlugin): def __init__(self, url, kwargs): # consume the parameter "log_cursor_logging_name" from the # URL query - logging_name = url.query.get("log_cursor_logging_name", "log_cursor") + logging_name = url.query.get( + "log_cursor_logging_name", "log_cursor" + ) self.log = logging.getLogger(logging_name) @@ -2706,7 +2726,6 @@ def engine_created(self, engine): "attach an event listener after the new Engine is constructed" event.listen(engine, "before_cursor_execute", self._log_event) - def _log_event( self, conn, @@ -2714,19 +2733,19 @@ def _log_event( statement, parameters, context, - executemany): + executemany, + ): self.log.info("Plugin logged cursor event: %s", statement) - - Plugins are registered using entry points in a similar way as that of dialects:: - entry_points={ - 'sqlalchemy.plugins': [ - 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin' + entry_points = { + "sqlalchemy.plugins": [ + "log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin" ] + } A plugin that uses the above names would be invoked from a database URL as in:: @@ -2743,15 +2762,16 @@ def _log_event( in the URL:: engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?" - "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three") + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three" + ) The plugin names may also be passed directly to :func:`_sa.create_engine` using the :paramref:`_sa.create_engine.plugins` argument:: engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test", - plugins=["myplugin"]) + "mysql+pymysql://scott:tiger@localhost/test", plugins=["myplugin"] + ) .. versionadded:: 1.2.3 plugin names can also be specified to :func:`_sa.create_engine` as a list @@ -2773,9 +2793,9 @@ def _log_event( class MyPlugin(CreateEnginePlugin): def __init__(self, url, kwargs): - self.my_argument_one = url.query['my_argument_one'] - self.my_argument_two = url.query['my_argument_two'] - self.my_argument_three = kwargs.pop('my_argument_three', None) + self.my_argument_one = url.query["my_argument_one"] + self.my_argument_two = url.query["my_argument_two"] + self.my_argument_three = kwargs.pop("my_argument_three", None) def update_url(self, url): return url.difference_update_query( @@ -2788,9 +2808,9 @@ def update_url(self, url): from sqlalchemy import create_engine engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?" - "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", - my_argument_three='bat' + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", + my_argument_three="bat", ) .. versionchanged:: 1.4 @@ -2809,15 +2829,15 @@ class MyPlugin(CreateEnginePlugin): def __init__(self, url, kwargs): if hasattr(CreateEnginePlugin, "update_url"): # detect the 1.4 API - self.my_argument_one = url.query['my_argument_one'] - self.my_argument_two = url.query['my_argument_two'] + self.my_argument_one = url.query["my_argument_one"] + self.my_argument_two = url.query["my_argument_two"] else: # detect the 1.3 and earlier API - mutate the # URL directly - self.my_argument_one = url.query.pop('my_argument_one') - self.my_argument_two = url.query.pop('my_argument_two') + self.my_argument_one = url.query.pop("my_argument_one") + self.my_argument_two = url.query.pop("my_argument_two") - self.my_argument_three = kwargs.pop('my_argument_three', None) + self.my_argument_three = kwargs.pop("my_argument_three", None) def update_url(self, url): # this method is only called in the 1.4 version @@ -2992,6 +3012,9 @@ class ExecutionContext: inline SQL expression value was fired off. Applies to inserts and updates.""" + execution_options: _ExecuteOptions + """Execution options associated with the current statement execution""" + @classmethod def _init_ddl( cls, @@ -3366,7 +3389,7 @@ class AdaptedConnection: __slots__ = ("_connection",) - _connection: Any + _connection: AsyncIODBAPIConnection @property def driver_connection(self) -> Any: @@ -3385,11 +3408,14 @@ def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T: engine = create_async_engine(...) + @event.listens_for(engine.sync_engine, "connect") - def register_custom_types(dbapi_connection, ...): + def register_custom_types( + dbapi_connection, # ... + ): dbapi_connection.run_async( lambda connection: connection.set_type_codec( - 'MyCustomType', encoder, decoder, ... + "MyCustomType", encoder, decoder, ... ) ) diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 618ea1d85ef..a96af36ccda 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -1,5 +1,5 @@ # engine/mock.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -27,10 +27,9 @@ from .interfaces import Dialect from .url import URL from ..sql.base import Executable - from ..sql.ddl import SchemaDropper - from ..sql.ddl import SchemaGenerator + from ..sql.ddl import InvokeDDLBase from ..sql.schema import HasSchemaAttr - from ..sql.schema import SchemaItem + from ..sql.visitors import Visitable class MockConnection: @@ -53,12 +52,14 @@ def execution_options(self, **kw: Any) -> MockConnection: def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: Visitable, **kwargs: Any, ) -> None: kwargs["checkfirst"] = False - visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + visitorcallable( + dialect=self.dialect, connection=self, **kwargs + ).traverse_single(element) def execute( self, @@ -90,10 +91,12 @@ def create_mock_engine( from sqlalchemy import create_mock_engine + def dump(sql, *multiparams, **params): print(sql.compile(dialect=engine.dialect)) - engine = create_mock_engine('postgresql+psycopg2://', dump) + + engine = create_mock_engine("postgresql+psycopg2://", dump) metadata.create_all(engine, checkfirst=False) :param url: A string URL which typically needs to contain only the diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index c01d3b74064..b3f9330842d 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -1,5 +1,5 @@ -# sqlalchemy/processors.py -# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors +# engine/processors.py +# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors # # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 6d2a8a29fd8..23009c64a4c 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1,5 +1,5 @@ # engine/reflection.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -55,6 +55,7 @@ from ..sql import operators from ..sql import schema as sa_schema from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import quoted_name from ..sql.elements import TextClause from ..sql.type_api import TypeEngine from ..sql.visitors import InternalTraversal @@ -89,8 +90,16 @@ def cache( exclude = {"info_cache", "unreflectable"} key = ( fn.__name__, - tuple(a for a in args if isinstance(a, str)), - tuple((k, v) for k, v in kw.items() if k not in exclude), + tuple( + (str(a), a.quote) if isinstance(a, quoted_name) else a + for a in args + if isinstance(a, str) + ), + tuple( + (k, (str(v), v.quote) if isinstance(v, quoted_name) else v) + for k, v in kw.items() + if k not in exclude + ), ) ret: _R = info_cache.get(key) if ret is None: @@ -184,7 +193,8 @@ class Inspector(inspection.Inspectable["Inspector"]): or a :class:`_engine.Connection`:: from sqlalchemy import inspect, create_engine - engine = create_engine('...') + + engine = create_engine("...") insp = inspect(engine) Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated @@ -621,7 +631,7 @@ def get_temp_table_names(self, **kw: Any) -> List[str]: r"""Return a list of temporary table names for the current bind. This method is unsupported by most dialects; currently - only Oracle, PostgreSQL and SQLite implements it. + only Oracle Database, PostgreSQL and SQLite implements it. :param \**kw: Additional keyword argument to pass to the dialect specific implementation. See the documentation of the dialect @@ -657,7 +667,7 @@ def get_table_options( given name was created. This currently includes some options that apply to MySQL and Oracle - tables. + Database tables. :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -1483,9 +1493,9 @@ def reflect_table( from sqlalchemy import create_engine, MetaData, Table from sqlalchemy import inspect - engine = create_engine('...') + engine = create_engine("...") meta = MetaData() - user_table = Table('user', meta) + user_table = Table("user", meta) insp = inspect(engine) insp.reflect_table(user_table, None) @@ -1704,9 +1714,12 @@ def _reflect_pk( if pk in cols_by_orig_name and pk not in exclude_columns ] - # update pk constraint name and comment + # update pk constraint name, comment and dialect_kwargs table.primary_key.name = pk_cons.get("name") table.primary_key.comment = pk_cons.get("comment", None) + dialect_options = pk_cons.get("dialect_options") + if dialect_options: + table.primary_key.dialect_kwargs.update(dialect_options) # tell the PKConstraint to re-initialize # its column collection @@ -1843,7 +1856,7 @@ def _reflect_indexes( if not expressions: util.warn( f"Skipping {flavor} {name!r} because key " - f"{index+1} reflected as None but no " + f"{index + 1} reflected as None but no " "'expressions' were returned" ) break diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 132ae88b660..b84fb3d1cb5 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1,5 +1,5 @@ # engine/result.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -52,11 +52,11 @@ from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter if typing.TYPE_CHECKING: - from ..sql.schema import Column + from ..sql.elements import SQLCoreOperations from ..sql.type_api import _ResultProcessorType -_KeyType = Union[str, "Column[Any]"] -_KeyIndexType = Union[str, "Column[Any]", int] +_KeyType = Union[str, "SQLCoreOperations[Any]"] +_KeyIndexType = Union[_KeyType, int] # is overridden in cursor using _CursorKeyMapRecType _KeyMapRecType = Any @@ -64,7 +64,7 @@ _KeyMapType = Mapping[_KeyType, _KeyMapRecType] -_RowData = Union[Row, RowMapping, Any] +_RowData = Union[Row[Any], RowMapping, Any] """A generic form of "row" that accommodates for the different kinds of "rows" that different result objects return, including row, row mapping, and scalar values""" @@ -82,7 +82,7 @@ """ -_InterimSupportsScalarsRowType = Union[Row, Any] +_InterimSupportsScalarsRowType = Union[Row[Any], Any] _ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]] _TupleGetterType = Callable[[Sequence[Any]], Sequence[Any]] @@ -116,8 +116,7 @@ def _for_freeze(self) -> ResultMetaData: @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... - ) -> NoReturn: - ... + ) -> NoReturn: ... @overload def _key_fallback( @@ -125,14 +124,12 @@ def _key_fallback( key: Any, err: Optional[Exception], raiseerr: Literal[False] = ..., - ) -> None: - ... + ) -> None: ... @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = ... - ) -> Optional[NoReturn]: - ... + ) -> Optional[NoReturn]: ... def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = True @@ -329,9 +326,6 @@ def __setstate__(self, state: Dict[str, Any]) -> None: _tuplefilter=_tuplefilter, ) - def _contains(self, value: Any, row: Row[Any]) -> bool: - return value in row._data - def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: if int in key.__class__.__mro__: key = self._keys[key] @@ -728,14 +722,21 @@ def manyrows( return manyrows + @overload + def _only_one_row( + self: ResultInternal[Row[Any]], + raise_for_second_row: bool, + raise_for_none: bool, + scalar: Literal[True], + ) -> Any: ... + @overload def _only_one_row( self, raise_for_second_row: bool, raise_for_none: Literal[True], scalar: bool, - ) -> _R: - ... + ) -> _R: ... @overload def _only_one_row( @@ -743,8 +744,7 @@ def _only_one_row( raise_for_second_row: bool, raise_for_none: bool, scalar: bool, - ) -> Optional[_R]: - ... + ) -> Optional[_R]: ... def _only_one_row( self, @@ -817,7 +817,6 @@ def _only_one_row( "was required" ) else: - next_row = _NO_ROW # if we checked for second row then that would have # closed us :) self._soft_close(hard=True) @@ -1107,17 +1106,15 @@ def columns(self, *col_expressions: _KeyIndexType) -> Self: statement = select(table.c.x, table.c.y, table.c.z) result = connection.execute(statement) - for z, y in result.columns('z', 'y'): - # ... - + for z, y in result.columns("z", "y"): + ... Example of using the column objects from the statement itself:: for z, y in result.columns( - statement.selected_columns.c.z, - statement.selected_columns.c.y + statement.selected_columns.c.z, statement.selected_columns.c.y ): - # ... + ... .. versionadded:: 1.4 @@ -1132,18 +1129,15 @@ def columns(self, *col_expressions: _KeyIndexType) -> Self: return self._column_slices(col_expressions) @overload - def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: - ... + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: ... @overload def scalars( self: Result[Tuple[_T]], index: Literal[0] - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload - def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: - ... + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: ... def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: """Return a :class:`_engine.ScalarResult` filtering object which @@ -1352,7 +1346,7 @@ def fetchone(self) -> Optional[Row[_TP]]: def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: """Fetch many rows. - When all rows are exhausted, returns an empty list. + When all rows are exhausted, returns an empty sequence. This method is provided for backwards compatibility with SQLAlchemy 1.x.x. @@ -1360,7 +1354,7 @@ def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: To fetch rows in groups, use the :meth:`_engine.Result.partitions` method. - :return: a list of :class:`_engine.Row` objects. + :return: a sequence of :class:`_engine.Row` objects. .. seealso:: @@ -1371,14 +1365,14 @@ def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: return self._manyrow_getter(self, size) def all(self) -> Sequence[Row[_TP]]: - """Return all rows in a list. + """Return all rows in a sequence. Closes the result set after invocation. Subsequent invocations - will return an empty list. + will return an empty sequence. .. versionadded:: 1.4 - :return: a list of :class:`_engine.Row` objects. + :return: a sequence of :class:`_engine.Row` objects. .. seealso:: @@ -1454,22 +1448,20 @@ def one_or_none(self) -> Optional[Row[_TP]]: ) @overload - def scalar_one(self: Result[Tuple[_T]]) -> _T: - ... + def scalar_one(self: Result[Tuple[_T]]) -> _T: ... @overload - def scalar_one(self) -> Any: - ... + def scalar_one(self) -> Any: ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` and - then :meth:`_engine.Result.one`. + then :meth:`_engine.ScalarResult.one`. .. seealso:: - :meth:`_engine.Result.one` + :meth:`_engine.ScalarResult.one` :meth:`_engine.Result.scalars` @@ -1479,22 +1471,20 @@ def scalar_one(self) -> Any: ) @overload - def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: - ... + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: ... @overload - def scalar_one_or_none(self) -> Optional[Any]: - ... + def scalar_one_or_none(self) -> Optional[Any]: ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. This is equivalent to calling :meth:`_engine.Result.scalars` and - then :meth:`_engine.Result.one_or_none`. + then :meth:`_engine.ScalarResult.one_or_none`. .. seealso:: - :meth:`_engine.Result.one_or_none` + :meth:`_engine.ScalarResult.one_or_none` :meth:`_engine.Result.scalars` @@ -1506,8 +1496,8 @@ def scalar_one_or_none(self) -> Optional[Any]: def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. - Raises :class:`.NoResultFound` if the result returns no - rows, or :class:`.MultipleResultsFound` if multiple rows + Raises :class:`_exc.NoResultFound` if the result returns no + rows, or :class:`_exc.MultipleResultsFound` if multiple rows would be returned. .. note:: This method returns one **row**, e.g. tuple, by default. @@ -1537,12 +1527,10 @@ def one(self) -> Row[_TP]: ) @overload - def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: - ... + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: ... @overload - def scalar(self) -> Any: - ... + def scalar(self) -> Any: ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -1776,7 +1764,7 @@ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: return self._manyrow_getter(self, size) def all(self) -> Sequence[_R]: - """Return all scalar values in a list. + """Return all scalar values in a sequence. Equivalent to :meth:`_engine.Result.all` except that scalar values, rather than :class:`_engine.Row` objects, @@ -1880,7 +1868,7 @@ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: ... def all(self) -> Sequence[_R]: # noqa: A001 - """Return all scalar values in a list. + """Return all scalar values in a sequence. Equivalent to :meth:`_engine.Result.all` except that tuple values, rather than :class:`_engine.Row` objects, @@ -1889,11 +1877,9 @@ def all(self) -> Sequence[_R]: # noqa: A001 """ ... - def __iter__(self) -> Iterator[_R]: - ... + def __iter__(self) -> Iterator[_R]: ... - def __next__(self) -> _R: - ... + def __next__(self) -> _R: ... def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -1927,22 +1913,20 @@ def one(self) -> _R: ... @overload - def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: - ... + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: ... @overload - def scalar_one(self) -> Any: - ... + def scalar_one(self) -> Any: ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.Result.one`. + and then :meth:`_engine.ScalarResult.one`. .. seealso:: - :meth:`_engine.Result.one` + :meth:`_engine.ScalarResult.one` :meth:`_engine.Result.scalars` @@ -1950,22 +1934,22 @@ def scalar_one(self) -> Any: ... @overload - def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: - ... + def scalar_one_or_none( + self: TupleResult[Tuple[_T]], + ) -> Optional[_T]: ... @overload - def scalar_one_or_none(self) -> Optional[Any]: - ... + def scalar_one_or_none(self) -> Optional[Any]: ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.Result.one_or_none`. + and then :meth:`_engine.ScalarResult.one_or_none`. .. seealso:: - :meth:`_engine.Result.one_or_none` + :meth:`_engine.ScalarResult.one_or_none` :meth:`_engine.Result.scalars` @@ -1973,12 +1957,10 @@ def scalar_one_or_none(self) -> Optional[Any]: ... @overload - def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: - ... + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: ... @overload - def scalar(self) -> Any: - ... + def scalar(self) -> Any: ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result @@ -2086,7 +2068,7 @@ def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]: return self._manyrow_getter(self, size) def all(self) -> Sequence[RowMapping]: - """Return all scalar values in a list. + """Return all scalar values in a sequence. Equivalent to :meth:`_engine.Result.all` except that :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 9017537ab09..da7ae9af277 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -1,5 +1,5 @@ # engine/row.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -213,15 +213,12 @@ def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool: if TYPE_CHECKING: @overload - def __getitem__(self, index: int) -> Any: - ... + def __getitem__(self, index: int) -> Any: ... @overload - def __getitem__(self, index: slice) -> Sequence[Any]: - ... + def __getitem__(self, index: slice) -> Sequence[Any]: ... - def __getitem__(self, index: Union[int, slice]) -> Any: - ... + def __getitem__(self, index: Union[int, slice]) -> Any: ... def __lt__(self, other: Any) -> bool: return self._op(other, operator.lt) @@ -296,8 +293,8 @@ class ROMappingView(ABC): def __init__( self, mapping: Mapping["_KeyType", Any], items: Sequence[Any] ): - self._mapping = mapping - self._items = items + self._mapping = mapping # type: ignore[misc] + self._items = items # type: ignore[misc] def __len__(self) -> int: return len(self._items) @@ -321,11 +318,11 @@ def __ne__(self, other: Any) -> bool: class ROMappingKeysValuesView( ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any] ): - __slots__ = ("_items",) + __slots__ = ("_items",) # mapping slot is provided by KeysView class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]): - __slots__ = ("_items",) + __slots__ = ("_items",) # mapping slot is provided by ItemsView class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): @@ -343,12 +340,11 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): as iteration of keys, values, and items:: for row in result: - if 'a' in row._mapping: - print("Column 'a': %s" % row._mapping['a']) + if "a" in row._mapping: + print("Column 'a': %s" % row._mapping["a"]) print("Column b: %s" % row._mapping[table.c.b]) - .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the mapping-like access previously provided by a database result row, which now seeks to behave mostly like a named tuple. @@ -359,8 +355,7 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): if TYPE_CHECKING: - def __getitem__(self, key: _KeyType) -> Any: - ... + def __getitem__(self, key: _KeyType) -> Any: ... else: __getitem__ = BaseRow._get_by_key_impl_mapping diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index f884f203c9e..b4b8077ba05 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -1,14 +1,11 @@ # engine/strategies.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Deprecated mock engine strategy used by Alembic. - - -""" +"""Deprecated mock engine strategy used by Alembic.""" from __future__ import annotations diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 5cf5ec7b4b7..20079a6b535 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -1,5 +1,5 @@ # engine/url.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -32,6 +32,7 @@ from typing import Type from typing import Union from urllib.parse import parse_qsl +from urllib.parse import quote from urllib.parse import quote_plus from urllib.parse import unquote @@ -121,7 +122,9 @@ class URL(NamedTuple): for keys and either strings or tuples of strings for values, e.g.:: >>> from sqlalchemy.engine import make_url - >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url = make_url( + ... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" + ... ) >>> url.query immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) @@ -170,6 +173,11 @@ def create( :param password: database password. Is typically a string, but may also be an object that can be stringified with ``str()``. + .. note:: The password string should **not** be URL encoded when + passed as an argument to :meth:`_engine.URL.create`; the string + should contain the password characters exactly as they would be + typed. + .. note:: A password-producing object will be stringified only **once** per :class:`_engine.Engine` object. For dynamic password generation per connect, see :ref:`engines_dynamic_tokens`. @@ -247,14 +255,12 @@ def _str_dict( @overload def _assert_value( val: str, - ) -> str: - ... + ) -> str: ... @overload def _assert_value( val: Sequence[str], - ) -> Union[str, Tuple[str, ...]]: - ... + ) -> Union[str, Tuple[str, ...]]: ... def _assert_value( val: Union[str, Sequence[str]], @@ -367,7 +373,9 @@ def update_query_string( >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url = url.update_query_string( + ... "alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" + ... ) >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -403,7 +411,13 @@ def update_query_pairs( >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")]) + >>> url = url.update_query_pairs( + ... [ + ... ("alt_host", "host1"), + ... ("alt_host", "host2"), + ... ("ssl_cipher", "/path/to/crt"), + ... ] + ... ) >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -485,7 +499,9 @@ def update_query_dict( >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"}) + >>> url = url.update_query_dict( + ... {"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"} + ... ) >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -523,14 +539,14 @@ def difference_update_query(self, names: Iterable[str]) -> URL: E.g.:: - url = url.difference_update_query(['foo', 'bar']) + url = url.difference_update_query(["foo", "bar"]) Equivalent to using :meth:`_engine.URL.set` as follows:: url = url.set( query={ key: url.query[key] - for key in set(url.query).difference(['foo', 'bar']) + for key in set(url.query).difference(["foo", "bar"]) } ) @@ -579,7 +595,9 @@ def normalized_query(self) -> Mapping[str, Sequence[str]]: >>> from sqlalchemy.engine import make_url - >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url = make_url( + ... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" + ... ) >>> url.query immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) >>> url.normalized_query @@ -621,17 +639,17 @@ def render_as_string(self, hide_password: bool = True) -> str: """ s = self.drivername + "://" if self.username is not None: - s += _sqla_url_quote(self.username) + s += quote(self.username, safe=" +") if self.password is not None: s += ":" + ( "***" if hide_password - else _sqla_url_quote(str(self.password)) + else quote(str(self.password), safe=" +") ) s += "@" if self.host is not None: if ":" in self.host: - s += "[%s]" % self.host + s += f"[{self.host}]" else: s += self.host if self.port is not None: @@ -642,7 +660,7 @@ def render_as_string(self, hide_password: bool = True) -> str: keys = list(self.query) keys.sort() s += "?" + "&".join( - "%s=%s" % (quote_plus(k), quote_plus(element)) + f"{quote_plus(k)}={quote_plus(element)}" for k in keys for element in util.to_list(self.query[k]) ) @@ -885,10 +903,10 @@ def _parse_url(name: str) -> URL: components["query"] = query if components["username"] is not None: - components["username"] = _sqla_url_unquote(components["username"]) + components["username"] = unquote(components["username"]) if components["password"] is not None: - components["password"] = _sqla_url_unquote(components["password"]) + components["password"] = unquote(components["password"]) ipv4host = components.pop("ipv4host") ipv6host = components.pop("ipv6host") @@ -902,12 +920,5 @@ def _parse_url(name: str) -> URL: else: raise exc.ArgumentError( - "Could not parse SQLAlchemy URL from string '%s'" % name + "Could not parse SQLAlchemy URL from given URL string" ) - - -def _sqla_url_quote(text: str) -> str: - return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) - - -_sqla_url_unquote = unquote diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 9b147a7014b..e499efa91aa 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -1,5 +1,5 @@ # engine/util.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -17,6 +17,7 @@ from .. import util from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Protocol +from ..util.typing import Self if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import _distill_params_20 as _distill_params_20 @@ -113,7 +114,7 @@ def _trans_ctx_check(cls, subject: _TConsSubject) -> None: "before emitting further commands." ) - def __enter__(self) -> TransactionalContext: + def __enter__(self) -> Self: subject = self._get_subject() # none for outer transaction, may be non-None for nested diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index 20a20d18e61..309b7bd33fb 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -1,5 +1,5 @@ # event/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index bb1dbea0fc9..01dd4bdd1bf 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -1,13 +1,11 @@ # event/api.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Public API functions for the event system. - -""" +"""Public API functions for the event system.""" from __future__ import annotations from typing import Any @@ -51,15 +49,14 @@ def listen( from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint + def unique_constraint_name(const, table): - const.name = "uq_%s_%s" % ( - table.name, - list(const.columns)[0].name - ) + const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name) + + event.listen( - UniqueConstraint, - "after_parent_attach", - unique_constraint_name) + UniqueConstraint, "after_parent_attach", unique_constraint_name + ) :param bool insert: The default behavior for event handlers is to append the decorated user defined function to an internal list of registered @@ -132,19 +129,17 @@ def listens_for( The :func:`.listens_for` decorator is part of the primary interface for the SQLAlchemy event system, documented at :ref:`event_toplevel`. - This function generally shares the same kwargs as :func:`.listens`. + This function generally shares the same kwargs as :func:`.listen`. e.g.:: from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint + @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): - const.name = "uq_%s_%s" % ( - table.name, - list(const.columns)[0].name - ) + const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name) A given function can also be invoked for only the first invocation of the event using the ``once`` argument:: @@ -153,7 +148,6 @@ def unique_constraint_name(const, table): def on_config(): do_config() - .. warning:: The ``once`` argument does not imply automatic de-registration of the listener function after it has been invoked a first time; a listener entry will remain associated with the target object. @@ -189,6 +183,7 @@ def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: def my_listener_function(*arg): pass + # ... it's removed like this event.remove(SomeMappedClass, "before_insert", my_listener_function) diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 0aa34198305..ec5d5822f1c 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -1,5 +1,5 @@ # event/attr.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -391,20 +391,23 @@ def __bool__(self) -> bool: class _MutexProtocol(Protocol): - def __enter__(self) -> bool: - ... + def __enter__(self) -> bool: ... def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - ... + ) -> Optional[bool]: ... class _CompoundListener(_InstanceLevelDispatch[_ET]): - __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" + __slots__ = ( + "_exec_once_mutex", + "_exec_once", + "_exec_w_sync_once", + "_is_asyncio", + ) _exec_once_mutex: _MutexProtocol parent_listeners: Collection[_ListenerFnType] @@ -412,11 +415,18 @@ class _CompoundListener(_InstanceLevelDispatch[_ET]): _exec_once: bool _exec_w_sync_once: bool + def __init__(self, *arg: Any, **kw: Any): + super().__init__(*arg, **kw) + self._is_asyncio = False + def _set_asyncio(self) -> None: - self._exec_once_mutex = AsyncAdaptedLock() + self._is_asyncio = True def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: - return threading.Lock() + if self._is_asyncio: + return AsyncAdaptedLock() + else: + return threading.Lock() def _exec_once_impl( self, retry_on_exception: bool, *args: Any, **kw: Any @@ -525,6 +535,7 @@ class _ListenerCollection(_CompoundListener[_ET]): propagate: Set[_ListenerFnType] def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): + super().__init__() if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self._exec_once = False @@ -564,6 +575,9 @@ def _update( existing_listeners.extend(other_listeners) + if other._is_asyncio: + self._set_asyncio() + to_associate = other.propagate.union(other_listeners) registry._stored_in_collection_multi(self, other, to_associate) diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index f92b2ede3cd..66dc12996bc 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -1,5 +1,5 @@ # event/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -42,9 +42,9 @@ from .. import util from ..util.typing import Literal -_registrars: MutableMapping[ - str, List[Type[_HasEventsDispatch[Any]]] -] = util.defaultdict(list) +_registrars: MutableMapping[str, List[Type[_HasEventsDispatch[Any]]]] = ( + util.defaultdict(list) +) def _is_event_name(name: str) -> bool: @@ -191,13 +191,8 @@ def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: :class:`._Dispatch` objects. """ - if "_joined_dispatch_cls" not in self.__class__.__dict__: - cls = type( - "Joined%s" % self.__class__.__name__, - (_JoinedDispatcher,), - {"__slots__": self._event_names}, - ) - self.__class__._joined_dispatch_cls = cls + assert "_joined_dispatch_cls" in self.__class__.__dict__ + return self._joined_dispatch_cls(self, other) def __reduce__(self) -> Union[str, Tuple[Any, ...]]: @@ -240,8 +235,7 @@ class _HasEventsDispatch(Generic[_ET]): if typing.TYPE_CHECKING: - def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: - ... + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: ... def __init_subclass__(cls) -> None: """Intercept new Event subclasses and create associated _Dispatch @@ -329,6 +323,51 @@ def _create_dispatcher_class( else: dispatch_target_cls.dispatch = dispatcher(cls) + klass = type( + "Joined%s" % dispatch_cls.__name__, + (_JoinedDispatcher,), + {"__slots__": event_names}, + ) + dispatch_cls._joined_dispatch_cls = klass + + # establish pickle capability by adding it to this module + globals()[klass.__name__] = klass + + +class _JoinedDispatcher(_DispatchCommon[_ET]): + """Represent a connection between two _Dispatch objects.""" + + __slots__ = "local", "parent", "_instance_cls" + + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): + self.local = local + self.parent = parent + self._instance_cls = self.local._instance_cls + + def __reduce__(self) -> Any: + return (self.__class__, (self.local, self.parent)) + + def __getattr__(self, name: str) -> _JoinedListener[_ET]: + # Assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects. + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + return self.parent._events + class Events(_HasEventsDispatch[_ET]): """Define event listening functions for a particular target type.""" @@ -341,9 +380,11 @@ def dispatch_is(*types: Type[Any]) -> bool: return all(isinstance(target.dispatch, t) for t in types) def dispatch_parent_is(t: Type[Any]) -> bool: - return isinstance( - cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t - ) + parent = cast("_JoinedDispatcher[_ET]", target.dispatch).parent + while isinstance(parent, _JoinedDispatcher): + parent = cast("_JoinedDispatcher[_ET]", parent).parent + + return isinstance(parent, t) # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. @@ -383,38 +424,6 @@ def _clear(cls) -> None: cls.dispatch._clear() -class _JoinedDispatcher(_DispatchCommon[_ET]): - """Represent a connection between two _Dispatch objects.""" - - __slots__ = "local", "parent", "_instance_cls" - - local: _DispatchCommon[_ET] - parent: _DispatchCommon[_ET] - _instance_cls: Optional[Type[_ET]] - - def __init__( - self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] - ): - self.local = local - self.parent = parent - self._instance_cls = self.local._instance_cls - - def __getattr__(self, name: str) -> _JoinedListener[_ET]: - # Assign _JoinedListeners as attributes on demand - # to reduce startup time for new dispatch objects. - ls = getattr(self.local, name) - jl = _JoinedListener(self.parent, ls.name, ls) - setattr(self, ls.name, jl) - return jl - - def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: - return self.parent._listen(event_key, **kw) - - @property - def _events(self) -> Type[_HasEventsDispatch[_ET]]: - return self.parent._events - - class dispatcher(Generic[_ET]): """Descriptor used by target classes to deliver the _Dispatch class at the class level @@ -430,12 +439,10 @@ def __init__(self, events: Type[_HasEventsDispatch[_ET]]): @overload def __get__( self, obj: Literal[None], cls: Type[Any] - ) -> Type[_Dispatch[_ET]]: - ... + ) -> Type[_Dispatch[_ET]]: ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: - ... + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index f3a7d04acee..e60fd9a5e17 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -1,5 +1,5 @@ # event/legacy.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -147,9 +147,9 @@ def _standard_listen_example( ) text %= { - "current_since": " (arguments as of %s)" % current_since - if current_since - else "", + "current_since": ( + " (arguments as of %s)" % current_since if current_since else "" + ), "event_name": fn.__name__, "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(dispatch_collection.arg_names), @@ -177,9 +177,9 @@ def _legacy_listen_examples( % { "since": since, "event_name": fn.__name__, - "has_kw_arguments": " **kw" - if dispatch_collection.has_kw - else "", + "has_kw_arguments": ( + " **kw" if dispatch_collection.has_kw else "" + ), "named_event_arguments": ", ".join(args), "sample_target": sample_target, } diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index fb2fed815f1..d7e4b321553 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -1,5 +1,5 @@ # event/registry.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -66,9 +66,9 @@ class EventTarget: "weakref.ref[_ListenerFnType]", ] -_key_to_collection: Dict[ - _EventKeyTupleType, _RefCollectionToListenerType -] = collections.defaultdict(dict) +_key_to_collection: Dict[_EventKeyTupleType, _RefCollectionToListenerType] = ( + collections.defaultdict(dict) +) """ Given an original listen() argument, can locate all listener collections and the listener fn contained @@ -154,7 +154,11 @@ def _removed_from_collection( if owner_ref in _collection_to_key: listener_to_key = _collection_to_key[owner_ref] - listener_to_key.pop(listen_ref) + # see #12216 - this guards against a removal that already occurred + # here. however, I cannot come up with a test that shows any negative + # side effects occurring from this removal happening, even though an + # event key may still be referenced from a clsleveldispatch here + listener_to_key.pop(listen_ref, None) def _stored_in_collection_multi( diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 2f7b23db4e3..ce832439516 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -1,5 +1,5 @@ -# sqlalchemy/events.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# events.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index a5a66de877f..71e5dd81e0b 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -1,5 +1,5 @@ -# sqlalchemy/exc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# exc.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -432,14 +432,16 @@ class DontWrapMixin: from sqlalchemy.exc import DontWrapMixin + class MyCustomException(Exception, DontWrapMixin): pass + class MySpecialType(TypeDecorator): impl = String def process_bind_param(self, value, dialect): - if value == 'invalid': + if value == "invalid": raise MyCustomException("invalid!") """ @@ -571,8 +573,7 @@ def instance( connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> StatementError: - ... + ) -> StatementError: ... @overload @classmethod @@ -586,8 +587,7 @@ def instance( connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> DontWrapMixin: - ... + ) -> DontWrapMixin: ... @overload @classmethod @@ -601,8 +601,7 @@ def instance( connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> BaseException: - ... + ) -> BaseException: ... @classmethod def instance( diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index e3af738b7ce..2751bcf938a 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -1,5 +1,5 @@ # ext/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 31df1345348..8f2c19b8764 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1,5 +1,5 @@ # ext/associationproxy.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -98,6 +98,7 @@ def association_proxy( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 ) -> AssociationProxy[Any]: r"""Return a Python property implementing a view of a target attribute which references an attribute on members of the @@ -198,6 +199,13 @@ def association_proxy( .. versionadded:: 2.0.0b4 + :param hash: Specific to + :ref:`orm_declarative_native_dataclasses`, controls if this field + is included when generating the ``__hash__()`` method for the mapped + class. + + .. versionadded:: 2.0.36 + :param info: optional, will be assigned to :attr:`.AssociationProxy.info` if present. @@ -237,7 +245,7 @@ def association_proxy( cascade_scalar_deletes=cascade_scalar_deletes, create_on_none_assignment=create_on_none_assignment, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), ) @@ -254,45 +262,39 @@ class AssociationProxyExtensionType(InspectionAttrExtensionType): class _GetterProtocol(Protocol[_T_co]): - def __call__(self, instance: Any) -> _T_co: - ... + def __call__(self, instance: Any) -> _T_co: ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _SetterProtocol(Protocol): - ... +class _SetterProtocol(Protocol): ... class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, value: _T_con) -> None: - ... + def __call__(self, instance: Any, value: _T_con) -> None: ... class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, key: Any, value: _T_con) -> None: - ... + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _CreatorProtocol(Protocol): - ... +class _CreatorProtocol(Protocol): ... class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, value: _T_con) -> Any: - ... + def __call__(self, value: _T_con) -> Any: ... class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, key: Any, value: Optional[_T_con]) -> Any: - ... + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ... class _LazyCollectionProtocol(Protocol[_T]): def __call__( self, - ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: - ... + ) -> Union[ + MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T] + ]: ... class _GetSetFactoryProtocol(Protocol): @@ -300,8 +302,7 @@ def __call__( self, collection_class: Optional[Type[Any]], assoc_instance: AssociationProxyInstance[Any], - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: - ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... class _ProxyFactoryProtocol(Protocol): @@ -311,15 +312,13 @@ def __call__( creator: _CreatorProtocol, value_attr: str, parent: AssociationProxyInstance[Any], - ) -> Any: - ... + ) -> Any: ... class _ProxyBulkSetProtocol(Protocol): def __call__( self, proxy: _AssociationCollection[Any], collection: Iterable[Any] - ) -> None: - ... + ) -> None: ... class _AssociationProxyProtocol(Protocol[_T]): @@ -337,18 +336,15 @@ class _AssociationProxyProtocol(Protocol[_T]): proxy_bulk_set: Optional[_ProxyBulkSetProtocol] @util.ro_memoized_property - def info(self) -> _InfoType: - ... + def info(self) -> _InfoType: ... def for_class( self, class_: Type[Any], obj: Optional[object] = None - ) -> AssociationProxyInstance[_T]: - ... + ) -> AssociationProxyInstance[_T]: ... def _default_getset( self, collection_class: Any - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: - ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... class AssociationProxy( @@ -419,18 +415,17 @@ def __init__( self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS @overload - def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self: - ... + def __get__( + self, instance: Literal[None], owner: Literal[None] + ) -> Self: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> AssociationProxyInstance[_T]: - ... + ) -> AssociationProxyInstance[_T]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( self, instance: object, owner: Any @@ -463,7 +458,7 @@ def for_class( class User(Base): # ... - keywords = association_proxy('kws', 'keyword') + keywords = association_proxy("kws", "keyword") If we access this :class:`.AssociationProxy` from :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the @@ -783,9 +778,9 @@ def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]: :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: stmt = ( - select(Parent). - join(Parent.proxied.local_attr). - join(Parent.proxied.remote_attr) + select(Parent) + .join(Parent.proxied.local_attr) + .join(Parent.proxied.remote_attr) ) A future release may seek to provide a more succinct join pattern @@ -861,12 +856,10 @@ def info(self) -> _InfoType: return self.parent.info @overload - def get(self: _Self, obj: Literal[None]) -> _Self: - ... + def get(self: _Self, obj: Literal[None]) -> _Self: ... @overload - def get(self, obj: Any) -> _T: - ... + def get(self, obj: Any) -> _T: ... def get( self, obj: Any @@ -1089,7 +1082,7 @@ def any( and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar " "attributes. Use has()." + "'any()' not implemented for scalar attributes. Use has()." ) return self._criterion_exists( criterion=criterion, is_has=False, **kwargs @@ -1113,7 +1106,7 @@ def has( or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. " "Use any()." + "'has()' not implemented for collections. Use any()." ) return self._criterion_exists( criterion=criterion, is_has=True, **kwargs @@ -1432,12 +1425,10 @@ def _set(self, object_: Any, value: _T) -> None: self.setter(object_, value) @overload - def __getitem__(self, index: int) -> _T: - ... + def __getitem__(self, index: int) -> _T: ... @overload - def __getitem__(self, index: slice) -> MutableSequence[_T]: - ... + def __getitem__(self, index: slice) -> MutableSequence[_T]: ... def __getitem__( self, index: Union[int, slice] @@ -1448,12 +1439,10 @@ def __getitem__( return [self._get(member) for member in self.col[index]] @overload - def __setitem__(self, index: int, value: _T) -> None: - ... + def __setitem__(self, index: int, value: _T) -> None: ... @overload - def __setitem__(self, index: slice, value: Iterable[_T]) -> None: - ... + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... def __setitem__( self, index: Union[int, slice], value: Union[_T, Iterable[_T]] @@ -1492,12 +1481,10 @@ def __setitem__( self._set(self.col[i], item) @overload - def __delitem__(self, index: int) -> None: - ... + def __delitem__(self, index: int) -> None: ... @overload - def __delitem__(self, index: slice) -> None: - ... + def __delitem__(self, index: slice) -> None: ... def __delitem__(self, index: Union[slice, int]) -> None: del self.col[index] @@ -1624,8 +1611,9 @@ def __imul__(self, n: SupportsIndex) -> Self: if typing.TYPE_CHECKING: # TODO: no idea how to do this without separate "stub" - def index(self, value: Any, start: int = ..., stop: int = ...) -> int: - ... + def index( + self, value: Any, start: int = ..., stop: int = ... + ) -> int: ... else: @@ -1701,12 +1689,10 @@ def __repr__(self) -> str: return repr(dict(self)) @overload - def get(self, __key: _KT) -> Optional[_VT]: - ... + def get(self, __key: _KT) -> Optional[_VT]: ... @overload - def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: - ... + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... def get( self, key: _KT, default: Optional[Union[_VT, _T]] = None @@ -1738,12 +1724,12 @@ def values(self) -> ValuesView[_VT]: return ValuesView(self) @overload - def pop(self, __key: _KT) -> _VT: - ... + def pop(self, __key: _KT) -> _VT: ... @overload - def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]: - ... + def pop( + self, __key: _KT, default: Union[_VT, _T] = ... + ) -> Union[_VT, _T]: ... def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: member = self.col.pop(__key, *arg, **kw) @@ -1756,16 +1742,15 @@ def popitem(self) -> Tuple[_KT, _VT]: @overload def update( self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT - ) -> None: - ... + ) -> None: ... @overload - def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: - ... + def update( + self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... @overload - def update(self, **kwargs: _VT) -> None: - ... + def update(self, **kwargs: _VT) -> None: ... def update(self, *a: Any, **kw: Any) -> None: up: Dict[_KT, _VT] = {} diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 8564db6f22e..7d8a04bd789 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -1,5 +1,5 @@ # ext/asyncio/__init__.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 251f5212542..72a617f4e22 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,5 +1,5 @@ # ext/asyncio/base.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -44,12 +44,10 @@ class ReversibleProxy(Generic[_PT]): __slots__ = ("__weakref__",) @overload - def _assign_proxied(self, target: _PT) -> _PT: - ... + def _assign_proxied(self, target: _PT) -> _PT: ... @overload - def _assign_proxied(self, target: None) -> None: - ... + def _assign_proxied(self, target: None) -> None: ... def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: @@ -73,28 +71,26 @@ def _target_gced( cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target(cls, target: _PT) -> Self: + def _regenerate_proxy_for_target( + cls, target: _PT, **additional_kw: Any + ) -> Self: raise NotImplementedError() @overload @classmethod def _retrieve_proxy_for_target( - cls, - target: _PT, - regenerate: Literal[True] = ..., - ) -> Self: - ... + cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any + ) -> Self: ... @overload @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True - ) -> Optional[Self]: - ... + cls, target: _PT, regenerate: bool = True, **additional_kw: Any + ) -> Optional[Self]: ... @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True + cls, target: _PT, regenerate: bool = True, **additional_kw: Any ) -> Optional[Self]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] @@ -106,7 +102,7 @@ def _retrieve_proxy_for_target( return proxy # type: ignore if regenerate: - return cls._regenerate_proxy_for_target(target) + return cls._regenerate_proxy_for_target(target, **additional_kw) else: return None @@ -182,7 +178,7 @@ async def __aexit__( # tell if we get the same exception back value = typ() try: - await util.athrow(self.gen, typ, value, traceback) + await self.gen.athrow(value) except StopAsyncIteration as exc: # Suppress StopIteration *unless* it's the same exception that # was passed to throw(). This prevents a StopIteration @@ -219,7 +215,7 @@ async def __aexit__( def asyncstartablecontext( - func: Callable[..., AsyncIterator[_T_co]] + func: Callable[..., AsyncIterator[_T_co]], ) -> Callable[..., GeneratorStartableContext[_T_co]]: """@asyncstartablecontext decorator. @@ -228,7 +224,9 @@ def asyncstartablecontext( ``@contextlib.asynccontextmanager`` supports, and the usage pattern is different as well. - Typical usage:: + Typical usage: + + .. sourcecode:: text @asyncstartablecontext async def some_async_generator(): diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index bf968cc3884..d4ecbdac986 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -1,5 +1,5 @@ # ext/asyncio/engine.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -41,6 +41,8 @@ from ...engine.base import Transaction from ...exc import ArgumentError from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec if TYPE_CHECKING: from ...engine.cursor import CursorResult @@ -61,6 +63,7 @@ from ...sql.base import Executable from ...sql.selectable import TypedReturnsRows +_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -195,6 +198,7 @@ class AsyncConnection( method of :class:`_asyncio.AsyncEngine`:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") async with engine.connect() as conn: @@ -251,7 +255,7 @@ def __init__( @classmethod def _regenerate_proxy_for_target( - cls, target: Connection + cls, target: Connection, **additional_kw: Any # noqa: U100 ) -> AsyncConnection: return AsyncConnection( AsyncEngine._retrieve_proxy_for_target(target.engine), target @@ -414,13 +418,12 @@ async def execution_options( yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, **opt: Any, - ) -> AsyncConnection: - ... + ) -> AsyncConnection: ... @overload - async def execution_options(self, **opt: Any) -> AsyncConnection: - ... + async def execution_options(self, **opt: Any) -> AsyncConnection: ... async def execution_options(self, **opt: Any) -> AsyncConnection: r"""Set non-SQL options for the connection which take effect @@ -518,8 +521,7 @@ def stream( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[_T]]: - ... + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... @overload def stream( @@ -528,8 +530,7 @@ def stream( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[Any]]: - ... + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... @asyncstartablecontext async def stream( @@ -544,7 +545,7 @@ async def stream( E.g.:: - result = await conn.stream(stmt): + result = await conn.stream(stmt) async for row in result: print(f"{row}") @@ -573,6 +574,11 @@ async def stream( :meth:`.AsyncConnection.stream_scalars` """ + if not self.dialect.supports_server_side_cursors: + raise exc.InvalidRequestError( + "Cant use `stream` or `stream_scalars` with the current " + "dialect since it does not support server side cursors." + ) result = await greenlet_spawn( self._proxied.execute, @@ -600,8 +606,7 @@ async def execute( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[_T]: - ... + ) -> CursorResult[_T]: ... @overload async def execute( @@ -610,8 +615,7 @@ async def execute( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... async def execute( self, @@ -667,8 +671,7 @@ async def scalar( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -677,8 +680,7 @@ async def scalar( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -709,8 +711,7 @@ async def scalars( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -719,8 +720,7 @@ async def scalars( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -752,8 +752,7 @@ def stream_scalars( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: - ... + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... @overload def stream_scalars( @@ -762,8 +761,7 @@ def stream_scalars( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: - ... + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... @asyncstartablecontext async def stream_scalars( @@ -819,9 +817,12 @@ async def stream_scalars( yield result.scalars() async def run_sync( - self, fn: Callable[..., _T], *arg: Any, **kw: Any + self, + fn: Callable[Concatenate[Connection, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, ) -> _T: - """Invoke the given synchronous (i.e. not async) callable, + '''Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_engine.Connection` as the first argument. @@ -831,26 +832,26 @@ async def run_sync( E.g.:: def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str: - '''A synchronous function that does not require awaiting + """A synchronous function that does not require awaiting :param conn: a Core SQLAlchemy Connection, used synchronously :return: an optional return value is supported - ''' - conn.execute( - some_table.insert().values(int_col=arg1, str_col=arg2) - ) + """ + conn.execute(some_table.insert().values(int_col=arg1, str_col=arg2)) return "success" async def do_something_async(async_engine: AsyncEngine) -> None: - '''an async function that uses awaiting''' + """an async function that uses awaiting""" async with async_engine.begin() as async_conn: # run do_something_with_core() with a sync-style # Connection, proxied into an awaitable - return_code = await async_conn.run_sync(do_something_with_core, 5, "strval") + return_code = await async_conn.run_sync( + do_something_with_core, 5, "strval" + ) print(return_code) This method maintains the asyncio event loop all the way through @@ -881,9 +882,11 @@ async def do_something_async(async_engine: AsyncEngine) -> None: :ref:`session_run_sync` - """ # noqa: E501 + ''' # noqa: E501 - return await greenlet_spawn(fn, self._proxied, *arg, **kw) + return await greenlet_spawn( + fn, self._proxied, *arg, _require_await=False, **kw + ) def __await__(self) -> Generator[Any, None, AsyncConnection]: return self.start().__await__() @@ -928,7 +931,7 @@ def invalidated(self) -> Any: return self._proxied.invalidated @property - def dialect(self) -> Any: + def dialect(self) -> Dialect: r"""Proxy for the :attr:`_engine.Connection.dialect` attribute on behalf of the :class:`_asyncio.AsyncConnection` class. @@ -937,7 +940,7 @@ def dialect(self) -> Any: return self._proxied.dialect @dialect.setter - def dialect(self, attr: Any) -> None: + def dialect(self, attr: Dialect) -> None: self._proxied.dialect = attr @property @@ -998,6 +1001,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): :func:`_asyncio.create_async_engine` function:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") .. versionadded:: 1.4 @@ -1037,7 +1041,9 @@ def _proxied(self) -> Engine: return self.sync_engine @classmethod - def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: + def _regenerate_proxy_for_target( + cls, target: Engine, **additional_kw: Any # noqa: U100 + ) -> AsyncEngine: return AsyncEngine(target) @contextlib.asynccontextmanager @@ -1054,7 +1060,6 @@ async def begin(self) -> AsyncIterator[AsyncConnection]: ) await conn.execute(text("my_special_procedure(5)")) - """ conn = self.connect() @@ -1100,12 +1105,10 @@ def execution_options( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> AsyncEngine: - ... + ) -> AsyncEngine: ... @overload - def execution_options(self, **opt: Any) -> AsyncEngine: - ... + def execution_options(self, **opt: Any) -> AsyncEngine: ... def execution_options(self, **opt: Any) -> AsyncEngine: """Return a new :class:`_asyncio.AsyncEngine` that will provide @@ -1160,7 +1163,7 @@ def clear_compiled_cache(self) -> None: This applies **only** to the built-in cache that is established via the :paramref:`_engine.create_engine.query_cache_size` parameter. It will not impact any dictionary caches that were passed via the - :paramref:`.Connection.execution_options.query_cache` parameter. + :paramref:`.Connection.execution_options.compiled_cache` parameter. .. versionadded:: 1.4 @@ -1343,7 +1346,7 @@ def __init__(self, connection: AsyncConnection, nested: bool = False): @classmethod def _regenerate_proxy_for_target( - cls, target: Transaction + cls, target: Transaction, **additional_kw: Any # noqa: U100 ) -> AsyncTransaction: sync_connection = target.connection sync_transaction = target @@ -1418,19 +1421,17 @@ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: @overload -def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: - ... +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ... @overload def _get_sync_engine_or_connection( async_engine: AsyncConnection, -) -> Connection: - ... +) -> Connection: ... def _get_sync_engine_or_connection( - async_engine: Union[AsyncEngine, AsyncConnection] + async_engine: Union[AsyncEngine, AsyncConnection], ) -> Union[Engine, Connection]: if isinstance(async_engine, AsyncConnection): return async_engine._proxied diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py index 3f937679b93..558187c0b41 100644 --- a/lib/sqlalchemy/ext/asyncio/exc.py +++ b/lib/sqlalchemy/ext/asyncio/exc.py @@ -1,5 +1,5 @@ # ext/asyncio/exc.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index a13e106ff31..8003f66afe2 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -1,5 +1,5 @@ # ext/asyncio/result.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -93,6 +93,7 @@ def __init__(self, real_result: Result[_TP]): self._metadata = real_result._metadata self._unique_filter_state = real_result._unique_filter_state + self._source_supports_scalars = real_result._source_supports_scalars self._post_creational_filter = None # BaseCursorResult pre-generates the "_row_getter". Use that @@ -324,22 +325,20 @@ async def one_or_none(self) -> Optional[Row[_TP]]: return await greenlet_spawn(self._only_one_row, True, False, False) @overload - async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: - ... + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ... @overload - async def scalar_one(self) -> Any: - ... + async def scalar_one(self) -> Any: ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and - then :meth:`_asyncio.AsyncResult.one`. + then :meth:`_asyncio.AsyncScalarResult.one`. .. seealso:: - :meth:`_asyncio.AsyncResult.one` + :meth:`_asyncio.AsyncScalarResult.one` :meth:`_asyncio.AsyncResult.scalars` @@ -349,22 +348,20 @@ async def scalar_one(self) -> Any: @overload async def scalar_one_or_none( self: AsyncResult[Tuple[_T]], - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: - ... + async def scalar_one_or_none(self) -> Optional[Any]: ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and - then :meth:`_asyncio.AsyncResult.one_or_none`. + then :meth:`_asyncio.AsyncScalarResult.one_or_none`. .. seealso:: - :meth:`_asyncio.AsyncResult.one_or_none` + :meth:`_asyncio.AsyncScalarResult.one_or_none` :meth:`_asyncio.AsyncResult.scalars` @@ -403,12 +400,10 @@ async def one(self) -> Row[_TP]: return await greenlet_spawn(self._only_one_row, True, True, False) @overload - async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: - ... + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ... @overload - async def scalar(self) -> Any: - ... + async def scalar(self) -> Any: ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -452,16 +447,13 @@ async def freeze(self) -> FrozenResult[_TP]: @overload def scalars( self: AsyncResult[Tuple[_T]], index: Literal[0] - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload - def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: - ... + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ... @overload - def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: - ... + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ... def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which @@ -833,11 +825,9 @@ async def all(self) -> Sequence[_R]: # noqa: A001 """ ... - async def __aiter__(self) -> AsyncIterator[_R]: - ... + async def __aiter__(self) -> AsyncIterator[_R]: ... - async def __anext__(self) -> _R: - ... + async def __anext__(self) -> _R: ... async def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -871,22 +861,20 @@ async def one(self) -> _R: ... @overload - async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: - ... + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ... @overload - async def scalar_one(self) -> Any: - ... + async def scalar_one(self) -> Any: ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.Result.one`. + and then :meth:`_engine.AsyncScalarResult.one`. .. seealso:: - :meth:`_engine.Result.one` + :meth:`_engine.AsyncScalarResult.one` :meth:`_engine.Result.scalars` @@ -896,22 +884,20 @@ async def scalar_one(self) -> Any: @overload async def scalar_one_or_none( self: AsyncTupleResult[Tuple[_T]], - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: - ... + async def scalar_one_or_none(self) -> Optional[Any]: ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.Result.one_or_none`. + and then :meth:`_engine.AsyncScalarResult.one_or_none`. .. seealso:: - :meth:`_engine.Result.one_or_none` + :meth:`_engine.AsyncScalarResult.one_or_none` :meth:`_engine.Result.scalars` @@ -919,12 +905,12 @@ async def scalar_one_or_none(self) -> Optional[Any]: ... @overload - async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: - ... + async def scalar( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... @overload - async def scalar(self) -> Any: - ... + async def scalar(self) -> Any: ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 4c68f53ffa8..d2a9a51b231 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -1,5 +1,5 @@ # ext/asyncio/scoping.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -364,7 +364,7 @@ def begin(self) -> AsyncSessionTransaction: object is entered:: async with async_session.begin(): - # .. ORM transaction is begun + ... # ORM transaction is begun Note that database IO will not normally occur when the session-level transaction is begun, as database transactions begin on an @@ -536,8 +536,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: - ... + ) -> Result[_T]: ... @overload async def execute( @@ -549,8 +548,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... @overload async def execute( @@ -562,8 +560,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: - ... + ) -> Result[Any]: ... async def execute( self, @@ -811,28 +808,28 @@ def get_bind( # construct async engines w/ async drivers engines = { - 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), - 'other':create_async_engine("sqlite+aiosqlite:///other.db"), - 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), - 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + "leader": create_async_engine("sqlite+aiosqlite:///leader.db"), + "other": create_async_engine("sqlite+aiosqlite:///other.db"), + "follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"), + "follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"), } + class RoutingSession(Session): def get_bind(self, mapper=None, clause=None, **kw): # within get_bind(), return sync engines if mapper and issubclass(mapper.class_, MyOtherClass): - return engines['other'].sync_engine + return engines["other"].sync_engine elif self._flushing or isinstance(clause, (Update, Delete)): - return engines['leader'].sync_engine + return engines["leader"].sync_engine else: return engines[ - random.choice(['follower1','follower2']) + random.choice(["follower1", "follower2"]) ].sync_engine + # apply to AsyncSession using sync_session_class - AsyncSessionMaker = async_sessionmaker( - sync_session_class=RoutingSession - ) + AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession) The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, implicitly non-blocking context in the same manner as ORM event hooks @@ -867,7 +864,7 @@ def is_modified( This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously committed value, if any. + value to its previously flushed or committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1015,8 +1012,7 @@ async def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -1027,8 +1023,7 @@ async def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -1070,8 +1065,7 @@ async def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -1082,8 +1076,7 @@ async def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -1182,8 +1175,7 @@ async def get_one( Proxied for the :class:`_asyncio.AsyncSession` class on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects - no rows. + Raises :class:`_exc.NoResultFound` if the query selects no rows. ..versionadded: 2.0.22 @@ -1213,8 +1205,7 @@ async def stream( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[_T]: - ... + ) -> AsyncResult[_T]: ... @overload async def stream( @@ -1225,8 +1216,7 @@ async def stream( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Any]: - ... + ) -> AsyncResult[Any]: ... async def stream( self, @@ -1265,8 +1255,7 @@ async def stream_scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload async def stream_scalars( @@ -1277,8 +1266,7 @@ async def stream_scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: - ... + ) -> AsyncScalarResult[Any]: ... async def stream_scalars( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 30232e59cbb..68cbb59bfd6 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1,5 +1,5 @@ # ext/asyncio/session.py -# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -38,6 +38,9 @@ from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + if TYPE_CHECKING: from .engine import AsyncConnection @@ -71,6 +74,7 @@ _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] +_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -332,9 +336,12 @@ async def refresh( ) async def run_sync( - self, fn: Callable[..., _T], *arg: Any, **kw: Any + self, + fn: Callable[Concatenate[Session, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, ) -> _T: - """Invoke the given synchronous (i.e. not async) callable, + '''Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_orm.Session` as the first argument. @@ -344,25 +351,27 @@ async def run_sync( E.g.:: def some_business_method(session: Session, param: str) -> str: - '''A synchronous function that does not require awaiting + """A synchronous function that does not require awaiting :param session: a SQLAlchemy Session, used synchronously :return: an optional return value is supported - ''' + """ session.add(MyObject(param=param)) session.flush() return "success" async def do_something_async(async_engine: AsyncEngine) -> None: - '''an async function that uses awaiting''' + """an async function that uses awaiting""" with AsyncSession(async_engine) as async_session: # run some_business_method() with a sync-style # Session, proxied into an awaitable - return_code = await async_session.run_sync(some_business_method, param="param1") + return_code = await async_session.run_sync( + some_business_method, param="param1" + ) print(return_code) This method maintains the asyncio event loop all the way through @@ -384,9 +393,11 @@ async def do_something_async(async_engine: AsyncEngine) -> None: :meth:`.AsyncConnection.run_sync` :ref:`session_run_sync` - """ # noqa: E501 + ''' # noqa: E501 - return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + return await greenlet_spawn( + fn, self.sync_session, *arg, _require_await=False, **kw + ) @overload async def execute( @@ -398,8 +409,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: - ... + ) -> Result[_T]: ... @overload async def execute( @@ -411,8 +421,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... @overload async def execute( @@ -424,8 +433,7 @@ async def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: - ... + ) -> Result[Any]: ... async def execute( self, @@ -471,8 +479,7 @@ async def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -483,8 +490,7 @@ async def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -528,8 +534,7 @@ async def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -540,8 +545,7 @@ async def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -624,8 +628,7 @@ async def get_one( """Return an instance based on the given primary key identifier, or raise an exception if not found. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects - no rows. + Raises :class:`_exc.NoResultFound` if the query selects no rows. ..versionadded: 2.0.22 @@ -655,8 +658,7 @@ async def stream( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[_T]: - ... + ) -> AsyncResult[_T]: ... @overload async def stream( @@ -667,8 +669,7 @@ async def stream( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Any]: - ... + ) -> AsyncResult[Any]: ... async def stream( self, @@ -710,8 +711,7 @@ async def stream_scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload async def stream_scalars( @@ -722,8 +722,7 @@ async def stream_scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: - ... + ) -> AsyncScalarResult[Any]: ... async def stream_scalars( self, @@ -812,7 +811,9 @@ def get_transaction(self) -> Optional[AsyncSessionTransaction]: """ trans = self.sync_session.get_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + return AsyncSessionTransaction._retrieve_proxy_for_target( + trans, async_session=self + ) else: return None @@ -828,7 +829,9 @@ def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: trans = self.sync_session.get_nested_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + return AsyncSessionTransaction._retrieve_proxy_for_target( + trans, async_session=self + ) else: return None @@ -879,28 +882,28 @@ def get_bind( # construct async engines w/ async drivers engines = { - 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), - 'other':create_async_engine("sqlite+aiosqlite:///other.db"), - 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), - 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + "leader": create_async_engine("sqlite+aiosqlite:///leader.db"), + "other": create_async_engine("sqlite+aiosqlite:///other.db"), + "follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"), + "follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"), } + class RoutingSession(Session): def get_bind(self, mapper=None, clause=None, **kw): # within get_bind(), return sync engines if mapper and issubclass(mapper.class_, MyOtherClass): - return engines['other'].sync_engine + return engines["other"].sync_engine elif self._flushing or isinstance(clause, (Update, Delete)): - return engines['leader'].sync_engine + return engines["leader"].sync_engine else: return engines[ - random.choice(['follower1','follower2']) + random.choice(["follower1", "follower2"]) ].sync_engine + # apply to AsyncSession using sync_session_class - AsyncSessionMaker = async_sessionmaker( - sync_session_class=RoutingSession - ) + AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession) The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, implicitly non-blocking context in the same manner as ORM event hooks @@ -956,7 +959,7 @@ def begin(self) -> AsyncSessionTransaction: object is entered:: async with async_session.begin(): - # .. ORM transaction is begun + ... # ORM transaction is begun Note that database IO will not normally occur when the session-level transaction is begun, as database transactions begin on an @@ -1309,7 +1312,7 @@ def is_modified( This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously committed value, if any. + value to its previously flushed or committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1633,16 +1636,22 @@ class async_sessionmaker(Generic[_AS]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import async_sessionmaker - async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None: + + async def run_some_sql( + async_session: async_sessionmaker[AsyncSession], + ) -> None: async with async_session() as session: session.add(SomeObject(data="object")) session.add(SomeOtherObject(name="other object")) await session.commit() + async def main() -> None: # an AsyncEngine, which the AsyncSession will use for connection # resources - engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/') + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/" + ) # create a reusable factory for new AsyncSession instances async_session = async_sessionmaker(engine) @@ -1686,8 +1695,7 @@ def __init__( expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... @overload def __init__( @@ -1698,8 +1706,7 @@ def __init__( expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... def __init__( self, @@ -1743,7 +1750,6 @@ async def main(): # commits transaction, closes session - """ session = self() @@ -1776,7 +1782,7 @@ def configure(self, **new_kw: Any) -> None: AsyncSession = async_sessionmaker(some_engine) - AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + AsyncSession.configure(bind=create_async_engine("sqlite+aiosqlite://")) """ # noqa E501 self.kw.update(new_kw) @@ -1862,12 +1868,27 @@ async def commit(self) -> None: await greenlet_spawn(self._sync_transaction().commit) + @classmethod + def _regenerate_proxy_for_target( # type: ignore[override] + cls, + target: SessionTransaction, + async_session: AsyncSession, + **additional_kw: Any, # noqa: U100 + ) -> AsyncSessionTransaction: + sync_transaction = target + nested = target.nested + obj = cls.__new__(cls) + obj.session = async_session + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + async def start( self, is_ctxmanager: bool = False ) -> AsyncSessionTransaction: self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.session.sync_session.begin_nested # type: ignore + self.session.sync_session.begin_nested if self.nested else self.session.sync_session.begin ) diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 18568c7f28f..817f91d267b 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -1,5 +1,5 @@ # ext/automap.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,7 +11,7 @@ It is hoped that the :class:`.AutomapBase` system provides a quick and modernized solution to the problem that the very famous -`SQLSoup `_ +`SQLSoup `_ also tries to solve, that of generating a quick and rudimentary object model from an existing database on the fly. By addressing the issue strictly at the mapper configuration level, and integrating fully with existing @@ -64,7 +64,7 @@ # collection-based relationships are by default named # "_collection" u1 = session.query(User).first() - print (u1.address_collection) + print(u1.address_collection) Above, calling :meth:`.AutomapBase.prepare` while passing along the :paramref:`.AutomapBase.prepare.reflect` parameter indicates that the @@ -101,6 +101,7 @@ from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey from sqlalchemy.ext.automap import automap_base + engine = create_engine("sqlite:///mydatabase.db") # produce our own MetaData object @@ -108,13 +109,15 @@ # we can reflect it ourselves from a database, using options # such as 'only' to limit what tables we look at... - metadata.reflect(engine, only=['user', 'address']) + metadata.reflect(engine, only=["user", "address"]) # ... or just define our own Table objects with it (or combine both) - Table('user_order', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', ForeignKey('user.id')) - ) + Table( + "user_order", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", ForeignKey("user.id")), + ) # we can then produce a set of mappings from this MetaData. Base = automap_base(metadata=metadata) @@ -123,8 +126,9 @@ Base.prepare() # mapped classes are ready - User, Address, Order = Base.classes.user, Base.classes.address,\ - Base.classes.user_order + User = Base.classes.user + Address = Base.classes.address + Order = Base.classes.user_order .. _automap_by_module: @@ -177,18 +181,23 @@ Base.metadata.create_all(e) + def module_name_for_table(cls, tablename, table): if table.schema is not None: return f"mymodule.{table.schema}" else: return f"mymodule.default" + Base = automap_base() Base.prepare(e, modulename_for_table=module_name_for_table) - Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table) - Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table) - + Base.prepare( + e, schema="test_schema", modulename_for_table=module_name_for_table + ) + Base.prepare( + e, schema="test_schema_2", modulename_for_table=module_name_for_table + ) The same named-classes are organized into a hierarchical collection available at :attr:`.AutomapBase.by_module`. This collection is traversed using the @@ -251,12 +260,13 @@ class name. # automap base Base = automap_base() + # pre-declare User for the 'user' table class User(Base): - __tablename__ = 'user' + __tablename__ = "user" # override schema elements like Columns - user_name = Column('name', String) + user_name = Column("name", String) # override relationships too, if desired. # we must use the same name that automap would use for the @@ -264,6 +274,7 @@ class User(Base): # generate for "address" address_collection = relationship("address", collection_class=set) + # reflect engine = create_engine("sqlite:///mydatabase.db") Base.prepare(autoload_with=engine) @@ -274,11 +285,11 @@ class User(Base): Address = Base.classes.address u1 = session.query(User).first() - print (u1.address_collection) + print(u1.address_collection) # the backref is still there: a1 = session.query(Address).first() - print (a1.user) + print(a1.user) Above, one of the more intricate details is that we illustrated overriding one of the :func:`_orm.relationship` objects that automap would have created. @@ -305,35 +316,49 @@ class User(Base): import re import inflect + def camelize_classname(base, tablename, table): - "Produce a 'camelized' class name, e.g. " + "Produce a 'camelized' class name, e.g." "'words_and_underscores' -> 'WordsAndUnderscores'" - return str(tablename[0].upper() + \ - re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) + return str( + tablename[0].upper() + + re.sub( + r"_([a-z])", + lambda m: m.group(1).upper(), + tablename[1:], + ) + ) + _pluralizer = inflect.engine() + + def pluralize_collection(base, local_cls, referred_cls, constraint): - "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "Produce an 'uncamelized', 'pluralized' class name, e.g." "'SomeTerm' -> 'some_terms'" referred_name = referred_cls.__name__ - uncamelized = re.sub(r'[A-Z]', - lambda m: "_%s" % m.group(0).lower(), - referred_name)[1:] + uncamelized = re.sub( + r"[A-Z]", + lambda m: "_%s" % m.group(0).lower(), + referred_name, + )[1:] pluralized = _pluralizer.plural(uncamelized) return pluralized + from sqlalchemy.ext.automap import automap_base Base = automap_base() engine = create_engine("sqlite:///mydatabase.db") - Base.prepare(autoload_with=engine, - classname_for_table=camelize_classname, - name_for_collection_relationship=pluralize_collection - ) + Base.prepare( + autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection, + ) From the above mapping, we would now have classes ``User`` and ``Address``, where the collection from ``User`` to ``Address`` is called @@ -422,16 +447,21 @@ def pluralize_collection(base, local_cls, referred_cls, constraint): options along to all one-to-many relationships:: from sqlalchemy.ext.automap import generate_relationship + from sqlalchemy.orm import interfaces + - def _gen_relationship(base, direction, return_fn, - attrname, local_cls, referred_cls, **kw): + def _gen_relationship( + base, direction, return_fn, attrname, local_cls, referred_cls, **kw + ): if direction is interfaces.ONETOMANY: - kw['cascade'] = 'all, delete-orphan' - kw['passive_deletes'] = True + kw["cascade"] = "all, delete-orphan" + kw["passive_deletes"] = True # make use of the built-in function to actually return # the result. - return generate_relationship(base, direction, return_fn, - attrname, local_cls, referred_cls, **kw) + return generate_relationship( + base, direction, return_fn, attrname, local_cls, referred_cls, **kw + ) + from sqlalchemy.ext.automap import automap_base from sqlalchemy import create_engine @@ -440,8 +470,7 @@ def _gen_relationship(base, direction, return_fn, Base = automap_base() engine = create_engine("sqlite:///mydatabase.db") - Base.prepare(autoload_with=engine, - generate_relationship=_gen_relationship) + Base.prepare(autoload_with=engine, generate_relationship=_gen_relationship) Many-to-Many relationships -------------------------- @@ -482,18 +511,20 @@ def _gen_relationship(base, direction, return_fn, classes given as follows:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) type = Column(String(50)) __mapper_args__ = { - 'polymorphic_identity':'employee', 'polymorphic_on': type + "polymorphic_identity": "employee", + "polymorphic_on": type, } + class Engineer(Employee): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) __mapper_args__ = { - 'polymorphic_identity':'engineer', + "polymorphic_identity": "engineer", } The foreign key from ``Engineer`` to ``Employee`` is used not for a @@ -508,25 +539,28 @@ class Engineer(Employee): SQLAlchemy can guess:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) type = Column(String(50)) __mapper_args__ = { - 'polymorphic_identity':'employee', 'polymorphic_on':type + "polymorphic_identity": "employee", + "polymorphic_on": type, } + class Engineer(Employee): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) - favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey("employee.id")) - favorite_employee = relationship(Employee, - foreign_keys=favorite_employee_id) + favorite_employee = relationship( + Employee, foreign_keys=favorite_employee_id + ) __mapper_args__ = { - 'polymorphic_identity':'engineer', - 'inherit_condition': id == Employee.id + "polymorphic_identity": "engineer", + "inherit_condition": id == Employee.id, } Handling Simple Naming Conflicts @@ -559,20 +593,24 @@ class Engineer(Employee): We can resolve this conflict by using an underscore as follows:: - def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + def name_for_scalar_relationship( + base, local_cls, referred_cls, constraint + ): name = referred_cls.__name__.lower() local_table = local_cls.__table__ if name in local_table.columns: newname = name + "_" warnings.warn( - "Already detected name %s present. using %s" % - (name, newname)) + "Already detected name %s present. using %s" % (name, newname) + ) return newname return name - Base.prepare(autoload_with=engine, - name_for_scalar_relationship=name_for_scalar_relationship) + Base.prepare( + autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship, + ) Alternatively, we can change the name on the column side. The columns that are mapped can be modified using the technique described at @@ -581,12 +619,13 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): Base = automap_base() + class TableB(Base): - __tablename__ = 'table_b' - _table_a = Column('table_a', ForeignKey('table_a.id')) + __tablename__ = "table_b" + _table_a = Column("table_a", ForeignKey("table_a.id")) - Base.prepare(autoload_with=engine) + Base.prepare(autoload_with=engine) Using Automap with Explicit Declarations ======================================== @@ -603,26 +642,29 @@ class TableB(Base): Base = automap_base() + class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id = Column(Integer, primary_key=True) name = Column(String) + class Address(Base): - __tablename__ = 'address' + __tablename__ = "address" id = Column(Integer, primary_key=True) email = Column(String) - user_id = Column(ForeignKey('user.id')) + user_id = Column(ForeignKey("user.id")) + # produce relationships Base.prepare() # mapping is complete, with "address_collection" and # "user" relationships - a1 = Address(email='u1') - a2 = Address(email='u2') + a1 = Address(email="u1") + a2 = Address(email="u2") u1 = User(address_collection=[a1, a2]) assert a1.user is u1 @@ -651,7 +693,8 @@ class Address(Base): @event.listens_for(Base.metadata, "column_reflect") def column_reflect(inspector, table, column_info): # set column.key = "attr_" - column_info['key'] = "attr_%s" % column_info['name'].lower() + column_info["key"] = "attr_%s" % column_info["name"].lower() + # run reflection Base.prepare(autoload_with=engine) @@ -715,8 +758,9 @@ def column_reflect(inspector, table, column_info): class PythonNameForTableType(Protocol): - def __call__(self, base: Type[Any], tablename: str, table: Table) -> str: - ... + def __call__( + self, base: Type[Any], tablename: str, table: Table + ) -> str: ... def classname_for_table( @@ -763,8 +807,7 @@ def __call__( local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: - ... + ) -> str: ... def name_for_scalar_relationship( @@ -804,8 +847,7 @@ def __call__( local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: - ... + ) -> str: ... def name_for_collection_relationship( @@ -850,8 +892,7 @@ def __call__( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Relationship[Any]: - ... + ) -> Relationship[Any]: ... @overload def __call__( @@ -863,8 +904,7 @@ def __call__( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> ORMBackrefArgument: - ... + ) -> ORMBackrefArgument: ... def __call__( self, @@ -877,8 +917,7 @@ def __call__( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Union[ORMBackrefArgument, Relationship[Any]]: - ... + ) -> Union[ORMBackrefArgument, Relationship[Any]]: ... @overload @@ -890,8 +929,7 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> Relationship[Any]: - ... +) -> Relationship[Any]: ... @overload @@ -903,8 +941,7 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> ORMBackrefArgument: - ... +) -> ORMBackrefArgument: ... def generate_relationship( @@ -1008,6 +1045,12 @@ class that is produced by the :func:`.declarative.declarative_base` User, Address = Base.classes.User, Base.classes.Address + For class names that overlap with a method name of + :class:`.util.Properties`, such as ``items()``, the getitem form + is also supported:: + + Item = Base.classes["items"] + """ by_module: ClassVar[ByModuleProperties] diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 64c9ce6ec26..cd3e087931e 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -1,5 +1,5 @@ -# sqlalchemy/ext/baked.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# ext/baked.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -258,23 +258,19 @@ def to_query(self, query_or_session): is passed to the lambda:: sub_bq = self.bakery(lambda s: s.query(User.name)) - sub_bq += lambda q: q.filter( - User.id == Address.user_id).correlate(Address) + sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address) main_bq = self.bakery(lambda s: s.query(Address)) - main_bq += lambda q: q.filter( - sub_bq.to_query(q).exists()) + main_bq += lambda q: q.filter(sub_bq.to_query(q).exists()) In the case where the subquery is used in the first callable against a :class:`.Session`, the :class:`.Session` is also accepted:: sub_bq = self.bakery(lambda s: s.query(User.name)) - sub_bq += lambda q: q.filter( - User.id == Address.user_id).correlate(Address) + sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address) main_bq = self.bakery( - lambda s: s.query( - Address.id, sub_bq.to_query(q).scalar_subquery()) + lambda s: s.query(Address.id, sub_bq.to_query(q).scalar_subquery()) ) :param query_or_session: a :class:`_query.Query` object or a class @@ -285,7 +281,7 @@ def to_query(self, query_or_session): .. versionadded:: 1.3 - """ + """ # noqa: E501 if isinstance(query_or_session, Session): session = query_or_session diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 39a55410305..cc64477ed47 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -1,10 +1,9 @@ # ext/compiler.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r"""Provides an API for creation of custom ClauseElements and compilers. @@ -18,9 +17,11 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ColumnClause + class MyColumn(ColumnClause): inherit_cache = True + @compiles(MyColumn) def compile_mycolumn(element, compiler, **kw): return "[%s]" % element.name @@ -32,10 +33,12 @@ def compile_mycolumn(element, compiler, **kw): from sqlalchemy import select - s = select(MyColumn('x'), MyColumn('y')) + s = select(MyColumn("x"), MyColumn("y")) print(str(s)) -Produces:: +Produces: + +.. sourcecode:: sql SELECT [x], [y] @@ -47,6 +50,7 @@ def compile_mycolumn(element, compiler, **kw): from sqlalchemy.schema import DDLElement + class AlterColumn(DDLElement): inherit_cache = False @@ -54,14 +58,18 @@ def __init__(self, column, cmd): self.column = column self.cmd = cmd + @compiles(AlterColumn) def visit_alter_column(element, compiler, **kw): return "ALTER COLUMN %s ..." % element.column.name - @compiles(AlterColumn, 'postgresql') + + @compiles(AlterColumn, "postgresql") def visit_alter_column(element, compiler, **kw): - return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, - element.column.name) + return "ALTER TABLE %s ALTER COLUMN %s ..." % ( + element.table.name, + element.column.name, + ) The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. @@ -81,6 +89,7 @@ def visit_alter_column(element, compiler, **kw): from sqlalchemy.sql.expression import Executable, ClauseElement + class InsertFromSelect(Executable, ClauseElement): inherit_cache = False @@ -88,20 +97,27 @@ def __init__(self, table, select): self.table = table self.select = select + @compiles(InsertFromSelect) def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True, **kw), - compiler.process(element.select, **kw) + compiler.process(element.select, **kw), ) - insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) + + insert = InsertFromSelect(t1, select(t1).where(t1.c.x > 5)) print(insert) -Produces:: +Produces (formatted for readability): - "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z - FROM mytable WHERE mytable.x > :x_1)" +.. sourcecode:: sql + + INSERT INTO mytable ( + SELECT mytable.x, mytable.y, mytable.z + FROM mytable + WHERE mytable.x > :x_1 + ) .. note:: @@ -121,11 +137,10 @@ def visit_insert_from_select(element, compiler, **kw): @compiles(MyConstraint) def compile_my_constraint(constraint, ddlcompiler, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "CONSTRAINT %s CHECK (%s)" % ( constraint.name, - ddlcompiler.sql_compiler.process( - constraint.expression, **kw) + ddlcompiler.sql_compiler.process(constraint.expression, **kw), ) Above, we add an additional flag to the process step as called by @@ -153,6 +168,7 @@ def compile_my_constraint(constraint, ddlcompiler, **kw): from sqlalchemy.sql.expression import Insert + @compiles(Insert) def prefix_inserts(insert, compiler, **kw): return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) @@ -168,17 +184,16 @@ def prefix_inserts(insert, compiler, **kw): ``compiler`` works for types, too, such as below where we implement the MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: - @compiles(String, 'mssql') - @compiles(VARCHAR, 'mssql') + @compiles(String, "mssql") + @compiles(VARCHAR, "mssql") def compile_varchar(element, compiler, **kw): - if element.length == 'max': + if element.length == "max": return "VARCHAR('max')" else: return compiler.visit_VARCHAR(element, **kw) - foo = Table('foo', metadata, - Column('data', VARCHAR('max')) - ) + + foo = Table("foo", metadata, Column("data", VARCHAR("max"))) Subclassing Guidelines ====================== @@ -216,18 +231,23 @@ class timestamp(ColumnElement): from sqlalchemy.sql.expression import FunctionElement + class coalesce(FunctionElement): - name = 'coalesce' + name = "coalesce" inherit_cache = True + @compiles(coalesce) def compile(element, compiler, **kw): return "coalesce(%s)" % compiler.process(element.clauses, **kw) - @compiles(coalesce, 'oracle') + + @compiles(coalesce, "oracle") def compile(element, compiler, **kw): if len(element.clauses) > 2: - raise TypeError("coalesce only supports two arguments on Oracle") + raise TypeError( + "coalesce only supports two arguments on " "Oracle Database" + ) return "nvl(%s)" % compiler.process(element.clauses, **kw) * :class:`.ExecutableDDLElement` - The root of all DDL expressions, @@ -281,6 +301,7 @@ def compile(element, compiler, **kw): class MyColumn(ColumnClause): inherit_cache = True + @compiles(MyColumn) def compile_mycolumn(element, compiler, **kw): return "[%s]" % element.name @@ -319,11 +340,12 @@ def __init__(self, table, select): self.table = table self.select = select + @compiles(InsertFromSelect) def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True, **kw), - compiler.process(element.select, **kw) + compiler.process(element.select, **kw), ) While it is also possible that the above ``InsertFromSelect`` could be made to @@ -359,28 +381,32 @@ def visit_insert_from_select(element, compiler, **kw): from sqlalchemy.ext.compiler import compiles from sqlalchemy.types import DateTime + class utcnow(expression.FunctionElement): type = DateTime() inherit_cache = True - @compiles(utcnow, 'postgresql') + + @compiles(utcnow, "postgresql") def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" - @compiles(utcnow, 'mssql') + + @compiles(utcnow, "mssql") def ms_utcnow(element, compiler, **kw): return "GETUTCDATE()" Example usage:: - from sqlalchemy import ( - Table, Column, Integer, String, DateTime, MetaData - ) + from sqlalchemy import Table, Column, Integer, String, DateTime, MetaData + metadata = MetaData() - event = Table("event", metadata, + event = Table( + "event", + metadata, Column("id", Integer, primary_key=True), Column("description", String(50), nullable=False), - Column("timestamp", DateTime, server_default=utcnow()) + Column("timestamp", DateTime, server_default=utcnow()), ) "GREATEST" function @@ -395,30 +421,30 @@ def ms_utcnow(element, compiler, **kw): from sqlalchemy.ext.compiler import compiles from sqlalchemy.types import Numeric + class greatest(expression.FunctionElement): type = Numeric() - name = 'greatest' + name = "greatest" inherit_cache = True + @compiles(greatest) def default_greatest(element, compiler, **kw): return compiler.visit_function(element) - @compiles(greatest, 'sqlite') - @compiles(greatest, 'mssql') - @compiles(greatest, 'oracle') + + @compiles(greatest, "sqlite") + @compiles(greatest, "mssql") + @compiles(greatest, "oracle") def case_greatest(element, compiler, **kw): arg1, arg2 = list(element.clauses) return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw) Example usage:: - Session.query(Account).\ - filter( - greatest( - Account.checking_balance, - Account.savings_balance) > 10000 - ) + Session.query(Account).filter( + greatest(Account.checking_balance, Account.savings_balance) > 10000 + ) "false" expression ------------------ @@ -429,16 +455,19 @@ def case_greatest(element, compiler, **kw): from sqlalchemy.sql import expression from sqlalchemy.ext.compiler import compiles + class sql_false(expression.ColumnElement): inherit_cache = True + @compiles(sql_false) def default_false(element, compiler, **kw): return "false" - @compiles(sql_false, 'mssql') - @compiles(sql_false, 'mysql') - @compiles(sql_false, 'oracle') + + @compiles(sql_false, "mssql") + @compiles(sql_false, "mysql") + @compiles(sql_false, "oracle") def int_false(element, compiler, **kw): return "0" @@ -448,19 +477,33 @@ def int_false(element, compiler, **kw): exp = union_all( select(users.c.name, sql_false().label("enrolled")), - select(customers.c.name, customers.c.enrolled) + select(customers.c.name, customers.c.enrolled), ) """ +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + from .. import exc from ..sql import sqltypes +if TYPE_CHECKING: + from ..sql.compiler import SQLCompiler + +_F = TypeVar("_F", bound=Callable[..., Any]) + -def compiles(class_, *specs): +def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]: """Register a function as a compiler for a given :class:`_expression.ClauseElement` type.""" - def decorate(fn): + def decorate(fn: _F) -> _F: # get an existing @compiles handler existing = class_.__dict__.get("_compiler_dispatcher", None) @@ -473,7 +516,9 @@ def decorate(fn): if existing_dispatch: - def _wrap_existing_dispatch(element, compiler, **kw): + def _wrap_existing_dispatch( + element: Any, compiler: SQLCompiler, **kw: Any + ) -> Any: try: return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError as uce: @@ -505,7 +550,7 @@ def _wrap_existing_dispatch(element, compiler, **kw): return decorate -def deregister(class_): +def deregister(class_: Type[Any]) -> None: """Remove all custom compilers associated with a given :class:`_expression.ClauseElement` type. @@ -517,10 +562,10 @@ def deregister(class_): class _dispatcher: - def __init__(self): - self.specs = {} + def __init__(self) -> None: + self.specs: Dict[str, Callable[..., Any]] = {} - def __call__(self, element, compiler, **kw): + def __call__(self, element: Any, compiler: SQLCompiler, **kw: Any) -> Any: # TODO: yes, this could also switch off of DBAPI in use. fn = self.specs.get(compiler.dialect.name, None) if not fn: diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 2f6b2f23fa8..0383f9d34f8 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -1,5 +1,5 @@ # ext/declarative/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index acc9d08cfbf..3dc6bf698c4 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -1,5 +1,5 @@ # ext/declarative/extensions.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -50,23 +50,26 @@ class ConcreteBase: from sqlalchemy.ext.declarative import ConcreteBase + class Employee(ConcreteBase, Base): - __tablename__ = 'employee' + __tablename__ = "employee" employee_id = Column(Integer, primary_key=True) name = Column(String(50)) __mapper_args__ = { - 'polymorphic_identity':'employee', - 'concrete':True} + "polymorphic_identity": "employee", + "concrete": True, + } + class Manager(Employee): - __tablename__ = 'manager' + __tablename__ = "manager" employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True} - + "polymorphic_identity": "manager", + "concrete": True, + } The name of the discriminator column used by :func:`.polymorphic_union` defaults to the name ``type``. To suit the use case of a mapping where an @@ -75,7 +78,7 @@ class Manager(Employee): ``_concrete_discriminator_name`` attribute:: class Employee(ConcreteBase, Base): - _concrete_discriminator_name = '_concrete_discriminator' + _concrete_discriminator_name = "_concrete_discriminator" .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` attribute to :class:`_declarative.ConcreteBase` so that the @@ -168,23 +171,27 @@ class AbstractConcreteBase(ConcreteBase): from sqlalchemy.orm import DeclarativeBase from sqlalchemy.ext.declarative import AbstractConcreteBase + class Base(DeclarativeBase): pass + class Employee(AbstractConcreteBase, Base): pass + class Manager(Employee): - __tablename__ = 'manager' + __tablename__ = "manager" employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True + "polymorphic_identity": "manager", + "concrete": True, } + Base.registry.configure() The abstract base class is handled by declarative in a special way; @@ -200,10 +207,12 @@ class Manager(Employee): from sqlalchemy.ext.declarative import AbstractConcreteBase + class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) + class Employee(AbstractConcreteBase, Base): strict_attrs = True @@ -211,31 +220,31 @@ class Employee(AbstractConcreteBase, Base): @declared_attr def company_id(cls): - return Column(ForeignKey('company.id')) + return Column(ForeignKey("company.id")) @declared_attr def company(cls): return relationship("Company") + class Manager(Employee): - __tablename__ = 'manager' + __tablename__ = "manager" name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True + "polymorphic_identity": "manager", + "concrete": True, } + Base.registry.configure() When we make use of our mappings however, both ``Manager`` and ``Employee`` will have an independently usable ``.company`` attribute:: - session.execute( - select(Employee).filter(Employee.company.has(id=5)) - ) + session.execute(select(Employee).filter(Employee.company.has(id=5))) :param strict_attrs: when specified on the base class, "strict" attribute mode is enabled which attempts to limit ORM mapped attributes on the @@ -366,10 +375,12 @@ class DeferredReflection: from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + class MyClass(DeferredReflection, Base): - __tablename__ = 'mytable' + __tablename__ = "mytable" Above, ``MyClass`` is not yet mapped. After a series of classes have been defined in the above fashion, all tables @@ -391,17 +402,22 @@ class MyClass(DeferredReflection, Base): class ReflectedOne(DeferredReflection, Base): __abstract__ = True + class ReflectedTwo(DeferredReflection, Base): __abstract__ = True + class MyClass(ReflectedOne): - __tablename__ = 'mytable' + __tablename__ = "mytable" + class MyOtherClass(ReflectedOne): - __tablename__ = 'myothertable' + __tablename__ = "myothertable" + class YetAnotherClass(ReflectedTwo): - __tablename__ = 'yetanothertable' + __tablename__ = "yetanothertable" + # ... etc. diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 963bd005a4b..3ea3304eb30 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -1,5 +1,5 @@ # ext/horizontal_shard.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -83,8 +83,7 @@ def __call__( mapper: Optional[Mapper[_T]], instance: Any, clause: Optional[ClauseElement], - ) -> Any: - ... + ) -> Any: ... class IdentityChooser(Protocol): @@ -97,8 +96,7 @@ def __call__( execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, **kw: Any, - ) -> Any: - ... + ) -> Any: ... class ShardedQuery(Query[_T]): @@ -127,12 +125,9 @@ def set_shard(self, shard_id: ShardIdentifier) -> Self: The shard_id can be passed for a 2.0 style execution to the bind_arguments dictionary of :meth:`.Session.execute`:: - results = session.execute( - stmt, - bind_arguments={"shard_id": "my_shard"} - ) + results = session.execute(stmt, bind_arguments={"shard_id": "my_shard"}) - """ + """ # noqa: E501 return self.execution_options(_sa_shard_id=shard_id) @@ -323,7 +318,7 @@ def _choose_shard_and_assign( state.identity_token = shard_id return shard_id - def connection_callable( # type: ignore [override] + def connection_callable( self, mapper: Optional[Mapper[_T]] = None, instance: Optional[Any] = None, @@ -384,9 +379,9 @@ class set_shard_id(ORMOption): the :meth:`_sql.Executable.options` method of any executable statement:: stmt = ( - select(MyObject). - where(MyObject.name == 'some name'). - options(set_shard_id("shard1")) + select(MyObject) + .where(MyObject.name == "some name") + .options(set_shard_id("shard1")) ) Above, the statement when invoked will limit to the "shard1" shard diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 615f166b479..c1c46e7c5f5 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -1,5 +1,5 @@ # ext/hybrid.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -34,8 +34,9 @@ class level and at the instance level. class Base(DeclarativeBase): pass + class Interval(Base): - __tablename__ = 'interval' + __tablename__ = "interval" id: Mapped[int] = mapped_column(primary_key=True) start: Mapped[int] @@ -57,7 +58,6 @@ def contains(self, point: int) -> bool: def intersects(self, other: Interval) -> bool: return self.contains(other.start) | self.contains(other.end) - Above, the ``length`` property returns the difference between the ``end`` and ``start`` attributes. With an instance of ``Interval``, this subtraction occurs in Python, using normal Python descriptor @@ -150,6 +150,7 @@ def intersects(self, other: Interval) -> bool: from sqlalchemy import func from sqlalchemy import type_coerce + class Interval(Base): # ... @@ -214,6 +215,7 @@ def _radius_expression(cls) -> ColumnElement[float]: # correct use, however is not accepted by pep-484 tooling + class Interval(Base): # ... @@ -256,6 +258,7 @@ def radius(cls): # correct use which is also accepted by pep-484 tooling + class Interval(Base): # ... @@ -330,6 +333,7 @@ def _length_setter(self, value: int) -> None: ``Interval.start``, this could be substituted directly:: from sqlalchemy import update + stmt = update(Interval).values({Interval.start_point: 10}) However, when using a composite hybrid like ``Interval.length``, this @@ -340,6 +344,7 @@ def _length_setter(self, value: int) -> None: from typing import List, Tuple, Any + class Interval(Base): # ... @@ -352,10 +357,10 @@ def _length_setter(self, value: int) -> None: self.end = self.start + value @length.inplace.update_expression - def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: - return [ - (cls.end, cls.start + value) - ] + def _length_update_expression( + cls, value: Any + ) -> List[Tuple[Any, Any]]: + return [(cls.end, cls.start + value)] Above, if we use ``Interval.length`` in an UPDATE expression, we get a hybrid SET expression: @@ -412,15 +417,16 @@ class Base(DeclarativeBase): class SavingsAccount(Base): - __tablename__ = 'account' + __tablename__ = "account" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) owner: Mapped[User] = relationship(back_populates="accounts") + class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) @@ -448,7 +454,10 @@ def _balance_setter(self, value: Optional[Decimal]) -> None: @balance.inplace.expression @classmethod def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: - return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) + return cast( + "SQLColumnExpression[Optional[Decimal]]", + SavingsAccount.balance, + ) The above hybrid property ``balance`` works with the first ``SavingsAccount`` entry in the list of accounts for this user. The @@ -471,8 +480,11 @@ def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: .. sourcecode:: pycon+sql >>> from sqlalchemy import select - >>> print(select(User, User.balance). - ... join(User.accounts).filter(User.balance > 5000)) + >>> print( + ... select(User, User.balance) + ... .join(User.accounts) + ... .filter(User.balance > 5000) + ... ) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance FROM "user" JOIN account ON "user".id = account.user_id @@ -487,8 +499,11 @@ def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: >>> from sqlalchemy import select >>> from sqlalchemy import or_ - >>> print (select(User, User.balance).outerjoin(User.accounts). - ... filter(or_(User.balance < 5000, User.balance == None))) + >>> print( + ... select(User, User.balance) + ... .outerjoin(User.accounts) + ... .filter(or_(User.balance < 5000, User.balance == None)) + ... ) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id @@ -528,15 +543,16 @@ class Base(DeclarativeBase): class SavingsAccount(Base): - __tablename__ = 'account' + __tablename__ = "account" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) owner: Mapped[User] = relationship(back_populates="accounts") + class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) @@ -546,7 +562,9 @@ class User(Base): @hybrid_property def balance(self) -> Decimal: - return sum((acc.balance for acc in self.accounts), start=Decimal("0")) + return sum( + (acc.balance for acc in self.accounts), start=Decimal("0") + ) @balance.inplace.expression @classmethod @@ -557,7 +575,6 @@ def _balance_expression(cls) -> SQLColumnExpression[Decimal]: .label("total_balance") ) - The above recipe will give us the ``balance`` column which renders a correlated SELECT: @@ -604,6 +621,7 @@ def _balance_expression(cls) -> SQLColumnExpression[Decimal]: from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column + class Base(DeclarativeBase): pass @@ -612,8 +630,9 @@ class CaseInsensitiveComparator(Comparator[str]): def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 return func.lower(self.__clause_element__()) == func.lower(other) + class SearchWord(Base): - __tablename__ = 'searchword' + __tablename__ = "searchword" id: Mapped[int] = mapped_column(primary_key=True) word: Mapped[str] @@ -675,6 +694,7 @@ def name(self) -> str: def _name_setter(self, value: str) -> None: self.first_name = value + class FirstNameLastName(FirstNameOnly): # ... @@ -684,11 +704,11 @@ class FirstNameLastName(FirstNameOnly): # of FirstNameOnly.name that is local to FirstNameLastName @FirstNameOnly.name.getter def name(self) -> str: - return self.first_name + ' ' + self.last_name + return self.first_name + " " + self.last_name @name.inplace.setter def _name_setter(self, value: str) -> None: - self.first_name, self.last_name = value.split(' ', 1) + self.first_name, self.last_name = value.split(" ", 1) Above, the ``FirstNameLastName`` class refers to the hybrid from ``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. @@ -709,8 +729,7 @@ class FirstNameLastName(FirstNameOnly): @FirstNameOnly.name.overrides.expression @classmethod def name(cls): - return func.concat(cls.first_name, ' ', cls.last_name) - + return func.concat(cls.first_name, " ", cls.last_name) Hybrid Value Objects -------------------- @@ -751,7 +770,7 @@ def __clause_element__(self): def __str__(self): return self.word - key = 'word' + key = "word" "Label to apply to Query tuple results" Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may @@ -762,7 +781,7 @@ def __str__(self): ``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: class SearchWord(Base): - __tablename__ = 'searchword' + __tablename__ = "searchword" id: Mapped[int] = mapped_column(primary_key=True) word: Mapped[str] @@ -904,13 +923,11 @@ class HybridExtensionType(InspectionAttrExtensionType): class _HybridGetterType(Protocol[_T_co]): - def __call__(s, self: Any) -> _T_co: - ... + def __call__(s, self: Any) -> _T_co: ... class _HybridSetterType(Protocol[_T_con]): - def __call__(s, self: Any, value: _T_con) -> None: - ... + def __call__(s, self: Any, value: _T_con) -> None: ... class _HybridUpdaterType(Protocol[_T_con]): @@ -918,25 +935,21 @@ def __call__( s, cls: Any, value: Union[_T_con, _ColumnExpressionArgument[_T_con]], - ) -> List[Tuple[_DMLColumnArgument, Any]]: - ... + ) -> List[Tuple[_DMLColumnArgument, Any]]: ... class _HybridDeleterType(Protocol[_T_co]): - def __call__(s, self: Any) -> None: - ... + def __call__(s, self: Any) -> None: ... class _HybridExprCallableType(Protocol[_T_co]): def __call__( s, cls: Any - ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]: - ... + ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... class _HybridComparatorCallableType(Protocol[_T]): - def __call__(self, cls: Any) -> Comparator[_T]: - ... + def __call__(self, cls: Any) -> Comparator[_T]: ... class _HybridClassLevelAccessor(QueryableAttribute[_T]): @@ -947,23 +960,24 @@ class _HybridClassLevelAccessor(QueryableAttribute[_T]): if TYPE_CHECKING: - def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: - ... + def getter( + self, fget: _HybridGetterType[_T] + ) -> hybrid_property[_T]: ... - def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: - ... + def setter( + self, fset: _HybridSetterType[_T] + ) -> hybrid_property[_T]: ... - def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: - ... + def deleter( + self, fdel: _HybridDeleterType[_T] + ) -> hybrid_property[_T]: ... @property - def overrides(self) -> hybrid_property[_T]: - ... + def overrides(self) -> hybrid_property[_T]: ... def update_expression( self, meth: _HybridUpdaterType[_T] - ) -> hybrid_property[_T]: - ... + ) -> hybrid_property[_T]: ... class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): @@ -988,6 +1002,7 @@ def __init__( from sqlalchemy.ext.hybrid import hybrid_method + class SomeClass: @hybrid_method def value(self, x, y): @@ -1025,14 +1040,12 @@ def inplace(self) -> Self: @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> Callable[_P, SQLCoreOperations[_R]]: - ... + ) -> Callable[_P, SQLCoreOperations[_R]]: ... @overload def __get__( self, instance: object, owner: Type[object] - ) -> Callable[_P, _R]: - ... + ) -> Callable[_P, _R]: ... def __get__( self, instance: Optional[object], owner: Type[object] @@ -1087,6 +1100,7 @@ def __init__( from sqlalchemy.ext.hybrid import hybrid_property + class SomeClass: @hybrid_property def value(self): @@ -1103,21 +1117,18 @@ def value(self, value): self.expr = _unwrap_classmethod(expr) self.custom_comparator = _unwrap_classmethod(custom_comparator) self.update_expr = _unwrap_classmethod(update_expr) - util.update_wrapper(self, fget) + util.update_wrapper(self, fget) # type: ignore[arg-type] @overload - def __get__(self, instance: Any, owner: Literal[None]) -> Self: - ... + def __get__(self, instance: Any, owner: Literal[None]) -> Self: ... @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> _HybridClassLevelAccessor[_T]: - ... + ) -> _HybridClassLevelAccessor[_T]: ... @overload - def __get__(self, instance: object, owner: Type[object]) -> _T: - ... + def __get__(self, instance: object, owner: Type[object]) -> _T: ... def __get__( self, instance: Optional[object], owner: Optional[Type[object]] @@ -1168,6 +1179,7 @@ class SuperClass: def foobar(self): return self._foobar + class SubClass(SuperClass): # ... @@ -1377,10 +1389,7 @@ def fullname(self): @fullname.update_expression def fullname(cls, value): fname, lname = value.split(" ", 1) - return [ - (cls.first_name, fname), - (cls.last_name, lname) - ] + return [(cls.first_name, fname), (cls.last_name, lname)] .. versionadded:: 1.2 @@ -1447,7 +1456,7 @@ class Comparator(interfaces.PropComparator[_T]): classes for usage with hybrids.""" def __init__( - self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]] + self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]] ): self.expression = expression @@ -1482,7 +1491,7 @@ class ExprComparator(Comparator[_T]): def __init__( self, cls: Type[Any], - expression: Union[_HasClauseElement, SQLColumnExpression[_T]], + expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]], hybrid: hybrid_property[_T], ): self.cls = cls diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index dbaad3c4077..886069ce000 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -1,5 +1,5 @@ -# ext/index.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# ext/indexable.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -36,19 +36,19 @@ Base = declarative_base() + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) data = Column(JSON) - name = index_property('data', 'name') - + name = index_property("data", "name") Above, the ``name`` attribute now behaves like a mapped column. We can compose a new ``Person`` and set the value of ``name``:: - >>> person = Person(name='Alchemist') + >>> person = Person(name="Alchemist") The value is now accessible:: @@ -59,11 +59,11 @@ class Person(Base): and the field was set:: >>> person.data - {"name": "Alchemist'} + {'name': 'Alchemist'} The field is mutable in place:: - >>> person.name = 'Renamed' + >>> person.name = "Renamed" >>> person.name 'Renamed' >>> person.data @@ -87,18 +87,17 @@ class Person(Base): >>> person = Person() >>> person.name - ... AttributeError: 'name' Unless you set a default value:: >>> class Person(Base): - >>> __tablename__ = 'person' - >>> - >>> id = Column(Integer, primary_key=True) - >>> data = Column(JSON) - >>> - >>> name = index_property('data', 'name', default=None) # See default + ... __tablename__ = "person" + ... + ... id = Column(Integer, primary_key=True) + ... data = Column(JSON) + ... + ... name = index_property("data", "name", default=None) # See default >>> person = Person() >>> print(person.name) @@ -111,11 +110,11 @@ class Person(Base): >>> from sqlalchemy.orm import Session >>> session = Session() - >>> query = session.query(Person).filter(Person.name == 'Alchemist') + >>> query = session.query(Person).filter(Person.name == "Alchemist") The above query is equivalent to:: - >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') + >>> query = session.query(Person).filter(Person.data["name"] == "Alchemist") Multiple :class:`.index_property` objects can be chained to produce multiple levels of indexing:: @@ -126,22 +125,25 @@ class Person(Base): Base = declarative_base() + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) data = Column(JSON) - birthday = index_property('data', 'birthday') - year = index_property('birthday', 'year') - month = index_property('birthday', 'month') - day = index_property('birthday', 'day') + birthday = index_property("data", "birthday") + year = index_property("birthday", "year") + month = index_property("birthday", "month") + day = index_property("birthday", "day") Above, a query such as:: - q = session.query(Person).filter(Person.year == '1980') + q = session.query(Person).filter(Person.year == "1980") -On a PostgreSQL backend, the above query will render as:: +On a PostgreSQL backend, the above query will render as: + +.. sourcecode:: sql SELECT person.id, person.data FROM person @@ -198,13 +200,14 @@ def expr(self, model): Base = declarative_base() + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) data = Column(JSON) - age = pg_json_property('data', 'age', Integer) + age = pg_json_property("data", "age", Integer) The ``age`` attribute at the instance level works as before; however when rendering SQL, PostgreSQL's ``->>`` operator will be used @@ -212,7 +215,8 @@ class Person(Base): >>> query = session.query(Person).filter(Person.age < 20) -The above query will render:: +The above query will render: +.. sourcecode:: sql SELECT person.id, person.data FROM person diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 688c762e72b..8bb01985ecc 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -1,5 +1,5 @@ # ext/instrumentation.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -214,9 +214,9 @@ def dict_of(self, instance): )(instance) -orm_instrumentation._instrumentation_factory = ( - _instrumentation_factory -) = ExtendedInstrumentationRegistry() +orm_instrumentation._instrumentation_factory = _instrumentation_factory = ( + ExtendedInstrumentationRegistry() +) orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -436,17 +436,15 @@ def _install_lookups(lookups): instance_dict = lookups["instance_dict"] manager_of_class = lookups["manager_of_class"] opt_manager_of_class = lookups["opt_manager_of_class"] - orm_base.instance_state = ( - attributes.instance_state - ) = orm_instrumentation.instance_state = instance_state - orm_base.instance_dict = ( - attributes.instance_dict - ) = orm_instrumentation.instance_dict = instance_dict - orm_base.manager_of_class = ( - attributes.manager_of_class - ) = orm_instrumentation.manager_of_class = manager_of_class - orm_base.opt_manager_of_class = ( - orm_util.opt_manager_of_class - ) = ( + orm_base.instance_state = attributes.instance_state = ( + orm_instrumentation.instance_state + ) = instance_state + orm_base.instance_dict = attributes.instance_dict = ( + orm_instrumentation.instance_dict + ) = instance_dict + orm_base.manager_of_class = attributes.manager_of_class = ( + orm_instrumentation.manager_of_class + ) = manager_of_class + orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = ( attributes.opt_manager_of_class ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 0f82518aaa1..3d568fc9892 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -1,5 +1,5 @@ # ext/mutable.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -21,6 +21,7 @@ from sqlalchemy.types import TypeDecorator, VARCHAR import json + class JSONEncodedDict(TypeDecorator): "Represents an immutable structure as a json-encoded string." @@ -48,6 +49,7 @@ def process_result_value(self, value, dialect): from sqlalchemy.ext.mutable import Mutable + class MutableDict(Mutable, dict): @classmethod def coerce(cls, key, value): @@ -101,9 +103,11 @@ class and associates a listener that will detect all future mappings from sqlalchemy import Table, Column, Integer - my_data = Table('my_data', metadata, - Column('id', Integer, primary_key=True), - Column('data', MutableDict.as_mutable(JSONEncodedDict)) + my_data = Table( + "my_data", + metadata, + Column("id", Integer, primary_key=True), + Column("data", MutableDict.as_mutable(JSONEncodedDict)), ) Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` @@ -115,13 +119,17 @@ class and associates a listener that will detect all future mappings from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column + class Base(DeclarativeBase): pass + class MyDataClass(Base): - __tablename__ = 'my_data' + __tablename__ = "my_data" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + data: Mapped[dict[str, str]] = mapped_column( + MutableDict.as_mutable(JSONEncodedDict) + ) The ``MyDataClass.data`` member will now be notified of in place changes to its value. @@ -132,11 +140,11 @@ class MyDataClass(Base): >>> from sqlalchemy.orm import Session >>> sess = Session(some_engine) - >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> m1 = MyDataClass(data={"value1": "foo"}) >>> sess.add(m1) >>> sess.commit() - >>> m1.data['value1'] = 'bar' + >>> m1.data["value1"] = "bar" >>> assert m1 in sess.dirty True @@ -153,15 +161,16 @@ class MyDataClass(Base): MutableDict.associate_with(JSONEncodedDict) + class Base(DeclarativeBase): pass + class MyDataClass(Base): - __tablename__ = 'my_data' + __tablename__ = "my_data" id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict) - Supporting Pickling -------------------- @@ -180,7 +189,7 @@ class MyDataClass(Base): class MyMutableType(Mutable): def __getstate__(self): d = self.__dict__.copy() - d.pop('_parents', None) + d.pop("_parents", None) return d With our dictionary example, we need to return the contents of the dict itself @@ -213,13 +222,18 @@ def __setstate__(self, state): from sqlalchemy.orm import mapped_column from sqlalchemy import event + class Base(DeclarativeBase): pass + class MyDataClass(Base): - __tablename__ = 'my_data' + __tablename__ = "my_data" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + data: Mapped[dict[str, str]] = mapped_column( + MutableDict.as_mutable(JSONEncodedDict) + ) + @event.listens_for(MyDataClass.data, "modified") def modified_json(instance, initiator): @@ -247,6 +261,7 @@ class introduced in :ref:`mapper_composite` to include import dataclasses from sqlalchemy.ext.mutable import MutableComposite + @dataclasses.dataclass class Point(MutableComposite): x: int @@ -261,7 +276,6 @@ def __setattr__(self, key, value): # alert all parents to the change self.changed() - The :class:`.MutableComposite` class makes use of class mapping events to automatically establish listeners for any usage of :func:`_orm.composite` that specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` @@ -271,6 +285,7 @@ def __setattr__(self, key, value): from sqlalchemy.orm import DeclarativeBase, Mapped from sqlalchemy.orm import composite, mapped_column + class Base(DeclarativeBase): pass @@ -280,8 +295,12 @@ class Vertex(Base): id: Mapped[int] = mapped_column(primary_key=True) - start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1")) - end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2")) + start: Mapped[Point] = composite( + mapped_column("x1"), mapped_column("y1") + ) + end: Mapped[Point] = composite( + mapped_column("x2"), mapped_column("y2") + ) def __repr__(self): return f"Vertex(start={self.start}, end={self.end})" @@ -378,6 +397,7 @@ def __setstate__(self, state): from .. import event from .. import inspect from .. import types +from .. import util from ..orm import Mapper from ..orm._typing import _ExternalEntityType from ..orm._typing import _O @@ -390,6 +410,7 @@ def __setstate__(self, state): from ..orm.decl_api import DeclarativeAttributeIntercept from ..orm.state import InstanceState from ..orm.unitofwork import UOWTransaction +from ..sql._typing import _TypeEngineArgument from ..sql.base import SchemaEventTarget from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -503,6 +524,7 @@ def load(state: InstanceState[_O], *args: Any) -> None: if val is not None: if coerce: val = cls.coerce(key, val) + assert val is not None state.dict[key] = val val._parents[state] = key @@ -637,7 +659,7 @@ def listen_for_type(mapper: Mapper[_O], class_: type) -> None: event.listen(Mapper, "mapper_configured", listen_for_type) @classmethod - def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: + def as_mutable(cls, sqltype: _TypeEngineArgument[_T]) -> TypeEngine[_T]: """Associate a SQL type with this mutable Python type. This establishes listeners that will detect ORM mappings against @@ -646,9 +668,11 @@ def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: The type is returned, unconditionally as an instance, so that :meth:`.as_mutable` can be used inline:: - Table('mytable', metadata, - Column('id', Integer, primary_key=True), - Column('data', MyMutableType.as_mutable(PickleType)) + Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", MyMutableType.as_mutable(PickleType)), ) Note that the returned type is always an instance, even if a class @@ -799,15 +823,12 @@ def __setitem__(self, key: _KT, value: _VT) -> None: @overload def setdefault( self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - def setdefault(self, key: _KT, value: _VT) -> _VT: - ... + def setdefault(self, key: _KT, value: _VT) -> _VT: ... - def setdefault(self, key: _KT, value: object = None) -> object: - ... + def setdefault(self, key: _KT, value: object = None) -> object: ... else: @@ -828,17 +849,14 @@ def update(self, *a: Any, **kw: _VT) -> None: if TYPE_CHECKING: @overload - def pop(self, __key: _KT) -> _VT: - ... + def pop(self, __key: _KT) -> _VT: ... @overload - def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: - ... + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: ... def pop( self, __key: _KT, __default: _VT | _T | None = None - ) -> _VT | _T: - ... + ) -> _VT | _T: ... else: @@ -909,10 +927,10 @@ def __setstate__(self, state: Iterable[_T]) -> None: self[:] = state def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: - return not isinstance(value, Iterable) + return not util.is_non_string_iterable(value) def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: - return isinstance(value, Iterable) + return util.is_non_string_iterable(value) def __setitem__( self, index: SupportsIndex | slice, value: _T | Iterable[_T] diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py index e69de29bb2d..b5827cb8d36 100644 --- a/lib/sqlalchemy/ext/mypy/__init__.py +++ b/lib/sqlalchemy/ext/mypy/__init__.py @@ -0,0 +1,6 @@ +# ext/mypy/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 1bfaf1d7b0b..02908cc14b4 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -1,5 +1,5 @@ # ext/mypy/apply.py -# Copyright (C) 2021 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -161,9 +161,9 @@ def re_apply_declarative_assignments( # update the SQLAlchemyAttribute with the better # information - mapped_attr_lookup[ - stmt.lvalues[0].name - ].type = python_type_for_type + mapped_attr_lookup[stmt.lvalues[0].name].type = ( + python_type_for_type + ) update_cls_metadata = True @@ -199,11 +199,15 @@ class User(Base): To one that describes the final Python behavior to Mypy:: + ... format: off + class User(Base): # ... attrname : Mapped[Optional[int]] = + ... format: on + """ left_node = lvalue.node assert isinstance(left_node, Var) @@ -223,9 +227,11 @@ class User(Base): lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, - [AnyType(TypeOfAny.special_form)] - if python_type_for_type is None - else [python_type_for_type], + ( + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type] + ), ) # so to have it skip the right side totally, we can do this: diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 9c7b44b7586..2ce7ad56ccc 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -1,5 +1,5 @@ # ext/mypy/decl_class.py -# Copyright (C) 2021 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,9 +58,9 @@ def scan_declarative_assignments_and_apply_types( elif cls.fullname.startswith("builtins"): return None - mapped_attributes: Optional[ - List[util.SQLAlchemyAttribute] - ] = util.get_mapped_attributes(info, api) + mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = ( + util.get_mapped_attributes(info, api) + ) # used by assign.add_additional_orm_attributes among others util.establish_as_sqlalchemy(info) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index e8345d09ae3..26a83cca836 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -1,5 +1,5 @@ # ext/mypy/infer.py -# Copyright (C) 2021 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -385,9 +385,9 @@ class MyClass: class MyClass: # ... - a : Mapped[int] + a: Mapped[int] - b : Mapped[str] + b: Mapped[str] c: Mapped[int] diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index ae55ca47b01..1eaef775953 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -1,5 +1,5 @@ # ext/mypy/names.py -# Copyright (C) 2021 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,6 +58,14 @@ NAMED_TYPE_BUILTINS_LIST = "builtins.list" NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" +_RelFullNames = { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.relationships._RelationshipDeclared", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", +} + _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( COLUMN, @@ -66,24 +74,9 @@ "sqlalchemy.sql.Column", }, ), - "Relationship": ( - RELATIONSHIP, - { - "sqlalchemy.orm.relationships.Relationship", - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.Relationship", - "sqlalchemy.orm.RelationshipProperty", - }, - ), - "RelationshipProperty": ( - RELATIONSHIP, - { - "sqlalchemy.orm.relationships.Relationship", - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.Relationship", - "sqlalchemy.orm.RelationshipProperty", - }, - ), + "Relationship": (RELATIONSHIP, _RelFullNames), + "RelationshipProperty": (RELATIONSHIP, _RelFullNames), + "_RelationshipDeclared": (RELATIONSHIP, _RelFullNames), "registry": ( REGISTRY, { @@ -304,7 +297,7 @@ def type_id_for_callee(callee: Expression) -> Optional[int]: def type_id_for_named_node( - node: Union[NameExpr, MemberExpr, SymbolNode] + node: Union[NameExpr, MemberExpr, SymbolNode], ) -> Optional[int]: type_id, fullnames = _lookup.get(node.name, (None, None)) diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 862d7d2166f..1ec2c02b9cf 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -1,5 +1,5 @@ # ext/mypy/plugin.py -# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 238c82a54f2..16761b9ab39 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,5 +1,5 @@ # ext/mypy/util.py -# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -80,7 +80,7 @@ def serialize(self) -> JsonDict: "name": self.name, "line": self.line, "column": self.column, - "type": self.type.serialize(), + "type": serialize_type(self.type), } def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: @@ -212,8 +212,7 @@ def add_global( @overload def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, expr_types: None = ... -) -> Optional[Union[CallExpr, NameExpr]]: - ... +) -> Optional[Union[CallExpr, NameExpr]]: ... @overload @@ -222,8 +221,7 @@ def get_callexpr_kwarg( name: str, *, expr_types: Tuple[TypingType[_TArgType], ...], -) -> Optional[_TArgType]: - ... +) -> Optional[_TArgType]: ... def get_callexpr_kwarg( @@ -315,9 +313,11 @@ def unbound_to_instance( return Instance( bound_type, [ - unbound_to_instance(api, arg) - if isinstance(arg, UnboundType) - else arg + ( + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + ) for arg in typ.args ], ) @@ -336,3 +336,22 @@ def info_for_cls( return sym.node return cls.info + + +def serialize_type(typ: Type) -> Union[str, JsonDict]: + try: + return typ.serialize() + except Exception: + pass + if hasattr(typ, "args"): + typ.args = tuple( + ( + a.resolve_string_annotation() + if hasattr(a, "resolve_string_annotation") + else a + ) + for a in typ.args + ) + elif hasattr(typ, "resolve_string_annotation"): + typ = typ.resolve_string_annotation() + return typ.serialize() diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index a6c42ff0936..3cc67b18964 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -1,5 +1,5 @@ # ext/orderinglist.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -26,18 +26,20 @@ Base = declarative_base() + class Slide(Base): - __tablename__ = 'slide' + __tablename__ = "slide" id = Column(Integer, primary_key=True) name = Column(String) bullets = relationship("Bullet", order_by="Bullet.position") + class Bullet(Base): - __tablename__ = 'bullet' + __tablename__ = "bullet" id = Column(Integer, primary_key=True) - slide_id = Column(Integer, ForeignKey('slide.id')) + slide_id = Column(Integer, ForeignKey("slide.id")) position = Column(Integer) text = Column(String) @@ -57,19 +59,24 @@ class Bullet(Base): Base = declarative_base() + class Slide(Base): - __tablename__ = 'slide' + __tablename__ = "slide" id = Column(Integer, primary_key=True) name = Column(String) - bullets = relationship("Bullet", order_by="Bullet.position", - collection_class=ordering_list('position')) + bullets = relationship( + "Bullet", + order_by="Bullet.position", + collection_class=ordering_list("position"), + ) + class Bullet(Base): - __tablename__ = 'bullet' + __tablename__ = "bullet" id = Column(Integer, primary_key=True) - slide_id = Column(Integer, ForeignKey('slide.id')) + slide_id = Column(Integer, ForeignKey("slide.id")) position = Column(Integer) text = Column(String) @@ -151,14 +158,18 @@ def ordering_list( from sqlalchemy.ext.orderinglist import ordering_list + class Slide(Base): - __tablename__ = 'slide' + __tablename__ = "slide" id = Column(Integer, primary_key=True) name = Column(String) - bullets = relationship("Bullet", order_by="Bullet.position", - collection_class=ordering_list('position')) + bullets = relationship( + "Bullet", + order_by="Bullet.position", + collection_class=ordering_list("position"), + ) :param attr: Name of the mapped attribute to use for storage and retrieval of diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index 706bff29fb0..b7032b65959 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -1,5 +1,5 @@ # ext/serializer.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,13 +28,17 @@ Usage is nearly the same as that of the standard Python pickle module:: from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) Session = scoped_session(sessionmaker()) # ... define mappers - query = Session.query(MyClass). - filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + query = ( + Session.query(MyClass) + .filter(MyClass.somedata == "foo") + .order_by(MyClass.sortkey) + ) # pickle the query serialized = dumps(query) @@ -42,7 +46,7 @@ # unpickle. Pass in metadata + scoped_session query2 = loads(serialized, metadata, Session) - print query2.all() + print(query2.all()) Similar restrictions as when using raw pickle apply; mapped classes must be themselves be pickleable, meaning they are importable from a module-level @@ -82,10 +86,9 @@ __all__ = ["Serializer", "Deserializer", "dumps", "loads"] -def Serializer(*args, **kw): - pickler = pickle.Pickler(*args, **kw) +class Serializer(pickle.Pickler): - def persistent_id(obj): + def persistent_id(self, obj): # print "serializing:", repr(obj) if isinstance(obj, Mapper) and not obj.non_primary: id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) @@ -113,9 +116,6 @@ def persistent_id(obj): return None return id_ - pickler.persistent_id = persistent_id - return pickler - our_ids = re.compile( r"(mapperprop|mapper|mapper_selectable|table|column|" @@ -123,20 +123,23 @@ def persistent_id(obj): ) -def Deserializer(file, metadata=None, scoped_session=None, engine=None): - unpickler = pickle.Unpickler(file) +class Deserializer(pickle.Unpickler): - def get_engine(): - if engine: - return engine - elif scoped_session and scoped_session().bind: - return scoped_session().bind - elif metadata and metadata.bind: - return metadata.bind + def __init__(self, file, metadata=None, scoped_session=None, engine=None): + super().__init__(file) + self.metadata = metadata + self.scoped_session = scoped_session + self.engine = engine + + def get_engine(self): + if self.engine: + return self.engine + elif self.scoped_session and self.scoped_session().bind: + return self.scoped_session().bind else: return None - def persistent_load(id_): + def persistent_load(self, id_): m = our_ids.match(str(id_)) if not m: return None @@ -157,20 +160,17 @@ def persistent_load(id_): cls = pickle.loads(b64decode(mapper)) return class_mapper(cls).attrs[keyname] elif type_ == "table": - return metadata.tables[args] + return self.metadata.tables[args] elif type_ == "column": table, colname = args.split(":") - return metadata.tables[table].c[colname] + return self.metadata.tables[table].c[colname] elif type_ == "session": - return scoped_session() + return self.scoped_session() elif type_ == "engine": - return get_engine() + return self.get_engine() else: raise Exception("Unknown token: %s" % type_) - unpickler.persistent_load = persistent_load - return unpickler - def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): buf = BytesIO() diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py index bfc31d42676..ef9afb1a52b 100644 --- a/lib/sqlalchemy/future/__init__.py +++ b/lib/sqlalchemy/future/__init__.py @@ -1,5 +1,5 @@ -# sql/future/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# future/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index 1984f34ca75..0449c3d9f31 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -1,5 +1,5 @@ -# sql/future/engine.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# future/engine.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 7d8479b5ecf..2e5b2201814 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -1,5 +1,5 @@ -# sqlalchemy/inspect.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# inspection.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -74,8 +74,7 @@ class _InspectableTypeProtocol(Protocol[_TCov]): """ - def _sa_inspect_type(self) -> _TCov: - ... + def _sa_inspect_type(self) -> _TCov: ... class _InspectableProtocol(Protocol[_TCov]): @@ -84,35 +83,31 @@ class _InspectableProtocol(Protocol[_TCov]): """ - def _sa_inspect_instance(self) -> _TCov: - ... + def _sa_inspect_instance(self) -> _TCov: ... @overload def inspect( subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True -) -> _IN: - ... +) -> _IN: ... @overload -def inspect(subject: _InspectableProtocol[_IN], raiseerr: bool = True) -> _IN: - ... +def inspect( + subject: _InspectableProtocol[_IN], raiseerr: bool = True +) -> _IN: ... @overload -def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: - ... +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: ... @overload -def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: - ... +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... @overload -def inspect(subject: Any, raiseerr: bool = True) -> Any: - ... +def inspect(subject: Any, raiseerr: bool = True) -> Any: ... def inspect(subject: Any, raiseerr: bool = True) -> Any: @@ -162,9 +157,7 @@ def _inspects( def decorate(fn_or_cls: _F) -> _F: for type_ in types: if type_ in _registrars: - raise AssertionError( - "Type %s is already " "registered" % type_ - ) + raise AssertionError("Type %s is already registered" % type_) _registrars[type_] = fn_or_cls return fn_or_cls @@ -176,6 +169,6 @@ def decorate(fn_or_cls: _F) -> _F: def _self_inspects(cls: _TT) -> _TT: if cls in _registrars: - raise AssertionError("Type %s is already " "registered" % cls) + raise AssertionError("Type %s is already registered" % cls) _registrars[cls] = True return cls diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 8de6d188cee..849a0bfa078 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -1,5 +1,5 @@ -# sqlalchemy/log.py -# Copyright (C) 2006-2023 the SQLAlchemy authors and contributors +# log.py +# Copyright (C) 2006-2025 the SQLAlchemy authors and contributors # # Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk # @@ -269,14 +269,12 @@ class echo_property: @overload def __get__( self, instance: Literal[None], owner: Type[Identified] - ) -> echo_property: - ... + ) -> echo_property: ... @overload def __get__( self, instance: Identified, owner: Type[Identified] - ) -> _EchoFlagType: - ... + ) -> _EchoFlagType: ... def __get__( self, instance: Optional[Identified], owner: Type[Identified] diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index f6888aeee45..7771de47eb2 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -1,5 +1,5 @@ # orm/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index df36c386416..d9e3ec37ba2 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -1,5 +1,5 @@ # orm/_orm_constructors.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,8 +28,8 @@ from .properties import MappedSQLExpression from .query import AliasOption from .relationships import _RelationshipArgumentType +from .relationships import _RelationshipDeclared from .relationships import _RelationshipSecondaryArgument -from .relationships import Relationship from .relationships import RelationshipProperty from .session import Session from .util import _ORMJoin @@ -70,7 +70,7 @@ from ..sql._typing import _TypeEngineArgument from ..sql.elements import ColumnElement from ..sql.schema import _ServerDefaultArgument - from ..sql.schema import FetchedValue + from ..sql.schema import _ServerOnUpdateArgument from ..sql.selectable import Alias from ..sql.selectable import Subquery @@ -108,6 +108,7 @@ def mapped_column( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -127,7 +128,7 @@ def mapped_column( onupdate: Optional[Any] = None, insert_default: Optional[Any] = _NoArg.NO_ARG, server_default: Optional[_ServerDefaultArgument] = None, - server_onupdate: Optional[FetchedValue] = None, + server_onupdate: Optional[_ServerOnUpdateArgument] = None, active_history: bool = False, quote: Optional[bool] = None, system: bool = False, @@ -255,12 +256,28 @@ def mapped_column( be used instead**. This is necessary to disambiguate the callable from being interpreted as a dataclass level default. + .. seealso:: + + :ref:`defaults_default_factory_insert_default` + + :paramref:`_orm.mapped_column.insert_default` + + :paramref:`_orm.mapped_column.default_factory` + :param insert_default: Passed directly to the :paramref:`_schema.Column.default` parameter; will supersede the value of :paramref:`_orm.mapped_column.default` when present, however :paramref:`_orm.mapped_column.default` will always apply to the constructor default for a dataclasses mapping. + .. seealso:: + + :ref:`defaults_default_factory_insert_default` + + :paramref:`_orm.mapped_column.default` + + :paramref:`_orm.mapped_column.default_factory` + :param sort_order: An integer that indicates how this mapped column should be sorted compared to the others when the ORM is creating a :class:`_schema.Table`. Among mapped columns that have the same @@ -295,6 +312,15 @@ def mapped_column( specifies a default-value generation function that will take place as part of the ``__init__()`` method as generated by the dataclass process. + + .. seealso:: + + :ref:`defaults_default_factory_insert_default` + + :paramref:`_orm.mapped_column.default` + + :paramref:`_orm.mapped_column.insert_default` + :param compare: Specific to :ref:`orm_declarative_native_dataclasses`, indicates if this field should be included in comparison operations when generating the @@ -306,6 +332,13 @@ def mapped_column( :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. + :param hash: Specific to + :ref:`orm_declarative_native_dataclasses`, controls if this field + is included when generating the ``__hash__()`` method for the mapped + class. + + .. versionadded:: 2.0.36 + :param \**kw: All remaining keyword arguments are passed through to the constructor for the :class:`_schema.Column`. @@ -320,7 +353,7 @@ def mapped_column( autoincrement=autoincrement, insert_default=insert_default, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), doc=doc, key=key, @@ -385,9 +418,9 @@ def orm_insert_sentinel( return mapped_column( name=name, - default=default - if default is not None - else _InsertSentinelColumnDefault(), + default=( + default if default is not None else _InsertSentinelColumnDefault() + ), _omit_from_statements=omit_from_statements, insert_sentinel=True, use_existing_column=True, @@ -415,12 +448,13 @@ def column_property( deferred: bool = False, raiseload: bool = False, comparator_factory: Optional[Type[PropComparator[_T]]] = None, - init: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + init: Union[_NoArg, bool] = _NoArg.NO_ARG, repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -509,13 +543,43 @@ def column_property( :ref:`orm_queryguide_deferred_raiseload` - :param init: + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + + .. seealso:: - :param default: + :ref:`defaults_default_factory_insert_default` - :param default_factory: + :paramref:`_orm.mapped_column.default` - :param kw_only: + :paramref:`_orm.mapped_column.insert_default` + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + :param hash: Specific to + :ref:`orm_declarative_native_dataclasses`, controls if this field + is included when generating the ``__hash__()`` method for the mapped + class. + + .. versionadded:: 2.0.36 """ return MappedSQLExpression( @@ -528,6 +592,7 @@ def column_property( default_factory, compare, kw_only, + hash, ), group=group, deferred=deferred, @@ -556,11 +621,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[Any]: - ... +) -> Composite[Any]: ... @overload @@ -578,11 +643,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: - ... +) -> Composite[_CC]: ... @overload @@ -600,11 +665,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: - ... +) -> Composite[_CC]: ... def composite( @@ -623,6 +688,7 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, @@ -697,6 +763,12 @@ def composite( :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. + :param hash: Specific to + :ref:`orm_declarative_native_dataclasses`, controls if this field + is included when generating the ``__hash__()`` method for the mapped + class. + + .. versionadded:: 2.0.36 """ if __kw: raise _no_kw() @@ -705,7 +777,7 @@ def composite( _class_or_attr, *attrs, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), group=group, deferred=deferred, @@ -719,7 +791,10 @@ def composite( def with_loader_criteria( entity_or_base: _EntityType[Any], - where_criteria: _ColumnExpressionArgument[bool], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, @@ -748,7 +823,7 @@ def with_loader_criteria( stmt = select(User).options( selectinload(User.addresses), - with_loader_criteria(Address, Address.email_address != 'foo')) + with_loader_criteria(Address, Address.email_address != "foo"), ) Above, the "selectinload" for ``User.addresses`` will apply the @@ -758,8 +833,10 @@ def with_loader_criteria( ON clause of the join, in this example using :term:`1.x style` queries:: - q = session.query(User).outerjoin(User.addresses).options( - with_loader_criteria(Address, Address.email_address != 'foo')) + q = ( + session.query(User) + .outerjoin(User.addresses) + .options(with_loader_criteria(Address, Address.email_address != "foo")) ) The primary purpose of :func:`_orm.with_loader_criteria` is to use @@ -772,6 +849,7 @@ def with_loader_criteria( session = Session(bind=engine) + @event.listens_for("do_orm_execute", session) def _add_filtering_criteria(execute_state): @@ -783,8 +861,8 @@ def _add_filtering_criteria(execute_state): execute_state.statement = execute_state.statement.options( with_loader_criteria( SecurityRole, - lambda cls: cls.role.in_(['some_role']), - include_aliases=True + lambda cls: cls.role.in_(["some_role"]), + include_aliases=True, ) ) @@ -821,16 +899,19 @@ def _add_filtering_criteria(execute_state): ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria` option will affect the way in which the JOIN is rendered:: - stmt = select(A).join(A.bs).options( - contains_eager(A.bs), - with_loader_criteria(B, B.flag == 1) + stmt = ( + select(A) + .join(A.bs) + .options(contains_eager(A.bs), with_loader_criteria(B, B.flag == 1)) ) Above, the given :func:`_orm.with_loader_criteria` option will affect the ON clause of the JOIN that is specified by ``.join(A.bs)``, so is applied as expected. The :func:`_orm.contains_eager` option has the effect that columns from - ``B`` are added to the columns clause:: + ``B`` are added to the columns clause: + + .. sourcecode:: sql SELECT b.id, b.a_id, b.data, b.flag, @@ -896,7 +977,7 @@ class of a particular set of mapped classes, to which the rule .. versionadded:: 1.4.0b2 - """ + """ # noqa: E501 return LoaderCriteriaOption( entity_or_base, where_criteria, @@ -930,6 +1011,7 @@ def relationship( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -950,7 +1032,7 @@ def relationship( omit_join: Literal[None, False] = None, sync_backref: Optional[bool] = None, **kw: Any, -) -> Relationship[Any]: +) -> _RelationshipDeclared[Any]: """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. @@ -1688,19 +1770,10 @@ class that will be synchronized with this one. It is usually the full set of related objects, to prevent modifications of the collection from resulting in persistence operations. - When using the :paramref:`_orm.relationship.viewonly` flag in - conjunction with backrefs, the originating relationship for a - particular state change will not produce state changes within the - viewonly relationship. This is the behavior implied by - :paramref:`_orm.relationship.sync_backref` being set to False. - - .. versionchanged:: 1.3.17 - the - :paramref:`_orm.relationship.sync_backref` flag is set to False - when using viewonly in conjunction with backrefs. - .. seealso:: - :paramref:`_orm.relationship.sync_backref` + :ref:`relationship_viewonly_notes` - more details on best practices + when using :paramref:`_orm.relationship.viewonly`. :param sync_backref: A boolean that enables the events used to synchronize the in-Python @@ -1762,10 +1835,15 @@ class that will be synchronized with this one. It is usually :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. + :param hash: Specific to + :ref:`orm_declarative_native_dataclasses`, controls if this field + is included when generating the ``__hash__()`` method for the mapped + class. + .. versionadded:: 2.0.36 """ - return Relationship( + return _RelationshipDeclared( argument, secondary=secondary, uselist=uselist, @@ -1780,7 +1858,7 @@ class that will be synchronized with this one. It is usually cascade=cascade, viewonly=viewonly, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), lazy=lazy, passive_deletes=passive_deletes, @@ -1815,6 +1893,7 @@ def synonym( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, ) -> Synonym[Any]: @@ -1825,14 +1904,13 @@ def synonym( e.g.:: class MyClass(Base): - __tablename__ = 'my_table' + __tablename__ = "my_table" id = Column(Integer, primary_key=True) job_status = Column(String(50)) status = synonym("job_status") - :param name: the name of the existing mapped property. This can refer to the string name ORM-mapped attribute configured on the class, including column-bound attributes @@ -1860,11 +1938,13 @@ class MyClass(Base): :paramref:`.synonym.descriptor` parameter:: my_table = Table( - "my_table", metadata, - Column('id', Integer, primary_key=True), - Column('job_status', String(50)) + "my_table", + metadata, + Column("id", Integer, primary_key=True), + Column("job_status", String(50)), ) + class MyClass: @property def _job_status_descriptor(self): @@ -1872,11 +1952,15 @@ def _job_status_descriptor(self): mapper( - MyClass, my_table, properties={ + MyClass, + my_table, + properties={ "job_status": synonym( - "_job_status", map_column=True, - descriptor=MyClass._job_status_descriptor) - } + "_job_status", + map_column=True, + descriptor=MyClass._job_status_descriptor, + ) + }, ) Above, the attribute named ``_job_status`` is automatically @@ -1925,7 +2009,7 @@ def _job_status_descriptor(self): descriptor=descriptor, comparator_factory=comparator_factory, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), doc=doc, info=info, @@ -2026,8 +2110,7 @@ def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: E.g.:: - 'items':relationship( - SomeItem, backref=backref('parent', lazy='subquery')) + "items": relationship(SomeItem, backref=backref("parent", lazy="subquery")) The :paramref:`_orm.relationship.backref` parameter is generally considered to be legacy; for modern applications, using @@ -2039,7 +2122,7 @@ def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: :ref:`relationships_backref` - background on backrefs - """ + """ # noqa: E501 return (name, kwargs) @@ -2056,6 +2139,7 @@ def deferred( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -2090,7 +2174,7 @@ def deferred( column, *additional_columns, attribute_options=_AttributeOptions( - init, repr, default, default_factory, compare, kw_only + init, repr, default, default_factory, compare, kw_only, hash ), group=group, deferred=True, @@ -2133,6 +2217,7 @@ def query_expression( _NoArg.NO_ARG, compare, _NoArg.NO_ARG, + _NoArg.NO_ARG, ), expire_on_flush=expire_on_flush, info=info, @@ -2186,8 +2271,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedType[_O]: - ... +) -> AliasedType[_O]: ... @overload @@ -2197,8 +2281,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedClass[_O]: - ... +) -> AliasedClass[_O]: ... @overload @@ -2208,8 +2291,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> FromClause: - ... +) -> FromClause: ... def aliased( @@ -2282,6 +2364,16 @@ def aliased( supported by all modern databases with regards to right-nested joins and generally produces more efficient queries. + When :paramref:`_orm.aliased.flat` is combined with + :paramref:`_orm.aliased.name`, the resulting joins will alias individual + tables using a naming scheme similar to ``_``. This + naming scheme is for visibility / debugging purposes only and the + specific scheme is subject to change without notice. + + .. versionadded:: 2.0.32 added support for combining + :paramref:`_orm.aliased.name` with :paramref:`_orm.aliased.flat`. + Previously, this would raise ``NotImplementedError``. + :param adapt_on_names: if True, more liberal "matching" will be used when mapping the mapped columns of the ORM entity to those of the given selectable - a name-based match will be performed if the @@ -2291,17 +2383,21 @@ def aliased( aggregate functions:: class UnitPrice(Base): - __tablename__ = 'unit_price' + __tablename__ = "unit_price" ... unit_id = Column(Integer) price = Column(Numeric) - aggregated_unit_price = Session.query( - func.sum(UnitPrice.price).label('price') - ).group_by(UnitPrice.unit_id).subquery() - aggregated_unit_price = aliased(UnitPrice, - alias=aggregated_unit_price, adapt_on_names=True) + aggregated_unit_price = ( + Session.query(func.sum(UnitPrice.price).label("price")) + .group_by(UnitPrice.unit_id) + .subquery() + ) + + aggregated_unit_price = aliased( + UnitPrice, alias=aggregated_unit_price, adapt_on_names=True + ) Above, functions on ``aggregated_unit_price`` which refer to ``.price`` will return the @@ -2329,6 +2425,7 @@ def with_polymorphic( aliased: bool = False, innerjoin: bool = False, adapt_on_names: bool = False, + name: Optional[str] = None, _use_mapper_path: bool = False, ) -> AliasedClass[_O]: """Produce an :class:`.AliasedClass` construct which specifies @@ -2400,6 +2497,10 @@ def with_polymorphic( .. versionadded:: 1.4.33 + :param name: Name given to the generated :class:`.AliasedClass`. + + .. versionadded:: 2.0.31 + """ return AliasedInsp._with_polymorphic_factory( base, @@ -2410,6 +2511,7 @@ def with_polymorphic( adapt_on_names=adapt_on_names, aliased=aliased, innerjoin=innerjoin, + name=name, _use_mapper_path=_use_mapper_path, ) @@ -2441,16 +2543,21 @@ def join( :meth:`_sql.Select.select_from` method, as in:: from sqlalchemy.orm import join - stmt = select(User).\ - select_from(join(User, Address, User.addresses)).\ - filter(Address.email_address=='foo@bar.com') + + stmt = ( + select(User) + .select_from(join(User, Address, User.addresses)) + .filter(Address.email_address == "foo@bar.com") + ) In modern SQLAlchemy the above join can be written more succinctly as:: - stmt = select(User).\ - join(User.addresses).\ - filter(Address.email_address=='foo@bar.com') + stmt = ( + select(User) + .join(User.addresses) + .filter(Address.email_address == "foo@bar.com") + ) .. warning:: using :func:`_orm.join` directly may not work properly with modern ORM options such as :func:`_orm.with_loader_criteria`. diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 3085351ba3b..ccb8413b524 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -1,5 +1,5 @@ # orm/_typing.py -# Copyright (C) 2022 the SQLAlchemy authors and contributors +# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -78,7 +78,7 @@ _ORMColumnExprArgument = Union[ ColumnElement[_T], - _HasClauseElement, + _HasClauseElement[_T], roles.ExpressionElementRole[_T], ] @@ -108,13 +108,13 @@ class _ORMAdapterProto(Protocol): """ - def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: - ... + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: ... class _LoaderCallable(Protocol): - def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: - ... + def __call__( + self, state: InstanceState[Any], passive: PassiveFlag + ) -> Any: ... def is_orm_option( @@ -138,39 +138,33 @@ def is_composite_class(obj: Any) -> bool: if TYPE_CHECKING: - def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: - ... + def insp_is_mapper_property( + obj: Any, + ) -> TypeGuard[MapperProperty[Any]]: ... - def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: - ... + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: ... - def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: - ... + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: ... def insp_is_attribute( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: - ... + ) -> TypeGuard[QueryableAttribute[Any]]: ... def attr_is_internal_proxy( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: - ... + ) -> TypeGuard[QueryableAttribute[Any]]: ... def prop_is_relationship( prop: MapperProperty[Any], - ) -> TypeGuard[RelationshipProperty[Any]]: - ... + ) -> TypeGuard[RelationshipProperty[Any]]: ... def is_collection_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: - ... + ) -> TypeGuard[CollectionAttributeImpl]: ... def is_has_collection_adapter( impl: AttributeImpl, - ) -> TypeGuard[HasCollectionAdapter]: - ... + ) -> TypeGuard[HasCollectionAdapter]: ... else: insp_is_mapper_property = operator.attrgetter("is_property") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 1098359ecaa..3c4f3164514 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1,5 +1,5 @@ # orm/attributes.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -401,7 +401,7 @@ def adapt_to_entity(self, adapt_to_entity: AliasedInsp[Any]) -> Self: parententity=adapt_to_entity, ) - def of_type(self, entity: _EntityType[Any]) -> QueryableAttribute[_T]: + def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -462,6 +462,9 @@ def hasparent( ) -> bool: return self.impl.hasparent(state, optimistic=optimistic) is not False + def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: + return (self,) + def __getattr__(self, key: str) -> Any: try: return util.MemoizedSlots.__getattr__(self, key) @@ -503,7 +506,7 @@ def _queryable_attribute_unreduce( return getattr(entity, key) -class InstrumentedAttribute(QueryableAttribute[_T]): +class InstrumentedAttribute(QueryableAttribute[_T_co]): """Class bound instrumented attribute which adds basic :term:`descriptor` methods. @@ -542,16 +545,16 @@ def __delete__(self, instance: object) -> None: self.impl.delete(instance_state(instance), instance_dict(instance)) @overload - def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: - ... + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: if instance is None: return self @@ -595,7 +598,7 @@ def create_proxied_attribute( # TODO: can move this to descriptor_props if the need for this # function is removed from ext/hybrid.py - class Proxy(QueryableAttribute[Any]): + class Proxy(QueryableAttribute[_T_co]): """Presents the :class:`.QueryableAttribute` interface as a proxy on top of a Python descriptor / :class:`.PropComparator` combination. @@ -610,13 +613,13 @@ class Proxy(QueryableAttribute[Any]): def __init__( self, - class_, - key, - descriptor, - comparator, - adapt_to_entity=None, - doc=None, - original_property=None, + class_: _ExternalEntityType[Any], + key: str, + descriptor: Any, + comparator: interfaces.PropComparator[_T_co], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + doc: Optional[str] = None, + original_property: Optional[QueryableAttribute[_T_co]] = None, ): self.class_ = class_ self.key = key @@ -627,11 +630,11 @@ def __init__( self._doc = self.__doc__ = doc @property - def _parententity(self): + def _parententity(self): # type: ignore[override] return inspection.inspect(self.class_, raiseerr=False) @property - def parent(self): + def parent(self): # type: ignore[override] return inspection.inspect(self.class_, raiseerr=False) _is_internal_proxy = True @@ -641,6 +644,13 @@ def parent(self): ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), ] + def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: + prop = self.original_property + if prop is None: + return () + else: + return prop._column_strategy_attrs() + @property def _impl_uses_objects(self): return ( @@ -1538,8 +1548,7 @@ def get_collection( dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -1548,8 +1557,7 @@ def get_collection( dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -1560,8 +1568,7 @@ def get_collection( passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, @@ -1592,8 +1599,7 @@ def set( def _is_collection_attribute_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: - ... + ) -> TypeGuard[CollectionAttributeImpl]: ... else: _is_collection_attribute_impl = operator.attrgetter("collection") @@ -2049,8 +2055,7 @@ def get_collection( dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -2059,8 +2064,7 @@ def get_collection( dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -2071,8 +2075,7 @@ def get_collection( passive: PassiveFlag = PASSIVE_OFF, ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, @@ -2670,7 +2673,7 @@ def init_collection(obj: object, key: str) -> CollectionAdapter: This function is used to provide direct access to collection internals for a previously unloaded attribute. e.g.:: - collection_adapter = init_collection(someobject, 'elements') + collection_adapter = init_collection(someobject, "elements") for elem in values: collection_adapter.append_without_event(elem) diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 362346cc2a8..b9f8d32be96 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -1,13 +1,11 @@ # orm/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Constants and rudimental functions used throughout the ORM. - -""" +"""Constants and rudimental functions used throughout the ORM.""" from __future__ import annotations @@ -21,6 +19,7 @@ from typing import no_type_check from typing import Optional from typing import overload +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -144,7 +143,7 @@ class PassiveFlag(FastIntFlag): """ NO_AUTOFLUSH = 64 - """Loader callables should disable autoflush.""", + """Loader callables should disable autoflush.""" NO_RAISE = 128 """Loader callables should not raise any assertions""" @@ -282,6 +281,8 @@ class NotExtension(InspectionAttrExtensionType): _none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT]) +_none_only_set = frozenset([None]) + _SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED") _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") @@ -308,29 +309,23 @@ def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: if TYPE_CHECKING: - def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: - ... + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: ... @overload - def opt_manager_of_class(cls: AliasedClass[Any]) -> None: - ... + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: ... @overload def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: - ... + ) -> Optional[ClassManager[_O]]: ... def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: - ... + ) -> Optional[ClassManager[_O]]: ... - def instance_state(instance: _O) -> InstanceState[_O]: - ... + def instance_state(instance: _O) -> InstanceState[_O]: ... - def instance_dict(instance: object) -> Dict[str, Any]: - ... + def instance_dict(instance: object) -> Dict[str, Any]: ... else: # these can be replaced by sqlalchemy.ext.instrumentation @@ -438,7 +433,7 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: def _class_to_mapper( - class_or_mapper: Union[Mapper[_T], Type[_T]] + class_or_mapper: Union[Mapper[_T], Type[_T]], ) -> Mapper[_T]: # can't get mypy to see an overload for this insp = inspection.inspect(class_or_mapper, False) @@ -450,7 +445,7 @@ def _class_to_mapper( def _mapper_or_none( - entity: Union[Type[_T], _InternalEntityType[_T]] + entity: Union[Type[_T], _InternalEntityType[_T]], ) -> Optional[Mapper[_T]]: """Return the :class:`_orm.Mapper` for the given class or None if the class is not mapped. @@ -512,8 +507,7 @@ def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: if TYPE_CHECKING: - def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: - ... + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: ... else: _state_mapper = util.dottedgetter("manager.mapper") @@ -586,7 +580,7 @@ class InspectionAttr: """ - __slots__ = () + __slots__: Tuple[str, ...] = () is_selectable = False """Return True if this object is an instance of @@ -684,27 +678,25 @@ class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): if typing.TYPE_CHECKING: - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: - ... + def of_type( + self, class_: _EntityType[Any] + ) -> PropComparator[_T_co]: ... def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> PropComparator[bool]: - ... + ) -> PropComparator[bool]: ... def any( # noqa: A001 self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def has( self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... class ORMDescriptor(Generic[_T_co], TypingOnly): @@ -718,23 +710,19 @@ class ORMDescriptor(Generic[_T_co], TypingOnly): @overload def __get__( self, instance: Any, owner: Literal[None] - ) -> ORMDescriptor[_T_co]: - ... + ) -> ORMDescriptor[_T_co]: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLCoreOperations[_T_co]: - ... + ) -> SQLCoreOperations[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: - ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: object, owner: Any - ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: - ... + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... class _MappedAnnotationBase(Generic[_T_co], TypingOnly): @@ -820,29 +808,23 @@ class Mapped( @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: - ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], _T_co]: - ... + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... @classmethod - def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: - ... + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... def __set__( self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] - ) -> None: - ... + ) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... class _MappedAttribute(Generic[_T_co], TypingOnly): @@ -919,24 +901,20 @@ class User(Base): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> AppenderQuery[_T_co]: - ... + ) -> AppenderQuery[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: - ... + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: - ... + ) -> None: ... class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): @@ -975,21 +953,19 @@ class User(Base): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> WriteOnlyCollection[_T_co]: - ... + ) -> WriteOnlyCollection[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]: - ... + ) -> Union[ + InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co] + ]: ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: - ... + ) -> None: ... diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 31caedc3785..402d7bede6d 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -1,5 +1,5 @@ # orm/bulk_persistence.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -76,13 +76,13 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Literal[None] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> None: - ... +) -> None: ... @overload @@ -90,19 +90,20 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Optional[dml.Insert] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> cursor.CursorResult[Any]: - ... +) -> cursor.CursorResult[Any]: ... def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -118,13 +119,35 @@ def _bulk_insert( ) if isstates: + if TYPE_CHECKING: + mappings = cast(Iterable[InstanceState[_O]], mappings) + if return_defaults: + # list of states allows us to attach .key for return_defaults case states = [(state, state.dict) for state in mappings] mappings = [dict_ for (state, dict_) in states] else: mappings = [state.dict for state in mappings] else: - mappings = [dict(m) for m in mappings] + if TYPE_CHECKING: + mappings = cast(Iterable[Dict[str, Any]], mappings) + + if return_defaults: + # use dictionaries given, so that newly populated defaults + # can be delivered back to the caller (see #11661). This is **not** + # compatible with other use cases such as a session-executed + # insert() construct, as this will confuse the case of + # insert-per-subclass for joined inheritance cases (see + # test_bulk_statements.py::BulkDMLReturningJoinedInhTest). + # + # So in this conditional, we have **only** called + # session.bulk_insert_mappings() which does not have this + # requirement + mappings = list(mappings) + else: + # for all other cases we need to establish a local dictionary + # so that the incoming dictionaries aren't mutated + mappings = [dict(m) for m in mappings] _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) @@ -220,6 +243,7 @@ def _bulk_insert( state.key = ( identity_cls, tuple([dict_[key] for key in identity_props]), + None, ) if use_orm_insert_stmt is not None: @@ -232,12 +256,12 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., enable_check_rowcount: bool = True, -) -> None: - ... +) -> None: ... @overload @@ -245,18 +269,19 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., enable_check_rowcount: bool = True, -) -> _result.Result[Any]: - ... +) -> _result.Result[Any]: ... def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, @@ -377,14 +402,16 @@ def _get_orm_crud_kv_pairs( if desc is NO_VALUE: yield ( coercions.expect(roles.DMLColumnRole, k), - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ) - if needs_to_be_cacheable - else v, + ( + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v + ), ) else: yield from core_get_crud_kv_pairs( @@ -405,13 +432,15 @@ def _get_orm_crud_kv_pairs( else: yield ( k, - v - if not needs_to_be_cacheable - else coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) ), ) @@ -528,9 +557,9 @@ def _setup_orm_returning( fs = fs.execution_options(**orm_level_statement._execution_options) fs = fs.options(*orm_level_statement._with_options) self.select_statement = fs - self.from_statement_ctx = ( - fsc - ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + self.from_statement_ctx = fsc = ( + ORMFromStatementCompileState.create_for_statement(fs, compiler) + ) fsc.setup_dml_returning_compile_state(dml_mapper) dml_level_statement = dml_level_statement._generate() @@ -590,6 +619,7 @@ def _return_orm_returning( querycontext = QueryContext( compile_state.from_statement_ctx, compile_state.select_statement, + statement, params, session, load_options, @@ -614,6 +644,7 @@ class default_update_options(Options): _eval_condition = None _matched_rows = None _identity_token = None + _populate_existing: bool = False @classmethod def can_use_returning( @@ -646,6 +677,7 @@ def orm_pre_session_exec( { "synchronize_session", "autoflush", + "populate_existing", "identity_token", "is_delete_using", "is_update_from", @@ -830,53 +862,39 @@ def _adjust_for_extra_criteria(cls, global_attributes, ext_info): return return_crit @classmethod - def _interpret_returning_rows(cls, mapper, rows): - """translate from local inherited table columns to base mapper - primary key columns. - - Joined inheritance mappers always establish the primary key in terms of - the base table. When we UPDATE a sub-table, we can only get - RETURNING for the sub-table's columns. + def _interpret_returning_rows(cls, result, mapper, rows): + """return rows that indicate PK cols in mapper.primary_key position + for RETURNING rows. - Here, we create a lookup from the local sub table's primary key - columns to the base table PK columns so that we can get identity - key values from RETURNING that's against the joined inheritance - sub-table. + Prior to 2.0.36, this method seemed to be written for some kind of + inheritance scenario but the scenario was unused for actual joined + inheritance, and the function instead seemed to perform some kind of + partial translation that would remove non-PK cols if the PK cols + happened to be first in the row, but not otherwise. The joined + inheritance walk feature here seems to have never been used as it was + always skipped by the "local_table" check. - the complexity here is to support more than one level deep of - inheritance, where we have to link columns to each other across - the inheritance hierarchy. + As of 2.0.36 the function strips away non-PK cols and provides the + PK cols for the table in mapper PK order. """ - if mapper.local_table is not mapper.base_mapper.local_table: - return rows - - # this starts as a mapping of - # local_pk_col: local_pk_col. - # we will then iteratively rewrite the "value" of the dict with - # each successive superclass column - local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} - - for mp in mapper.iterate_to_root(): - if mp.inherits is None: - break - elif mp.local_table is mp.inherits.local_table: - continue - - t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) - col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} - for pk, super_ in local_pk_to_base_pk.items(): - local_pk_to_base_pk[pk] = col_to_col[super_] + try: + if mapper.local_table is not mapper.base_mapper.local_table: + # TODO: dive more into how a local table PK is used for fetch + # sync, not clear if this is correct as it depends on the + # downstream routine to fetch rows using + # local_table.primary_key order + pk_keys = result._tuple_getter(mapper.local_table.primary_key) + else: + pk_keys = result._tuple_getter(mapper.primary_key) + except KeyError: + # can't use these rows, they don't have PK cols in them + # this is an unusual case where the user would have used + # .return_defaults() + return [] - lookup = { - local_pk_to_base_pk[lpk]: idx - for idx, lpk in enumerate(mapper.local_table.primary_key) - } - primary_key_convert = [ - lookup[bpk] for bpk in mapper.base_mapper.primary_key - ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] + return [pk_keys(row) for row in rows] @classmethod def _get_matched_objects_on_criteria(cls, update_options, states): @@ -1439,6 +1457,9 @@ def _setup_for_orm_update(self, statement, compiler, **kw): new_stmt = statement._clone() + if new_stmt.table._annotations["parententity"] is mapper: + new_stmt.table = mapper.local_table + # note if the statement has _multi_values, these # are passed through to the new statement, which will then raise # InvalidRequestError because UPDATE doesn't support multi_values @@ -1557,10 +1578,20 @@ def orm_execute_statement( bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: + update_options = execution_options.get( "_sa_orm_update_options", cls.default_update_options ) + if update_options._populate_existing: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + load_options += {"_populate_existing": True} + execution_options = execution_options.union( + {"_sa_orm_load_options": load_options} + ) + if update_options._dml_strategy not in ( "orm", "auto", @@ -1716,7 +1747,10 @@ def _do_post_synchronize_evaluate( session, update_options, statement, + result.context.compiled_parameters[0], [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + result.prefetch_cols(), + result.postfetch_cols(), ) @classmethod @@ -1728,9 +1762,8 @@ def _do_post_synchronize_fetch( returned_defaults_rows = result.returned_defaults_rows if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( - target_mapper, returned_defaults_rows + result, target_mapper, returned_defaults_rows ) - matched_rows = [ tuple(row) + (update_options._identity_token,) for row in pk_rows @@ -1761,6 +1794,7 @@ def _do_post_synchronize_fetch( session, update_options, statement, + result.context.compiled_parameters[0], [ ( obj, @@ -1769,16 +1803,26 @@ def _do_post_synchronize_fetch( ) for obj in objs ], + result.prefetch_cols(), + result.postfetch_cols(), ) @classmethod def _apply_update_set_values_to_objects( - cls, session, update_options, statement, matched_objects + cls, + session, + update_options, + statement, + effective_params, + matched_objects, + prefetch_cols, + postfetch_cols, ): """apply values to objects derived from an update statement, e.g. UPDATE..SET """ + mapper = update_options._subject_mapper target_cls = mapper.class_ evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) @@ -1801,7 +1845,35 @@ def _apply_update_set_values_to_objects( attrib = {k for k, v in resolved_keys_as_propnames} states = set() + + to_prefetch = { + c + for c in prefetch_cols + if c.key in effective_params + and c in mapper._columntoproperty + and c.key not in evaluated_keys + } + to_expire = { + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + }.difference(evaluated_keys) + + prefetch_transfer = [ + (mapper._columntoproperty[c].key, c.key) for c in to_prefetch + ] + for obj, state, dict_ in matched_objects: + + dict_.update( + { + col_to_prop: effective_params[c_key] + for col_to_prop, c_key in prefetch_transfer + } + ) + + state._expire_attributes(state.dict, to_expire) + to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: @@ -1858,6 +1930,9 @@ def create_for_statement(cls, statement, compiler, **kw): new_stmt = statement._clone() + if new_stmt.table._annotations["parententity"] is mapper: + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( self.global_attributes, mapper ) @@ -2018,7 +2093,7 @@ def _do_post_synchronize_fetch( if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( - target_mapper, returned_defaults_rows + result, target_mapper, returned_defaults_rows ) matched_rows = [ diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 10f1db03b65..fd4828e8559 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -1,5 +1,5 @@ -# ext/declarative/clsregistry.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# orm/clsregistry.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -72,7 +72,7 @@ def add_class( # class already exists. existing = decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): - existing = decl_class_registry[classname] = _MultipleClassMarker( + decl_class_registry[classname] = _MultipleClassMarker( [cls, cast("Type[Any]", existing)] ) else: @@ -83,9 +83,9 @@ def add_class( _ModuleMarker, decl_class_registry["_sa_module_registry"] ) except KeyError: - decl_class_registry[ - "_sa_module_registry" - ] = root_module = _ModuleMarker("_sa_module_registry", None) + decl_class_registry["_sa_module_registry"] = root_module = ( + _ModuleMarker("_sa_module_registry", None) + ) tokens = cls.__module__.split(".") @@ -239,10 +239,10 @@ def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: def add_item(self, item: Type[Any]) -> None: # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, - # [ticket:3208] + # [ticket:3208] and [ticket:10782] modules = { cls.__module__ - for cls in [ref() for ref in self.contents] + for cls in [ref() for ref in list(self.contents)] if cls is not None } if item.__module__ in modules: @@ -287,8 +287,9 @@ def __getitem__(self, name: str) -> ClsRegistryToken: def _remove_item(self, name: str) -> None: self.contents.pop(name, None) - if not self.contents and self.parent is not None: - self.parent._remove_item(self.name) + if not self.contents: + if self.parent is not None: + self.parent._remove_item(self.name) _registries.discard(self) def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: @@ -316,7 +317,7 @@ def add_class(self, name: str, cls: Type[Any]) -> None: else: raise else: - existing = self.contents[name] = _MultipleClassMarker( + self.contents[name] = _MultipleClassMarker( [cls], on_remove=lambda: self._remove_item(name) ) @@ -542,9 +543,7 @@ def __call__(self) -> Any: _fallback_dict: Mapping[str, Any] = None # type: ignore -def _resolver( - cls: Type[Any], prop: RelationshipProperty[Any] -) -> Tuple[ +def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], Callable[[str, bool], _class_resolver], ]: diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 3a4964c4609..336b1133d99 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -1,5 +1,5 @@ # orm/collections.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -21,6 +21,8 @@ and return values to events:: from sqlalchemy.orm.collections import collection + + class MyClass: # ... @@ -32,7 +34,6 @@ def store(self, item): def pop(self): return self.data.pop() - The second approach is a bundle of targeted decorators that wrap appropriate append and remove notifiers around the mutation methods present in the standard Python ``list``, ``set`` and ``dict`` interfaces. These could be @@ -73,10 +74,11 @@ class InstrumentedList(list): method that's already instrumented. For example:: class QueueIsh(list): - def push(self, item): - self.append(item) - def shift(self): - return self.pop(0) + def push(self, item): + self.append(item) + + def shift(self): + return self.pop(0) There's no need to decorate these methods. ``append`` and ``pop`` are already instrumented as part of the ``list`` interface. Decorating them would fire @@ -148,10 +150,12 @@ def shift(self): "keyfunc_mapping", "column_keyed_dict", "attribute_keyed_dict", - "column_keyed_dict", - "attribute_keyed_dict", - "MappedCollection", "KeyFuncDict", + # old names in < 2.0 + "mapped_collection", + "column_mapped_collection", + "attribute_mapped_collection", + "MappedCollection", ] __instrumentation_mutex = threading.Lock() @@ -167,8 +171,7 @@ def shift(self): class _CollectionConverterProtocol(Protocol): - def __call__(self, collection: _COL) -> _COL: - ... + def __call__(self, collection: _COL) -> _COL: ... class _AdaptedCollectionProtocol(Protocol): @@ -194,9 +197,10 @@ def append(self, append): ... The recipe decorators all require parens, even those that take no arguments:: - @collection.adds('entity') + @collection.adds("entity") def insert(self, position, entity): ... + @collection.removes_return() def popitem(self): ... @@ -216,11 +220,13 @@ def appender(fn): @collection.appender def add(self, append): ... + # or, equivalently @collection.appender @collection.adds(1) def add(self, append): ... + # for mapping type, an 'append' may kick out a previous value # that occupies that slot. consider d['a'] = 'foo'- any previous # value in d['a'] is discarded. @@ -260,10 +266,11 @@ def remover(fn): @collection.remover def zap(self, entity): ... + # or, equivalently @collection.remover @collection.removes_return() - def zap(self, ): ... + def zap(self): ... If the value to remove is not present in the collection, you may raise an exception or return None to ignore the error. @@ -363,7 +370,8 @@ def adds(arg): @collection.adds(1) def push(self, item): ... - @collection.adds('entity') + + @collection.adds("entity") def do_stuff(self, thing, entity=None): ... """ @@ -548,9 +556,9 @@ def _reset_empty(self) -> None: self.empty ), "This collection adapter is not in the 'empty' state" self.empty = False - self.owner_state.dict[ - self._key - ] = self.owner_state._empty_collections.pop(self._key) + self.owner_state.dict[self._key] = ( + self.owner_state._empty_collections.pop(self._key) + ) def _refuse_empty(self) -> NoReturn: raise sa_exc.InvalidRequestError( @@ -1554,14 +1562,14 @@ class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation: util.immutabledict[ - Any, _CollectionFactoryType -] = util.immutabledict( - { - list: InstrumentedList, - set: InstrumentedSet, - dict: InstrumentedDict, - } +__canned_instrumentation: util.immutabledict[Any, _CollectionFactoryType] = ( + util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } + ) ) __interfaces: util.immutabledict[ diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 79b43f5fe7d..30b05948a51 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -1,5 +1,5 @@ # orm/context.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -104,6 +104,7 @@ class QueryContext: "top_level_context", "compile_state", "query", + "user_passed_query", "params", "load_options", "bind_arguments", @@ -147,7 +148,12 @@ class default_load_options(Options): def __init__( self, compile_state: CompileState, - statement: Union[Select[Any], FromStatement[Any]], + statement: Union[Select[Any], FromStatement[Any], UpdateBase], + user_passed_query: Union[ + Select[Any], + FromStatement[Any], + UpdateBase, + ], params: _CoreSingleExecuteParams, session: Session, load_options: Union[ @@ -162,6 +168,13 @@ def __init__( self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state self.query = statement + + # the query that the end user passed to Session.execute() or similar. + # this is usually the same as .query, except in the bulk_persistence + # routines where a separate FromStatement is manufactured in the + # compile stage; this allows differentiation in that case. + self.user_passed_query = user_passed_query + self.session = session self.loaders_require_buffering = False self.loaders_require_uniquing = False @@ -169,7 +182,7 @@ def __init__( self.top_level_context = load_options._sa_top_level_orm_context cached_options = compile_state.select_statement._with_options - uncached_options = statement._with_options + uncached_options = user_passed_query._with_options # see issue #7447 , #8399 for some background # propagated loader options will be present on loaded InstanceState @@ -218,7 +231,7 @@ def _init_global_attributes( if compiler is None: # this is the legacy / testing only ORM _compile_state() use case. # there is no need to apply criteria options for this. - self.global_attributes = ga = {} + self.global_attributes = {} assert toplevel return else: @@ -252,10 +265,10 @@ def _init_global_attributes( @classmethod def create_for_statement( cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], + statement: Executable, + compiler: SQLCompiler, **kw: Any, - ) -> AbstractORMCompileState: + ) -> CompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -401,8 +414,8 @@ class default_compile_options(CacheableOptions): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select[Any], FromStatement[Any]] - select_statement: Union[Select[Any], FromStatement[Any]] + statement: Union[Select[Any], FromStatement[Any], UpdateBase] + select_statement: Union[Select[Any], FromStatement[Any], UpdateBase] _entities: List[_QueryEntity] _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ @@ -424,16 +437,30 @@ class default_compile_options(CacheableOptions): def __init__(self, *arg, **kw): raise NotImplementedError() - if TYPE_CHECKING: + @classmethod + def create_for_statement( + cls, + statement: Executable, + compiler: SQLCompiler, + **kw: Any, + ) -> ORMCompileState: + return cls._create_orm_context( + cast("Union[Select, FromStatement]", statement), + toplevel=not compiler.stack, + compiler=compiler, + **kw, + ) - @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - ... + @classmethod + def _create_orm_context( + cls, + statement: Union[Select, FromStatement], + *, + toplevel: bool, + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + raise NotImplementedError() def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns @@ -517,15 +544,14 @@ def orm_pre_session_exec( and len(statement._compile_options._current_path) > 10 and execution_options.get("compiled_cache", True) is not None ): - util.warn( - "Loader depth for query is excessively deep; caching will " - "be disabled for additional loaders. Consider using the " - "recursion_depth feature for deeply nested recursive eager " - "loaders. Use the compiled_cache=None execution option to " - "skip this warning." - ) - execution_options = execution_options.union( - {"compiled_cache": None} + execution_options: util.immutabledict[str, Any] = ( + execution_options.union( + { + "compiled_cache": None, + "_cache_disable_reason": "excess depth for " + "ORM loader options", + } + ) ) bind_arguments["clause"] = statement @@ -580,6 +606,7 @@ def orm_setup_cursor_result( querycontext = QueryContext( compile_state, statement, + statement, params, session, load_options, @@ -643,8 +670,8 @@ def _create_entities_collection(cls, query, legacy): ) -class DMLReturningColFilter: - """an adapter used for the DML RETURNING case. +class _DMLReturningColFilter: + """a base for an adapter used for the DML RETURNING cases Has a subset of the interface used by :class:`.ORMAdapter` and is used for :class:`._QueryEntity` @@ -678,6 +705,21 @@ def __call__(self, col, as_filter): else: return None + def adapt_check_present(self, col): + raise NotImplementedError() + + +class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter): + """an adapter used for the DML RETURNING case specifically + for ORM bulk insert (or any hypothetical DML that is splitting out a class + hierarchy among multiple DML statements....ORM bulk insert is the only + example right now) + + its main job is to limit the columns in a RETURNING to only a specific + mapped table in a hierarchy. + + """ + def adapt_check_present(self, col): mapper = self.mapper prop = mapper._columntoproperty.get(col, None) @@ -686,6 +728,30 @@ def adapt_check_present(self, col): return mapper.local_table.c.corresponding_column(col) +class _DMLUpdateDeleteReturningColFilter(_DMLReturningColFilter): + """an adapter used for the DML RETURNING case specifically + for ORM enabled UPDATE/DELETE + + its main job is to limit the columns in a RETURNING to include + only direct persisted columns from the immediate selectable, not + expressions like column_property(), or to also allow columns from other + mappers for the UPDATE..FROM use case. + + """ + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is not None: + # if the col is from the immediate mapper, only return a persisted + # column, not any kind of column_property expression + return mapper.persist_selectable.c.corresponding_column(col) + + # if the col is from some other mapper, just return it, assume the + # user knows what they are doing + return col + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -704,12 +770,16 @@ class ORMFromStatementCompileState(ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def create_for_statement( + def _create_orm_context( cls, - statement_container: Union[Select, FromStatement], + statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMFromStatementCompileState: + statement_container = statement + assert isinstance(statement_container, FromStatement) if compiler is not None and compiler.stack: @@ -751,9 +821,11 @@ def create_for_statement( self.statement = statement self._label_convention = self._column_naming_convention( - statement._label_style - if not statement._is_textual and not statement.is_dml - else LABEL_STYLE_NONE, + ( + statement._label_style + if not statement._is_textual and not statement.is_dml + else LABEL_STYLE_NONE + ), self.use_legacy_query_style, ) @@ -799,9 +871,9 @@ def create_for_statement( for entity in self._entities: entity.setup_compile_state(self) - compiler._ordered_columns = ( - compiler._textual_ordered_columns - ) = False + compiler._ordered_columns = compiler._textual_ordered_columns = ( + False + ) # enable looser result column matching. this is shown to be # needed by test_query.py::TextTest @@ -838,14 +910,24 @@ def _get_current_adapter(self): return None def setup_dml_returning_compile_state(self, dml_mapper): - """used by BulkORMInsert (and Update / Delete?) to set up a handler + """used by BulkORMInsert, Update, Delete to set up a handler for RETURNING to return ORM objects and expressions """ target_mapper = self.statement._propagate_attrs.get( "plugin_subject", None ) - adapter = DMLReturningColFilter(target_mapper, dml_mapper) + + if self.statement.is_insert: + adapter = _DMLBulkInsertReturningColFilter( + target_mapper, dml_mapper + ) + elif self.statement.is_update or self.statement.is_delete: + adapter = _DMLUpdateDeleteReturningColFilter( + target_mapper, dml_mapper + ) + else: + adapter = None if self.compile_options._is_star and (len(self._entities) != 1): raise sa_exc.CompileError( @@ -888,6 +970,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): ("_compile_options", InternalTraversal.dp_has_cache_key) ] + is_from_statement = True + def __init__( self, entities: Iterable[_ColumnsClauseArgument[Any]], @@ -905,6 +989,10 @@ def __init__( ] self.element = element self.is_dml = element.is_dml + self.is_select = element.is_select + self.is_delete = element.is_delete + self.is_insert = element.is_insert + self.is_update = element.is_update self._label_style = ( element._label_style if is_select_base(element) else None ) @@ -998,21 +1086,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _having_criteria = () @classmethod - def create_for_statement( + def _create_orm_context( cls, statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMSelectCompileState: - """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) - if compiler is not None: - toplevel = not compiler.stack - else: - toplevel = True - select_statement = statement # if we are a select() that was never a legacy Query, we won't @@ -1368,11 +1452,15 @@ def all_selected_columns(cls, statement): def get_columns_clause_froms(cls, statement): return cls._normalize_froms( itertools.chain.from_iterable( - element._from_objects - if "parententity" not in element._annotations - else [ - element._annotations["parententity"].__clause_element__() - ] + ( + element._from_objects + if "parententity" not in element._annotations + else [ + element._annotations[ + "parententity" + ].__clause_element__() + ] + ) for element in statement._raw_columns ) ) @@ -1501,9 +1589,11 @@ def _compound_eager_statement(self): # the original expressions outside of the label references # in order to have them render. unwrapped_order_by = [ - elem.element - if isinstance(elem, sql.elements._label_reference) - else elem + ( + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem + ) for elem in self.order_by ] @@ -1545,10 +1635,10 @@ def _compound_eager_statement(self): ) statement._label_style = self.label_style - # Oracle however does not allow FOR UPDATE on the subquery, - # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL - # we expect that all elements of the row are locked, so also put it - # on the outside (except in the case of PG when OF is used) + # Oracle Database however does not allow FOR UPDATE on the subquery, + # and the Oracle Database dialects ignore it, plus for PostgreSQL, + # MySQL we expect that all elements of the row are locked, so also put + # it on the outside (except in the case of PG when OF is used) if ( self._for_update_arg is not None and self._for_update_arg.of is None @@ -1774,8 +1864,6 @@ def _join(self, args, entities_collection): "selectable/table as join target" ) - of_type = None - if isinstance(onclause, interfaces.PropComparator): # descriptor/property given (or determined); this tells us # explicitly what the expected "left" side of the join is. @@ -2422,9 +2510,12 @@ def _column_descriptions( "type": ent.type, "aliased": getattr(insp_ent, "is_aliased_class", False), "expr": ent.expr, - "entity": getattr(insp_ent, "entity", None) - if ent.entity_zero is not None and not insp_ent.is_clause_element - else None, + "entity": ( + getattr(insp_ent, "entity", None) + if ent.entity_zero is not None + and not insp_ent.is_clause_element + else None + ), } for ent, insp_ent in [ (_ent, _ent.entity_zero) for _ent in ctx._entities @@ -2434,7 +2525,7 @@ def _column_descriptions( def _legacy_filter_by_entity_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]] + query_or_augmented_select: Union[Query[Any], Select[Any]], ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if self._setup_joins: @@ -2449,7 +2540,7 @@ def _legacy_filter_by_entity_zero( def _entity_from_pre_ent_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]] + query_or_augmented_select: Union[Query[Any], Select[Any]], ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if not self._raw_columns: @@ -2507,7 +2598,7 @@ def setup_compile_state(self, compile_state: ORMCompileState) -> None: def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: DMLReturningColFilter, + adapter: Optional[_DMLReturningColFilter], ) -> None: raise NotImplementedError() @@ -2709,7 +2800,7 @@ def row_processor(self, context, result): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: DMLReturningColFilter, + adapter: Optional[_DMLReturningColFilter], ) -> None: loading._setup_entity_query( compile_state, @@ -2865,6 +2956,13 @@ def setup_compile_state(self, compile_state): for ent in self._entities: ent.setup_compile_state(compile_state) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: Optional[_DMLReturningColFilter], + ) -> None: + return self.setup_compile_state(compile_state) + def row_processor(self, context, result): procs, labels, extra = zip( *[ent.row_processor(context, result) for ent in self._entities] @@ -3028,7 +3126,10 @@ def __init__( if not is_current_entities or column._is_text_clause: self._label_name = None else: - self._label_name = compile_state._label_convention(column) + if parent_bundle: + self._label_name = column._proxy_key + else: + self._label_name = compile_state._label_convention(column) if parent_bundle: parent_bundle._entities.append(self) @@ -3048,7 +3149,7 @@ def corresponds_to(self, entity): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: DMLReturningColFilter, + adapter: Optional[_DMLReturningColFilter], ) -> None: return self.setup_compile_state(compile_state) @@ -3122,9 +3223,12 @@ def __init__( self.raw_column_index = raw_column_index if is_current_entities: - self._label_name = compile_state._label_convention( - column, col_name=orm_key - ) + if parent_bundle: + self._label_name = orm_key if orm_key else column._proxy_key + else: + self._label_name = compile_state._label_convention( + column, col_name=orm_key + ) else: self._label_name = None @@ -3162,10 +3266,13 @@ def corresponds_to(self, entity): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: DMLReturningColFilter, + adapter: Optional[_DMLReturningColFilter], ) -> None: - self._fetch_column = self.column - column = adapter(self.column, False) + + self._fetch_column = column = self.column + if adapter: + column = adapter(column, False) + if column is not None: compile_state.dedupe_columns.add(column) compile_state.primary_columns.append(column) diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 80c85f13ad3..60468237ee0 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -1,5 +1,5 @@ -# orm/declarative/api.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# orm/decl_api.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -14,7 +14,6 @@ import typing from typing import Any from typing import Callable -from typing import cast from typing import ClassVar from typing import Dict from typing import FrozenSet @@ -72,12 +71,16 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing +from ..util import warn_deprecated from ..util.typing import CallableReference +from ..util.typing import de_optionalize_union_types from ..util.typing import flatten_newtype from ..util.typing import is_generic from ..util.typing import is_literal from ..util.typing import is_newtype +from ..util.typing import is_pep695 from ..util.typing import Literal +from ..util.typing import LITERAL_TYPES from ..util.typing import Self if TYPE_CHECKING: @@ -206,7 +209,7 @@ def synonym_for( :paramref:`.orm.synonym.descriptor` parameter:: class MyClass(Base): - __tablename__ = 'my_table' + __tablename__ = "my_table" id = Column(Integer, primary_key=True) _job_status = Column("job_status", String(50)) @@ -312,17 +315,13 @@ def __init__( self, fn: Callable[..., _T], cascading: bool = False, - ): - ... + ): ... - def __get__(self, instance: Optional[object], owner: Any) -> _T: - ... + def __get__(self, instance: Optional[object], owner: Any) -> _T: ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... def __call__(self, fn: Callable[..., _TT]) -> _declared_directive[_TT]: # extensive fooling of mypy underway... @@ -376,20 +375,21 @@ def __tablename__(cls) -> str: for subclasses:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id: Mapped[int] = mapped_column(primary_key=True) type: Mapped[str] = mapped_column(String(50)) @declared_attr.directive def __mapper_args__(cls) -> Dict[str, Any]: - if cls.__name__ == 'Employee': + if cls.__name__ == "Employee": return { - "polymorphic_on":cls.type, - "polymorphic_identity":"Employee" + "polymorphic_on": cls.type, + "polymorphic_identity": "Employee", } else: - return {"polymorphic_identity":cls.__name__} + return {"polymorphic_identity": cls.__name__} + class Engineer(Employee): pass @@ -427,14 +427,11 @@ def __init__( self, fn: _DeclaredAttrDecorated[_T], cascading: bool = False, - ): - ... + ): ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... # this is the Mapped[] API where at class descriptor get time we want # the type checker to see InstrumentedAttribute[_T]. However the @@ -443,17 +440,14 @@ def __delete__(self, instance: Any) -> None: @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: - ... + ) -> InstrumentedAttribute[_T]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: - ... + ) -> Union[InstrumentedAttribute[_T], _T]: ... @hybridmethod def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: @@ -494,6 +488,7 @@ def declarative_mixin(cls: Type[_T]) -> Type[_T]: from sqlalchemy.orm import declared_attr from sqlalchemy.orm import declarative_mixin + @declarative_mixin class MyMixin: @@ -501,10 +496,11 @@ class MyMixin: def __tablename__(cls): return cls.__name__.lower() - __table_args__ = {'mysql_engine': 'InnoDB'} - __mapper_args__= {'always_refresh': True} + __table_args__ = {"mysql_engine": "InnoDB"} + __mapper_args__ = {"always_refresh": True} + + id = Column(Integer, primary_key=True) - id = Column(Integer, primary_key=True) class MyModel(MyMixin, Base): name = Column(String(1000)) @@ -594,6 +590,7 @@ def __init_subclass__( dataclass_callable: Union[ _NoArg, Callable[..., Type[Any]] ] = _NoArg.NO_ARG, + **kw: Any, ) -> None: apply_dc_transforms: _DataclassArguments = { "init": init, @@ -618,11 +615,11 @@ def __init_subclass__( for k, v in apply_dc_transforms.items() } else: - cls._sa_apply_dc_transforms = ( - current_transforms - ) = apply_dc_transforms + cls._sa_apply_dc_transforms = current_transforms = ( + apply_dc_transforms + ) - super().__init_subclass__() + super().__init_subclass__(**kw) if not _is_mapped_class(cls): new_anno = ( @@ -646,10 +643,10 @@ class DeclarativeBase( from sqlalchemy.orm import DeclarativeBase + class Base(DeclarativeBase): pass - The above ``Base`` class is now usable as the base for new declarative mappings. The superclass makes use of the ``__init_subclass__()`` method to set up new classes and metaclasses aren't used. @@ -672,11 +669,12 @@ class Base(DeclarativeBase): bigint = Annotated[int, "bigint"] my_metadata = MetaData() + class Base(DeclarativeBase): metadata = my_metadata type_annotation_map = { str: String().with_variant(String(255), "mysql", "mariadb"), - bigint: BigInteger() + bigint: BigInteger(), } Class-level attributes which may be specified include: @@ -751,11 +749,9 @@ def __init__(self, id=None, name=None): if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: - ... + def _sa_inspect_type(self) -> Mapper[Self]: ... - def _sa_inspect_instance(self) -> InstanceState[Self]: - ... + def _sa_inspect_instance(self) -> InstanceState[Self]: ... _sa_registry: ClassVar[_RegistryType] @@ -836,16 +832,15 @@ def _sa_inspect_instance(self) -> InstanceState[Self]: """ - def __init__(self, **kw: Any): - ... + def __init__(self, **kw: Any): ... - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBase in cls.__bases__: _check_not_declarative(cls, DeclarativeBase) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) - super().__init_subclass__() + super().__init_subclass__(**kw) def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: @@ -922,11 +917,9 @@ class DeclarativeBaseNoMeta( if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: - ... + def _sa_inspect_type(self) -> Mapper[Self]: ... - def _sa_inspect_instance(self) -> InstanceState[Self]: - ... + def _sa_inspect_instance(self) -> InstanceState[Self]: ... __tablename__: Any """String name to assign to the generated @@ -961,15 +954,15 @@ def _sa_inspect_instance(self) -> InstanceState[Self]: """ - def __init__(self, **kw: Any): - ... + def __init__(self, **kw: Any): ... - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBaseNoMeta in cls.__bases__: _check_not_declarative(cls, DeclarativeBaseNoMeta) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__(**kw) def add_mapped_attribute( @@ -1234,38 +1227,34 @@ def update_type_annotation_map( self.type_annotation_map.update( { - sub_type: sqltype + de_optionalize_union_types(typ): sqltype for typ, sqltype in type_annotation_map.items() - for sub_type in compat_typing.expand_unions( - typ, include_union=True, discard_none=True - ) } ) def _resolve_type( - self, python_type: _MatchedOnType + self, python_type: _MatchedOnType, _do_fallbacks: bool = True ) -> Optional[sqltypes.TypeEngine[Any]]: - search: Iterable[Tuple[_MatchedOnType, Type[Any]]] python_type_type: Type[Any] + search: Iterable[Tuple[_MatchedOnType, Type[Any]]] if is_generic(python_type): if is_literal(python_type): - python_type_type = cast("Type[Any]", python_type) + python_type_type = python_type # type: ignore[assignment] - search = ( # type: ignore[assignment] + search = ( (python_type, python_type_type), - (Literal, python_type_type), + *((lt, python_type_type) for lt in LITERAL_TYPES), ) else: python_type_type = python_type.__origin__ search = ((python_type, python_type_type),) - elif is_newtype(python_type): - python_type_type = flatten_newtype(python_type) - search = ((python_type, python_type_type),) - else: - python_type_type = cast("Type[Any]", python_type) - flattened = None + elif isinstance(python_type, type): + python_type_type = python_type search = ((pt, pt) for pt in python_type_type.__mro__) + else: + python_type_type = python_type # type: ignore[assignment] + search = ((python_type, python_type_type),) for pt, flattened in search: # we search through full __mro__ for types. however... @@ -1289,6 +1278,39 @@ def _resolve_type( if resolved_sql_type is not None: return resolved_sql_type + # 2.0 fallbacks + if _do_fallbacks: + python_type_to_check: Any = None + kind = None + if is_pep695(python_type): + # NOTE: assume there aren't type alias types of new types. + python_type_to_check = python_type + while is_pep695(python_type_to_check): + python_type_to_check = python_type_to_check.__value__ + python_type_to_check = de_optionalize_union_types( + python_type_to_check + ) + kind = "TypeAliasType" + if is_newtype(python_type): + python_type_to_check = flatten_newtype(python_type) + kind = "NewType" + + if python_type_to_check is not None: + res_after_fallback = self._resolve_type( + python_type_to_check, False + ) + if res_after_fallback is not None: + assert kind is not None + warn_deprecated( + f"Matching the provided {kind} '{python_type}' on " + "its resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to " + "the type_annotation_map to allow it to match " + "explicitly.", + "2.0", + ) + return res_after_fallback + return None @property @@ -1481,6 +1503,7 @@ def generate_base( Base = mapper_registry.generate_base() + class MyClass(Base): __tablename__ = "my_table" id = Column(Integer, primary_key=True) @@ -1493,6 +1516,7 @@ class MyClass(Base): mapper_registry = registry() + class Base(metaclass=DeclarativeMeta): __abstract__ = True registry = mapper_registry @@ -1578,8 +1602,7 @@ def __class_getitem__(cls: Type[_T], key: Any) -> Type[_T]: ), ) @overload - def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: - ... + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: ... @overload def mapped_as_dataclass( @@ -1594,8 +1617,7 @@ def mapped_as_dataclass( match_args: Union[_NoArg, bool] = ..., kw_only: Union[_NoArg, bool] = ..., dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., - ) -> Callable[[Type[_O]], Type[_O]]: - ... + ) -> Callable[[Type[_O]], Type[_O]]: ... def mapped_as_dataclass( self, @@ -1660,9 +1682,10 @@ def mapped(self, cls: Type[_O]) -> Type[_O]: mapper_registry = registry() + @mapper_registry.mapped class Foo: - __tablename__ = 'some_table' + __tablename__ = "some_table" id = Column(Integer, primary_key=True) name = Column(String) @@ -1702,15 +1725,17 @@ def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]: mapper_registry = registry() + @mapper_registry.as_declarative_base() class Base: @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) - class MyMappedClass(Base): - # ... + + class MyMappedClass(Base): ... All keyword arguments passed to :meth:`_orm.registry.as_declarative_base` are passed @@ -1740,12 +1765,14 @@ def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]: mapper_registry = registry() + class Foo: - __tablename__ = 'some_table' + __tablename__ = "some_table" id = Column(Integer, primary_key=True) name = Column(String) + mapper = mapper_registry.map_declaratively(Foo) This function is more conveniently invoked indirectly via either the @@ -1798,12 +1825,14 @@ def map_imperatively( my_table = Table( "my_table", mapper_registry.metadata, - Column('id', Integer, primary_key=True) + Column("id", Integer, primary_key=True), ) + class MyClass: pass + mapper_registry.map_imperatively(MyClass, my_table) See the section :ref:`orm_imperative_mapping` for complete background @@ -1850,15 +1879,17 @@ def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: from sqlalchemy.orm import as_declarative + @as_declarative() class Base: @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) - class MyMappedClass(Base): - # ... + + class MyMappedClass(Base): ... .. seealso:: diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index d5ef3db470a..d0a78764cc8 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -1,5 +1,5 @@ -# ext/declarative/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# orm/decl_base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -65,11 +65,11 @@ from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType +from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ClassDict @@ -98,12 +98,12 @@ class MappedClassProtocol(Protocol[_O]): __mapper__: Mapper[_O] __table__: FromClause - def __call__(self, **kw: Any) -> _O: - ... + def __call__(self, **kw: Any) -> _O: ... class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): "Internal more detailed version of ``MappedClassProtocol``." + metadata: MetaData __tablename__: str __mapper_args__: _MapperKwArgs @@ -111,11 +111,9 @@ class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): _sa_apply_dc_transforms: Optional[_DataclassArguments] - def __declare_first__(self) -> None: - ... + def __declare_first__(self) -> None: ... - def __declare_last__(self) -> None: - ... + def __declare_last__(self) -> None: ... class _DataclassArguments(TypedDict): @@ -434,7 +432,7 @@ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: class _CollectedAnnotation(NamedTuple): raw_annotation: _AnnotationScanType mapped_container: Optional[Type[Mapped[Any]]] - extracted_mapped_annotation: Union[Type[Any], str] + extracted_mapped_annotation: Union[_AnnotationScanType, str] is_dataclass: bool attr_value: Any originating_module: str @@ -456,6 +454,7 @@ class _ClassScanMapperConfig(_MapperConfig): "tablename", "mapper_args", "mapper_args_fn", + "table_fn", "inherits", "single", "allow_dataclass_fields", @@ -762,7 +761,7 @@ def _scan_attributes(self) -> None: _include_dunders = self._include_dunders mapper_args_fn = None table_args = inherited_table_args = None - + table_fn = None tablename = None fixed_table = "__table__" in clsdict_view @@ -843,6 +842,22 @@ def _mapper_args_fn() -> Dict[str, Any]: ) if not tablename and (not class_mapped or check_decl): tablename = cls_as_Decl.__tablename__ + elif name == "__table__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + # if a @declared_attr using "__table__" is detected, + # wrap up a callable to look for "__table__" from + # the final concrete class when we set up a table. + # this was fixed by + # #11509, regression in 2.0 from version 1.4. + if check_decl and not table_fn: + # don't even invoke __table__ until we're ready + def _table_fn() -> FromClause: + return cls_as_Decl.__table__ + + table_fn = _table_fn + elif name == "__table_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -859,9 +874,10 @@ def _mapper_args_fn() -> Dict[str, Any]: if base is not cls: inherited_table_args = True else: - # skip all other dunder names, which at the moment - # should only be __table__ - continue + # any other dunder names; should not be here + # as we have tested for all four names in + # _include_dunders + assert False elif class_mapped: if _is_declarative_props(obj) and not obj._quiet: util.warn( @@ -908,9 +924,9 @@ def _mapper_args_fn() -> Dict[str, Any]: "@declared_attr.cascading; " "skipping" % (name, cls) ) - collected_attributes[name] = column_copies[ - obj - ] = ret = obj.__get__(obj, cls) + collected_attributes[name] = column_copies[obj] = ( + ret + ) = obj.__get__(obj, cls) setattr(cls, name, ret) else: if is_dataclass_field: @@ -947,9 +963,9 @@ def _mapper_args_fn() -> Dict[str, Any]: ): ret = ret.descriptor - collected_attributes[name] = column_copies[ - obj - ] = ret + collected_attributes[name] = column_copies[obj] = ( + ret + ) if ( isinstance(ret, (Column, MapperProperty)) @@ -1034,6 +1050,7 @@ def _mapper_args_fn() -> Dict[str, Any]: self.table_args = table_args self.tablename = tablename self.mapper_args_fn = mapper_args_fn + self.table_fn = table_fn def _setup_dataclasses_transforms(self) -> None: dataclass_setup_arguments = self.dataclass_setup_arguments @@ -1051,6 +1068,16 @@ def _setup_dataclasses_transforms(self) -> None: "'@registry.mapped_as_dataclass'" ) + # can't create a dataclass if __table__ is already there. This would + # fail an assertion when calling _get_arguments_for_make_dataclass: + # assert False, "Mapped[] received without a mapping declaration" + if "__table__" in self.cls.__dict__: + raise exc.InvalidRequestError( + f"Class {self.cls} already defines a '__table__'. " + "ORM Annotated Dataclasses do not support a pre-existing " + "'__table__' element" + ) + warn_for_non_dc_attrs = collections.defaultdict(list) def _allow_dataclass_field( @@ -1130,9 +1157,9 @@ def _allow_dataclass_field( defaults = {} for item in field_list: if len(item) == 2: - name, tp = item # type: ignore + name, tp = item elif len(item) == 3: - name, tp, spec = item # type: ignore + name, tp, spec = item defaults[name] = spec else: assert False @@ -1270,8 +1297,6 @@ def _collect_annotation( or isinstance(attr_value, _MappedAttribute) ) ) - else: - is_dataclass_field = False is_dataclass_field = False extracted = _extract_mapped_subtype( @@ -1282,10 +1307,8 @@ def _collect_annotation( type(attr_value), required=False, is_dataclass_field=is_dataclass_field, - expect_mapped=expect_mapped - and not is_dataclass, # self.allow_dataclass_fields, + expect_mapped=expect_mapped and not is_dataclass, ) - if extracted is None: # ClassVar can come out here return None @@ -1293,9 +1316,9 @@ def _collect_annotation( extracted_mapped_annotation, mapped_container = extracted if attr_value is None and not is_literal(extracted_mapped_annotation): - for elem in typing_get_args(extracted_mapped_annotation): - if isinstance(elem, str) or is_fwd_ref( - elem, check_generic=True + for elem in get_args(extracted_mapped_annotation): + if is_fwd_ref( + elem, check_generic=True, check_for_plain_string=True ): elem = de_stringify_annotation( self.cls, @@ -1553,7 +1576,7 @@ def _extract_mappable_attributes(self) -> None: is_dataclass, ) except NameError as ne: - raise exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f"Could not resolve all types within mapped " f'annotation: "{annotation}". Ensure all ' f"types are written correctly and are " @@ -1690,7 +1713,11 @@ def _setup_table(self, table: Optional[FromClause] = None) -> None: manager = attributes.manager_of_class(cls) - if "__table__" not in clsdict_view and table is None: + if ( + self.table_fn is None + and "__table__" not in clsdict_view + and table is None + ): if hasattr(cls, "__table_cls__"): table_cls = cast( Type[Table], @@ -1736,7 +1763,12 @@ def _setup_table(self, table: Optional[FromClause] = None) -> None: ) else: if table is None: - table = cls_as_Decl.__table__ + if self.table_fn: + table = self.set_cls_attribute( + "__table__", self.table_fn() + ) + else: + table = cls_as_Decl.__table__ if declared_columns: for c in declared_columns: if not table.c.contains_column(c): @@ -1985,8 +2017,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: pass - # mypy disallows plain property override of variable - @property # type: ignore + @property def cls(self) -> Type[Any]: return self._cls() # type: ignore diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index e941dbcbf47..a8cafdd0b7a 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -1,5 +1,5 @@ # orm/dependency.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,9 +7,7 @@ # mypy: ignore-errors -"""Relationship dependencies. - -""" +"""Relationship dependencies.""" from __future__ import annotations @@ -167,9 +165,11 @@ def per_state_flush_actions(self, uow, states, isdelete): sum_ = state.manager[self.key].impl.get_all_pending( state, state.dict, - self._passive_delete_flag - if isdelete - else attributes.PASSIVE_NO_INITIALIZE, + ( + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE + ), ) if not sum_: @@ -1052,7 +1052,7 @@ def presort_saves(self, uowcommit, states): # so that prop_has_changes() returns True for state in states: if self._pks_changed(uowcommit, state): - history = uowcommit.get_attribute_history( + uowcommit.get_attribute_history( state, self.key, attributes.PASSIVE_OFF ) diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index c1fe9de85ca..2d1ec13f19e 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -1,5 +1,5 @@ # orm/descriptor_props.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -53,9 +53,10 @@ from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter +from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 -from ..util.typing import typing_get_args + if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -98,6 +99,11 @@ class DescriptorProperty(MapperProperty[_T]): descriptor: DescriptorReference[Any] + def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: + raise NotImplementedError( + "This MapperProperty does not implement column loader strategies" + ) + def get_history( self, state: InstanceState[Any], @@ -364,7 +370,7 @@ def declarative_scan( argument = extracted_mapped_annotation if is_pep593(argument): - argument = typing_get_args(argument)[0] + argument = get_args(argument)[0] if argument and self.composite_class is None: if isinstance(argument, str) or is_fwd_ref( @@ -419,13 +425,13 @@ def _init_accessor(self) -> None: and self.composite_class not in _composite_getters ): if self._generated_composite_accessor is not None: - _composite_getters[ - self.composite_class - ] = self._generated_composite_accessor + _composite_getters[self.composite_class] = ( + self._generated_composite_accessor + ) elif hasattr(self.composite_class, "__composite_values__"): - _composite_getters[ - self.composite_class - ] = lambda obj: obj.__composite_values__() + _composite_getters[self.composite_class] = ( + lambda obj: obj.__composite_values__() + ) @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") @@ -499,6 +505,9 @@ def props(self) -> Sequence[MapperProperty[Any]]: props.append(prop) return props + def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: + return self._comparable_elements + @util.non_memoized_property @util.preload_module("orm.properties") def columns(self) -> Sequence[Column[Any]]: @@ -781,7 +790,9 @@ def _bulk_update_tuples( elif isinstance(self.prop.composite_class, type) and isinstance( value, self.prop.composite_class ): - values = self.prop._composite_values_from_instance(value) + values = self.prop._composite_values_from_instance( + value # type: ignore[arg-type] + ) else: raise sa_exc.ArgumentError( "Can't UPDATE composite attribute %s to %r" @@ -996,6 +1007,9 @@ def _proxied_object( ) return attr.property + def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: + return (getattr(self.parent.class_, self.name),) + def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: prop = self._proxied_object diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 1d0c03606c8..3c81c396f6e 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -1,5 +1,5 @@ # orm/dynamic.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -161,10 +161,12 @@ def _iter(self) -> Union[result.ScalarResult[_T], result.Result[_T]]: return result.IteratorResult( result.SimpleResultMetaData([self.attr.class_.__name__]), - self.attr._get_collection_history( # type: ignore[arg-type] - attributes.instance_state(self.instance), - PassiveFlag.PASSIVE_NO_INITIALIZE, - ).added_items, + iter( + self.attr._get_collection_history( + attributes.instance_state(self.instance), + PassiveFlag.PASSIVE_NO_INITIALIZE, + ).added_items + ), _source_supports_scalars=True, ).scalars() else: @@ -172,8 +174,7 @@ def _iter(self) -> Union[result.ScalarResult[_T], result.Result[_T]]: if TYPE_CHECKING: - def __iter__(self) -> Iterator[_T]: - ... + def __iter__(self) -> Iterator[_T]: ... def __getitem__(self, index: Any) -> Union[_T, List[_T]]: sess = self.session diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index f3796f03d1e..57aae5a3c49 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -1,5 +1,5 @@ # orm/evaluator.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,6 +28,7 @@ from .. import inspect from ..sql import and_ from ..sql import operators +from ..sql.sqltypes import Concatenable from ..sql.sqltypes import Integer from ..sql.sqltypes import Numeric from ..util import warn_deprecated @@ -311,6 +312,16 @@ def visit_not_in_op_binary_op( def visit_concat_op_binary_op( self, operator, eval_left, eval_right, clause ): + + if not issubclass( + clause.left.type._type_affinity, Concatenable + ) or not issubclass(clause.right.type._type_affinity, Concatenable): + raise UnevaluatableError( + f"Cannot evaluate concatenate operator " + f'"{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + return self._straight_evaluate( lambda a, b: a + b, eval_left, eval_right, clause ) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index e7e3e32a7ff..5af78fc6b76 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1,13 +1,11 @@ # orm/events.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""ORM event interfaces. - -""" +"""ORM event interfaces.""" from __future__ import annotations from typing import Any @@ -207,10 +205,12 @@ class InstanceEvents(event.Events[ClassManager[Any]]): from sqlalchemy import event + def my_load_listener(target, context): print("on load!") - event.listen(SomeClass, 'load', my_load_listener) + + event.listen(SomeClass, "load", my_load_listener) Available targets include: @@ -466,8 +466,7 @@ def load(self, target: _O, context: QueryContext) -> None: the existing loading context is maintained for the object after the event is called:: - @event.listens_for( - SomeClass, "load", restore_load_context=True) + @event.listens_for(SomeClass, "load", restore_load_context=True) def on_load(instance, context): instance.some_unloaded_attribute @@ -494,15 +493,15 @@ def on_load(instance, context): .. seealso:: + :ref:`mapped_class_load_events` + :meth:`.InstanceEvents.init` :meth:`.InstanceEvents.refresh` :meth:`.SessionEvents.loaded_as_persistent` - :ref:`mapping_constructors` - - """ + """ # noqa: E501 def refresh( self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]] @@ -534,6 +533,8 @@ def refresh( .. seealso:: + :ref:`mapped_class_load_events` + :meth:`.InstanceEvents.load` """ @@ -577,6 +578,8 @@ def refresh_flush( .. seealso:: + :ref:`mapped_class_load_events` + :ref:`orm_server_defaults` :ref:`metadata_defaults_toplevel` @@ -725,9 +728,9 @@ def populate( class _InstanceEventsHold(_EventsHold[_ET]): - all_holds: weakref.WeakKeyDictionary[ - Any, Any - ] = weakref.WeakKeyDictionary() + all_holds: weakref.WeakKeyDictionary[Any, Any] = ( + weakref.WeakKeyDictionary() + ) def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) @@ -745,6 +748,7 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy import event + def my_before_insert_listener(mapper, connection, target): # execute a stored procedure upon INSERT, # apply the value to the row to be inserted @@ -752,10 +756,10 @@ def my_before_insert_listener(mapper, connection, target): text("select my_special_function(%d)" % target.special_number) ).scalar() + # associate the listener function with SomeClass, # to execute during the "before_insert" hook - event.listen( - SomeClass, 'before_insert', my_before_insert_listener) + event.listen(SomeClass, "before_insert", my_before_insert_listener) Available targets include: @@ -921,9 +925,10 @@ class overall, or to any un-mapped class which serves as a base Base = declarative_base() + @event.listens_for(Base, "instrument_class", propagate=True) def on_new_class(mapper, cls_): - " ... " + "..." :param mapper: the :class:`_orm.Mapper` which is the target of this event. @@ -979,7 +984,7 @@ def before_mapper_configured( symbol which indicates to the :func:`.configure_mappers` call that this particular mapper (or hierarchy of mappers, if ``propagate=True`` is used) should be skipped in the current configuration run. When one or - more mappers are skipped, the he "new mappers" flag will remain set, + more mappers are skipped, the "new mappers" flag will remain set, meaning the :func:`.configure_mappers` function will continue to be called when mappers are used, to continue to try to configure all available mappers. @@ -988,7 +993,7 @@ def before_mapper_configured( :meth:`.MapperEvents.before_configured`, :meth:`.MapperEvents.after_configured`, and :meth:`.MapperEvents.mapper_configured`, the - :meth;`.MapperEvents.before_mapper_configured` event provides for a + :meth:`.MapperEvents.before_mapper_configured` event provides for a meaningful return value when it is registered with the ``retval=True`` parameter. @@ -1002,13 +1007,16 @@ def before_mapper_configured( DontConfigureBase = declarative_base() + @event.listens_for( DontConfigureBase, - "before_mapper_configured", retval=True, propagate=True) + "before_mapper_configured", + retval=True, + propagate=True, + ) def dont_configure(mapper, cls): return EXT_SKIP - .. seealso:: :meth:`.MapperEvents.before_configured` @@ -1090,9 +1098,9 @@ def before_configured(self) -> None: from sqlalchemy.orm import Mapper + @event.listens_for(Mapper, "before_configured") - def go(): - ... + def go(): ... Contrast this event to :meth:`.MapperEvents.after_configured`, which is invoked after the series of mappers has been configured, @@ -1110,10 +1118,9 @@ def go(): from sqlalchemy.orm import mapper - @event.listens_for(mapper, "before_configured", once=True) - def go(): - ... + @event.listens_for(mapper, "before_configured", once=True) + def go(): ... .. seealso:: @@ -1150,9 +1157,9 @@ def after_configured(self) -> None: from sqlalchemy.orm import Mapper + @event.listens_for(Mapper, "after_configured") - def go(): - # ... + def go(): ... Theoretically this event is called once per application, but is actually called any time new mappers @@ -1164,9 +1171,9 @@ def go(): from sqlalchemy.orm import mapper + @event.listens_for(mapper, "after_configured", once=True) - def go(): - # ... + def go(): ... .. seealso:: @@ -1553,9 +1560,11 @@ class SessionEvents(event.Events[Session]): from sqlalchemy import event from sqlalchemy.orm import sessionmaker + def my_before_commit(session): print("before commit!") + Session = sessionmaker() event.listen(Session, "before_commit", my_before_commit) @@ -1591,7 +1600,7 @@ def my_before_commit(session): _dispatch_target = Session def _lifecycle_event( # type: ignore [misc] - fn: Callable[[SessionEvents, Session, Any], None] + fn: Callable[[SessionEvents, Session, Any], None], ) -> Callable[[SessionEvents, Session, Any], None]: _sessionevents_lifecycle_event_names.add(fn.__name__) return fn @@ -1775,7 +1784,7 @@ def after_transaction_create( @event.listens_for(session, "after_transaction_create") def after_transaction_create(session, transaction): if transaction.parent is None: - # work with top-level transaction + ... # work with top-level transaction To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the :attr:`.SessionTransaction.nested` attribute:: @@ -1783,8 +1792,7 @@ def after_transaction_create(session, transaction): @event.listens_for(session, "after_transaction_create") def after_transaction_create(session, transaction): if transaction.nested: - # work with SAVEPOINT transaction - + ... # work with SAVEPOINT transaction .. seealso:: @@ -1816,7 +1824,7 @@ def after_transaction_end( @event.listens_for(session, "after_transaction_create") def after_transaction_end(session, transaction): if transaction.parent is None: - # work with top-level transaction + ... # work with top-level transaction To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the :attr:`.SessionTransaction.nested` attribute:: @@ -1824,8 +1832,7 @@ def after_transaction_end(session, transaction): @event.listens_for(session, "after_transaction_create") def after_transaction_end(session, transaction): if transaction.nested: - # work with SAVEPOINT transaction - + ... # work with SAVEPOINT transaction .. seealso:: @@ -1935,7 +1942,7 @@ def after_soft_rollback( @event.listens_for(Session, "after_soft_rollback") def do_something(session, previous_transaction): if session.is_active: - session.execute("select * from some_table") + session.execute(text("select * from some_table")) :param session: The target :class:`.Session`. :param previous_transaction: The :class:`.SessionTransaction` @@ -2035,7 +2042,14 @@ def after_begin( transaction: SessionTransaction, connection: Connection, ) -> None: - """Execute after a transaction is begun on a connection + """Execute after a transaction is begun on a connection. + + .. note:: This event is called within the process of the + :class:`_orm.Session` modifying its own internal state. + To invoke SQL operations within this hook, use the + :class:`_engine.Connection` provided to the event; + do not run SQL operations using the :class:`_orm.Session` + directly. :param session: The target :class:`.Session`. :param transaction: The :class:`.SessionTransaction`. @@ -2444,11 +2458,11 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): from sqlalchemy import event - @event.listens_for(MyClass.collection, 'append', propagate=True) + + @event.listens_for(MyClass.collection, "append", propagate=True) def my_append_listener(target, value, initiator): print("received append event for target: %s" % target) - Listeners have the option to return a possibly modified version of the value, when the :paramref:`.AttributeEvents.retval` flag is passed to :func:`.event.listen` or :func:`.event.listens_for`, such as below, @@ -2457,11 +2471,12 @@ def my_append_listener(target, value, initiator): def validate_phone(target, value, oldvalue, initiator): "Strip non-numeric characters from a phone number" - return re.sub(r'\D', '', value) + return re.sub(r"\D", "", value) + # setup listener on UserContact.phone attribute, instructing # it to use the return value - listen(UserContact.phone, 'set', validate_phone, retval=True) + listen(UserContact.phone, "set", validate_phone, retval=True) A validation function like the above can also raise an exception such as :exc:`ValueError` to halt the operation. @@ -2471,7 +2486,7 @@ def validate_phone(target, value, oldvalue, initiator): as when using mapper inheritance patterns:: - @event.listens_for(MySuperClass.attr, 'set', propagate=True) + @event.listens_for(MySuperClass.attr, "set", propagate=True) def receive_set(target, value, initiator): print("value set: %s" % target) @@ -2704,10 +2719,12 @@ def bulk_replace( from sqlalchemy.orm.attributes import OP_BULK_REPLACE + @event.listens_for(SomeObject.collection, "bulk_replace") def process_collection(target, values, initiator): values[:] = [_make_value(value) for value in values] + @event.listens_for(SomeObject.collection, "append", retval=True) def process_collection(target, value, initiator): # make sure bulk_replace didn't already do it @@ -2855,16 +2872,18 @@ def init_scalar( SOME_CONSTANT = 3.1415926 + class MyClass(Base): # ... some_attribute = Column(Numeric, default=SOME_CONSTANT) + @event.listens_for( - MyClass.some_attribute, "init_scalar", - retval=True, propagate=True) + MyClass.some_attribute, "init_scalar", retval=True, propagate=True + ) def _init_some_attribute(target, dict_, value): - dict_['some_attribute'] = SOME_CONSTANT + dict_["some_attribute"] = SOME_CONSTANT return SOME_CONSTANT Above, we initialize the attribute ``MyClass.some_attribute`` to the @@ -2900,9 +2919,10 @@ def _init_some_attribute(target, dict_, value): SOME_CONSTANT = 3.1415926 + @event.listens_for( - MyClass.some_attribute, "init_scalar", - retval=True, propagate=True) + MyClass.some_attribute, "init_scalar", retval=True, propagate=True + ) def _init_some_attribute(target, dict_, value): # will also fire off attribute set events target.some_attribute = SOME_CONSTANT @@ -2939,7 +2959,7 @@ def _init_some_attribute(target, dict_, value): :ref:`examples_instrumentation` - see the ``active_column_defaults.py`` example. - """ + """ # noqa: E501 def init_collection( self, @@ -3077,8 +3097,8 @@ def before_compile(self, query: Query[Any]) -> None: @event.listens_for(Query, "before_compile", retval=True) def no_deleted(query): for desc in query.column_descriptions: - if desc['type'] is User: - entity = desc['entity'] + if desc["type"] is User: + entity = desc["entity"] query = query.filter(entity.deleted == False) return query @@ -3094,12 +3114,11 @@ def no_deleted(query): re-establish the query being cached, apply the event adding the ``bake_ok`` flag:: - @event.listens_for( - Query, "before_compile", retval=True, bake_ok=True) + @event.listens_for(Query, "before_compile", retval=True, bake_ok=True) def my_event(query): for desc in query.column_descriptions: - if desc['type'] is User: - entity = desc['entity'] + if desc["type"] is User: + entity = desc["entity"] query = query.filter(entity.deleted == False) return query @@ -3120,7 +3139,7 @@ def my_event(query): :ref:`baked_with_before_compile` - """ + """ # noqa: E501 def before_compile_update( self, query: Query[Any], update_context: BulkUpdate @@ -3140,11 +3159,13 @@ def before_compile_update( @event.listens_for(Query, "before_compile_update", retval=True) def no_deleted(query, update_context): for desc in query.column_descriptions: - if desc['type'] is User: - entity = desc['entity'] + if desc["type"] is User: + entity = desc["entity"] query = query.filter(entity.deleted == False) - update_context.values['timestamp'] = datetime.utcnow() + update_context.values["timestamp"] = datetime.datetime.now( + datetime.UTC + ) return query The ``.values`` dictionary of the "update context" object can also @@ -3172,7 +3193,7 @@ def no_deleted(query, update_context): :meth:`.QueryEvents.before_compile_delete` - """ + """ # noqa: E501 def before_compile_delete( self, query: Query[Any], delete_context: BulkDelete @@ -3191,8 +3212,8 @@ def before_compile_delete( @event.listens_for(Query, "before_compile_delete", retval=True) def no_deleted(query, delete_context): for desc in query.column_descriptions: - if desc['type'] is User: - entity = desc['entity'] + if desc["type"] is User: + entity = desc["entity"] query = query.filter(entity.deleted == False) return query diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index f30e50350ba..a2f7c9f78a3 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -1,5 +1,5 @@ # orm/exc.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING from typing import TypeVar +from .util import _mapper_property_as_plain_name from .. import exc as sa_exc from .. import util from ..exc import MultipleResultsFound # noqa @@ -64,6 +65,15 @@ class FlushError(sa_exc.SQLAlchemyError): """A invalid condition was detected during flush().""" +class MappedAnnotationError(sa_exc.ArgumentError): + """Raised when ORM annotated declarative cannot interpret the + expression present inside of the :class:`.Mapped` construct. + + .. versionadded:: 2.0.40 + + """ + + class UnmappedError(sa_exc.InvalidRequestError): """Base for exceptions that involve expected mappings not present.""" @@ -191,8 +201,8 @@ def __init__( % ( util.clsname_as_plain_name(actual_strategy_type), requesting_property, - util.clsname_as_plain_name(applied_to_property_type), - util.clsname_as_plain_name(applies_to), + _mapper_property_as_plain_name(applied_to_property_type), + _mapper_property_as_plain_name(applies_to), ), ) diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 81140a94ef5..1808b2d5e59 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -1,5 +1,5 @@ # orm/identity.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index b12d80ac4f7..f87023f1809 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -1,5 +1,5 @@ # orm/instrumentation.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -85,13 +85,11 @@ def __call__( state: state.InstanceState[Any], toload: Set[str], passive: base.PassiveFlag, - ) -> None: - ... + ) -> None: ... class _ManagerFactory(Protocol): - def __call__(self, class_: Type[_O]) -> ClassManager[_O]: - ... + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ... class ClassManager( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index a118b2aa854..b4462e54593 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -1,5 +1,5 @@ # orm/interfaces.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -115,7 +115,7 @@ class ORMStatementRole(roles.StatementRole): __slots__ = () _role_name = ( - "Executable SQL or text() construct, including ORM " "aware objects" + "Executable SQL or text() construct, including ORM aware objects" ) @@ -149,13 +149,17 @@ class ORMColumnDescription(TypedDict): class _IntrospectsAnnotations: __slots__ = () + @classmethod + def _mapper_property_name(cls) -> str: + return cls.__name__ + def found_in_pep593_annotated(self) -> Any: """return a copy of this object to use in declarative when the object is found inside of an Annotated object.""" raise NotImplementedError( - f"Use of the {self.__class__} construct inside of an " - f"Annotated object is not yet supported." + f"Use of the {self._mapper_property_name()!r} " + "construct inside of an Annotated object is not yet supported." ) def declarative_scan( @@ -181,7 +185,8 @@ def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: raise sa_exc.ArgumentError( f"Python typing annotation is required for attribute " f'"{cls.__name__}.{key}" when primary argument(s) for ' - f'"{self.__class__.__name__}" construct are None or not present' + f'"{self._mapper_property_name()}" ' + "construct are None or not present" ) @@ -201,6 +206,7 @@ class _AttributeOptions(NamedTuple): dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] dataclasses_compare: Union[_NoArg, bool] dataclasses_kw_only: Union[_NoArg, bool] + dataclasses_hash: Union[_NoArg, bool, None] def _as_dataclass_field(self, key: str) -> Any: """Return a ``dataclasses.Field`` object given these arguments.""" @@ -218,6 +224,8 @@ def _as_dataclass_field(self, key: str) -> Any: kw["compare"] = self.dataclasses_compare if self.dataclasses_kw_only is not _NoArg.NO_ARG: kw["kw_only"] = self.dataclasses_kw_only + if self.dataclasses_hash is not _NoArg.NO_ARG: + kw["hash"] = self.dataclasses_hash if "default" in kw and callable(kw["default"]): # callable defaults are ambiguous. deprecate them in favour of @@ -297,6 +305,7 @@ def _get_arguments_for_make_dataclass( _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, + _NoArg.NO_ARG, ) _DEFAULT_READONLY_ATTRIBUTE_OPTIONS = _AttributeOptions( @@ -306,6 +315,7 @@ def _get_arguments_for_make_dataclass( _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, + _NoArg.NO_ARG, ) @@ -675,27 +685,37 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): # definition of custom PropComparator subclasses - from sqlalchemy.orm.properties import \ - ColumnProperty,\ - Composite,\ - Relationship + from sqlalchemy.orm.properties import ( + ColumnProperty, + Composite, + Relationship, + ) + class MyColumnComparator(ColumnProperty.Comparator): def __eq__(self, other): return self.__clause_element__() == other + class MyRelationshipComparator(Relationship.Comparator): def any(self, expression): "define the 'any' operation" # ... + class MyCompositeComparator(Composite.Comparator): def __gt__(self, other): "redefine the 'greater than' operation" - return sql.and_(*[a>b for a, b in - zip(self.__clause_element__().clauses, - other.__composite_values__())]) + return sql.and_( + *[ + a > b + for a, b in zip( + self.__clause_element__().clauses, + other.__composite_values__(), + ) + ] + ) # application of custom PropComparator subclasses @@ -703,17 +723,22 @@ def __gt__(self, other): from sqlalchemy.orm import column_property, relationship, composite from sqlalchemy import Column, String + class SomeMappedClass(Base): - some_column = column_property(Column("some_column", String), - comparator_factory=MyColumnComparator) + some_column = column_property( + Column("some_column", String), + comparator_factory=MyColumnComparator, + ) - some_relationship = relationship(SomeOtherClass, - comparator_factory=MyRelationshipComparator) + some_relationship = relationship( + SomeOtherClass, comparator_factory=MyRelationshipComparator + ) some_composite = composite( - Column("a", String), Column("b", String), - comparator_factory=MyCompositeComparator - ) + Column("a", String), + Column("b", String), + comparator_factory=MyCompositeComparator, + ) Note that for column-level operator redefinition, it's usually simpler to define the operators at the Core level, using the @@ -735,6 +760,7 @@ class SomeMappedClass(Base): :attr:`.TypeEngine.comparator_factory` """ + __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" @@ -754,7 +780,7 @@ def __init__( self._adapt_to_entity = adapt_to_entity @util.non_memoized_property - def property(self) -> MapperProperty[_T]: + def property(self) -> MapperProperty[_T_co]: """Return the :class:`.MapperProperty` associated with this :class:`.PropComparator`. @@ -784,7 +810,7 @@ def _bulk_update_tuples( def adapt_to_entity( self, adapt_to_entity: AliasedInsp[Any] - ) -> PropComparator[_T]: + ) -> PropComparator[_T_co]: """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ @@ -838,15 +864,13 @@ def _of_type_op(a: Any, class_: Any) -> Any: def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -856,8 +880,9 @@ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: e.g.:: - query.join(Company.employees.of_type(Engineer)).\ - filter(Engineer.name=='foo') + query.join(Company.employees.of_type(Engineer)).filter( + Engineer.name == "foo" + ) :param \class_: a class or mapper indicating that criterion will be against this specific subclass. @@ -883,11 +908,11 @@ def and_( stmt = select(User).join( - User.addresses.and_(Address.email_address != 'foo') + User.addresses.and_(Address.email_address != "foo") ) stmt = select(User).options( - joinedload(User.addresses.and_(Address.email_address != 'foo')) + joinedload(User.addresses.and_(Address.email_address != "foo")) ) .. versionadded:: 1.4 diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index cae6f0be21c..679286f5466 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -1,5 +1,5 @@ # orm/loading.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -149,9 +149,11 @@ def go(obj): raise sa_exc.InvalidRequestError( "Can't apply uniqueness to row tuple containing value of " - f"""type {datatype!r}; {'the values returned appear to be' - if uncertain else 'this datatype produces'} """ - "non-hashable values" + f"""type {datatype!r}; { + 'the values returned appear to be' + if uncertain + else 'this datatype produces' + } non-hashable values""" ) return go @@ -179,20 +181,22 @@ def go(obj): return go unique_filters = [ - _no_unique - if context.yield_per - else _not_hashable( - ent.column.type, # type: ignore - legacy=context.load_options._legacy_uniquing, - uncertain=ent._null_column_type, - ) - if ( - not ent.use_id_for_hash - and (ent._non_hashable_value or ent._null_column_type) + ( + _no_unique + if context.yield_per + else ( + _not_hashable( + ent.column.type, # type: ignore + legacy=context.load_options._legacy_uniquing, + uncertain=ent._null_column_type, + ) + if ( + not ent.use_id_for_hash + and (ent._non_hashable_value or ent._null_column_type) + ) + else id if ent.use_id_for_hash else None + ) ) - else id - if ent.use_id_for_hash - else None for ent in context.compile_state._entities ] @@ -1006,21 +1010,38 @@ def _instance_processor( # loading does not apply assert only_load_props is None - callable_ = _load_subclass_via_in( - context, - path, - selectin_load_via, - _polymorphic_from, - option_entities, - ) - PostLoad.callable_for_path( - context, - load_path, - selectin_load_via.mapper, - selectin_load_via, - callable_, - selectin_load_via, - ) + if selectin_load_via.is_mapper: + _load_supers = [] + _endmost_mapper = selectin_load_via + while ( + _endmost_mapper + and _endmost_mapper is not _polymorphic_from + ): + _load_supers.append(_endmost_mapper) + _endmost_mapper = _endmost_mapper.inherits + else: + _load_supers = [selectin_load_via] + + for _selectinload_entity in _load_supers: + if PostLoad.path_exists( + context, load_path, _selectinload_entity + ): + continue + callable_ = _load_subclass_via_in( + context, + path, + _selectinload_entity, + _polymorphic_from, + option_entities, + ) + PostLoad.callable_for_path( + context, + load_path, + _selectinload_entity.mapper, + _selectinload_entity, + callable_, + _selectinload_entity, + ) post_load = PostLoad.for_context(context, load_path, only_load_props) diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 9e479d0d308..ca085c40376 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -1,5 +1,5 @@ -# orm/collections.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# orm/mapped_collection.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,6 +29,8 @@ from ..sql import coercions from ..sql import expression from ..sql import roles +from ..util.langhelpers import Missing +from ..util.langhelpers import MissingOr from ..util.typing import Literal if TYPE_CHECKING: @@ -40,8 +42,6 @@ _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) -_F = TypeVar("_F", bound=Callable[[Any], Any]) - class _PlainColumnGetter(Generic[_KT]): """Plain column getter, stores collection of Column objects @@ -70,7 +70,7 @@ def __reduce__( def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: return self.cols - def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: + def __call__(self, value: _KT) -> MissingOr[Union[_KT, Tuple[_KT, ...]]]: state = base.instance_state(value) m = base._state_mapper(state) @@ -83,7 +83,7 @@ def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: else: obj = key[0] if obj is None: - return _UNMAPPED_AMBIGUOUS_NONE + return Missing else: return obj @@ -117,9 +117,7 @@ def __reduce__( return self.__class__, (self.colkeys,) @classmethod - def _reduce_from_cols( - cls, cols: Sequence[ColumnElement[_KT]] - ) -> Tuple[ + def _reduce_from_cols(cls, cols: Sequence[ColumnElement[_KT]]) -> Tuple[ Type[_SerializableColumnGetterV2[_KT]], Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], ]: @@ -200,9 +198,6 @@ def column_keyed_dict( ) -_UNMAPPED_AMBIGUOUS_NONE = object() - - class _AttrGetter: __slots__ = ("attr_name", "getter") @@ -219,9 +214,9 @@ def __call__(self, mapped_object: Any) -> Any: dict_ = state.dict obj = dict_.get(self.attr_name, base.NO_VALUE) if obj is None: - return _UNMAPPED_AMBIGUOUS_NONE + return Missing else: - return _UNMAPPED_AMBIGUOUS_NONE + return Missing return obj @@ -231,7 +226,7 @@ def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]: def attribute_keyed_dict( attr_name: str, *, ignore_unpopulated_attribute: bool = False -) -> Type[KeyFuncDict[_KT, _KT]]: +) -> Type[KeyFuncDict[Any, Any]]: """A dictionary-based collection type with attribute-based keying. .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to @@ -279,7 +274,7 @@ def attribute_keyed_dict( def keyfunc_mapping( - keyfunc: _F, + keyfunc: Callable[[Any], Any], *, ignore_unpopulated_attribute: bool = False, ) -> Type[KeyFuncDict[_KT, Any]]: @@ -355,7 +350,7 @@ class KeyFuncDict(Dict[_KT, _VT]): def __init__( self, - keyfunc: _F, + keyfunc: Callable[[Any], Any], *dict_args: Any, ignore_unpopulated_attribute: bool = False, ) -> None: @@ -379,7 +374,7 @@ def __init__( @classmethod def _unreduce( cls, - keyfunc: _F, + keyfunc: Callable[[Any], Any], values: Dict[_KT, _KT], adapter: Optional[CollectionAdapter] = None, ) -> "KeyFuncDict[_KT, _KT]": @@ -466,7 +461,7 @@ def set( ) else: return - elif key is _UNMAPPED_AMBIGUOUS_NONE: + elif key is Missing: if not self.ignore_unpopulated_attribute: self._raise_for_unpopulated( value, _sa_initiator, warn_only=True @@ -494,7 +489,7 @@ def remove( value, _sa_initiator, warn_only=False ) return - elif key is _UNMAPPED_AMBIGUOUS_NONE: + elif key is Missing: if not self.ignore_unpopulated_attribute: self._raise_for_unpopulated( value, _sa_initiator, warn_only=True @@ -516,7 +511,7 @@ def remove( def _mapped_collection_cls( - keyfunc: _F, ignore_unpopulated_attribute: bool + keyfunc: Callable[[Any], Any], ignore_unpopulated_attribute: bool ) -> Type[KeyFuncDict[_KT, _KT]]: class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]): def __init__(self, *dict_args: Any) -> None: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c66d876e087..ae7f8f24fc4 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,5 +1,5 @@ # orm/mapper.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -132,9 +132,9 @@ ] -_mapper_registries: weakref.WeakKeyDictionary[ - _RegistryType, bool -] = weakref.WeakKeyDictionary() +_mapper_registries: weakref.WeakKeyDictionary[_RegistryType, bool] = ( + weakref.WeakKeyDictionary() +) def _all_registries() -> Set[registry]: @@ -296,6 +296,17 @@ class will overwrite all data within object instances that already particular primary key value. A "partial primary key" can occur if one has mapped to an OUTER JOIN, for example. + The :paramref:`.orm.Mapper.allow_partial_pks` parameter also + indicates to the ORM relationship lazy loader, when loading a + many-to-one related object, if a composite primary key that has + partial NULL values should result in an attempt to load from the + database, or if a load attempt is not necessary. + + .. versionadded:: 2.0.36 :paramref:`.orm.Mapper.allow_partial_pks` + is consulted by the relationship lazy loader strategy, such that + when set to False, a SELECT for a composite primary key that + has partial NULL values will not be emitted. + :param batch: Defaults to ``True``, indicating that save operations of multiple entities can be batched together for efficiency. Setting to False indicates @@ -318,7 +329,7 @@ class will overwrite all data within object instances that already class User(Base): __table__ = user_table - __mapper_args__ = {'column_prefix':'_'} + __mapper_args__ = {"column_prefix": "_"} The above mapping will assign the ``user_id``, ``user_name``, and ``password`` columns to attributes named ``_user_id``, @@ -442,7 +453,7 @@ class User(Base): mapping of the class to an alternate selectable, for loading only. - .. seealso:: + .. seealso:: :ref:`relationship_aliased_class` - the new pattern that removes the need for the :paramref:`_orm.Mapper.non_primary` flag. @@ -534,14 +545,14 @@ class User(Base): base-most mapped :class:`.Table`:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] = mapped_column(String(50)) __mapper_args__ = { - "polymorphic_on":discriminator, - "polymorphic_identity":"employee" + "polymorphic_on": discriminator, + "polymorphic_identity": "employee", } It may also be specified @@ -550,17 +561,18 @@ class Employee(Base): approach:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] = mapped_column(String(50)) __mapper_args__ = { - "polymorphic_on":case( + "polymorphic_on": case( (discriminator == "EN", "engineer"), (discriminator == "MA", "manager"), - else_="employee"), - "polymorphic_identity":"employee" + else_="employee", + ), + "polymorphic_identity": "employee", } It may also refer to any attribute using its string name, @@ -568,14 +580,14 @@ class Employee(Base): configurations:: class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] __mapper_args__ = { "polymorphic_on": "discriminator", - "polymorphic_identity": "employee" + "polymorphic_identity": "employee", } When setting ``polymorphic_on`` to reference an @@ -592,6 +604,7 @@ class Employee(Base): from sqlalchemy import event from sqlalchemy.orm import object_mapper + @event.listens_for(Employee, "init", propagate=True) def set_identity(instance, *arg, **kw): mapper = object_mapper(instance) @@ -1043,7 +1056,7 @@ def entity(self): """ - primary_key: Tuple[Column[Any], ...] + primary_key: Tuple[ColumnElement[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects which comprise the 'primary key' of the mapped table, from the @@ -1606,9 +1619,11 @@ def _configure_pks(self) -> None: if self._primary_key_argument: coerced_pk_arg = [ - self._str_arg_to_mapped_col("primary_key", c) - if isinstance(c, str) - else c + ( + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c + ) for c in ( coercions.expect( roles.DDLConstraintColumnRole, @@ -2465,9 +2480,11 @@ def __str__(self) -> str: return "Mapper[%s%s(%s)]" % ( self.class_.__name__, self.non_primary and " (non-primary)" or "", - self.local_table.description - if self.local_table is not None - else self.persist_selectable.description, + ( + self.local_table.description + if self.local_table is not None + else self.persist_selectable.description + ), ) def _is_orphan(self, state: InstanceState[_O]) -> bool: @@ -2537,7 +2554,7 @@ def _mappers_from_spec( if spec == "*": mappers = list(self.self_and_descendants) elif spec: - mapper_set = set() + mapper_set: Set[Mapper[Any]] = set() for m in util.to_list(spec): m = _class_to_mapper(m) if not m.isa(self): @@ -3244,14 +3261,9 @@ def _equivalent_columns(self) -> _EquivalentColumnMap: The resulting structure is a dictionary of columns mapped to lists of equivalent columns, e.g.:: - { - tablea.col1: - {tableb.col1, tablec.col1}, - tablea.col2: - {tabled.col2} - } + {tablea.col1: {tableb.col1, tablec.col1}, tablea.col2: {tabled.col2}} - """ + """ # noqa: E501 result: _EquivalentColumnMap = {} def visit_binary(binary): @@ -3416,9 +3428,11 @@ def primary_base_mapper(self) -> Mapper[Any]: return self.class_manager.mapper.base_mapper def _result_has_identity_key(self, result, adapter=None): - pk_cols: Sequence[ColumnClause[Any]] = self.primary_key - if adapter: - pk_cols = [adapter.columns[c] for c in pk_cols] + pk_cols: Sequence[ColumnElement[Any]] + if adapter is not None: + pk_cols = [adapter.columns[c] for c in self.primary_key] + else: + pk_cols = self.primary_key rk = result.keys() for col in pk_cols: if col not in rk: @@ -3428,7 +3442,7 @@ def _result_has_identity_key(self, result, adapter=None): def identity_key_from_row( self, - row: Optional[Union[Row[Any], RowMapping]], + row: Union[Row[Any], RowMapping], identity_token: Optional[Any] = None, adapter: Optional[ORMAdapter] = None, ) -> _IdentityKeyType[_O]: @@ -3443,18 +3457,21 @@ def identity_key_from_row( for the "row" argument """ - pk_cols: Sequence[ColumnClause[Any]] = self.primary_key - if adapter: - pk_cols = [adapter.columns[c] for c in pk_cols] + pk_cols: Sequence[ColumnElement[Any]] + if adapter is not None: + pk_cols = [adapter.columns[c] for c in self.primary_key] + else: + pk_cols = self.primary_key + mapping: RowMapping if hasattr(row, "_mapping"): - mapping = row._mapping # type: ignore + mapping = row._mapping else: - mapping = cast("Mapping[Any, Any]", row) + mapping = row # type: ignore[assignment] return ( self._identity_class, - tuple(mapping[column] for column in pk_cols), # type: ignore + tuple(mapping[column] for column in pk_cols), identity_token, ) @@ -3724,14 +3741,15 @@ def _would_selectin_load_only_from_given_mapper(self, super_mapper): given:: - class A: - ... + class A: ... + class B(A): __mapper_args__ = {"polymorphic_load": "selectin"} - class C(B): - ... + + class C(B): ... + class D(B): __mapper_args__ = {"polymorphic_load": "selectin"} @@ -3801,6 +3819,7 @@ def _subclass_load_via_in(self, entity, polymorphic_from): this subclass as a SELECT with IN. """ + strategy_options = util.preloaded.orm_strategy_options assert self.inherits @@ -3824,7 +3843,7 @@ def _subclass_load_via_in(self, entity, polymorphic_from): classes_to_include.add(m) m = m.inherits - for prop in self.attrs: + for prop in self.column_attrs + self.relationships: # skip prop keys that are not instrumented on the mapped class. # this is primarily the "_sa_polymorphic_on" property that gets # created for an ad-hoc polymorphic_on SQL expression, issue #8704 @@ -4289,7 +4308,7 @@ def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None: reg._new_mappers = False -def reconstructor(fn): +def reconstructor(fn: _Fn) -> _Fn: """Decorate a method as the 'reconstructor' hook. Designates a single method as the "reconstructor", an ``__init__``-like @@ -4315,7 +4334,7 @@ def reconstructor(fn): :meth:`.InstanceEvents.load` """ - fn.__sa_reconstructor__ = True + fn.__sa_reconstructor__ = True # type: ignore[attr-defined] return fn diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 354552a5a40..bb03e53d2b1 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -1,12 +1,10 @@ # orm/path_registry.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Path tracking utilities, representing mapper graph traversals. - -""" +"""Path tracking utilities, representing mapper graph traversals.""" from __future__ import annotations @@ -35,7 +33,7 @@ if TYPE_CHECKING: from ._typing import _InternalEntityType - from .interfaces import MapperProperty + from .interfaces import StrategizedProperty from .mapper import Mapper from .relationships import RelationshipProperty from .util import AliasedInsp @@ -45,11 +43,9 @@ from ..util.typing import _LiteralStar from ..util.typing import TypeGuard - def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: - ... + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: ... - def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: - ... + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: ... else: is_root = operator.attrgetter("is_root") @@ -59,13 +55,13 @@ def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: _SerializedPath = List[Any] _StrPathToken = str _PathElementType = Union[ - _StrPathToken, "_InternalEntityType[Any]", "MapperProperty[Any]" + _StrPathToken, "_InternalEntityType[Any]", "StrategizedProperty[Any]" ] # the representation is in fact # a tuple with alternating: -# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], -# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# [_InternalEntityType[Any], Union[str, StrategizedProperty[Any]], +# _InternalEntityType[Any], Union[str, StrategizedProperty[Any]], ...] # this might someday be a tuple of 2-tuples instead, but paths can be # chopped at odd intervals as well so this is less flexible _PathRepresentation = Tuple[_PathElementType, ...] @@ -73,7 +69,7 @@ def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: # NOTE: these names are weird since the array is 0-indexed, # the "_Odd" entries are at 0, 2, 4, etc _OddPathRepresentation = Sequence["_InternalEntityType[Any]"] -_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] +_EvenPathRepresentation = Sequence[Union["StrategizedProperty[Any]", str]] log = logging.getLogger(__name__) @@ -185,26 +181,23 @@ def __hash__(self) -> int: return id(self) @overload - def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: - ... + def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: ... @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... + def __getitem__(self, entity: int) -> _PathElementType: ... @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... + def __getitem__(self, entity: slice) -> _PathRepresentation: ... @overload def __getitem__( self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... + ) -> AbstractEntityRegistry: ... @overload - def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: - ... + def __getitem__( + self, entity: StrategizedProperty[Any] + ) -> PropRegistry: ... def __getitem__( self, @@ -213,7 +206,7 @@ def __getitem__( int, slice, _InternalEntityType[Any], - MapperProperty[Any], + StrategizedProperty[Any], ], ) -> Union[ TokenRegistry, @@ -232,7 +225,7 @@ def length(self) -> int: def pairs( self, ) -> Iterator[ - Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + Tuple[_InternalEntityType[Any], Union[str, StrategizedProperty[Any]]] ]: odd_path = cast(_OddPathRepresentation, self.path) even_path = cast(_EvenPathRepresentation, odd_path) @@ -320,13 +313,11 @@ def deserialize(cls, path: _SerializedPath) -> PathRegistry: @overload @classmethod - def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: - ... + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: ... @overload @classmethod - def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: - ... + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: ... @classmethod def per_mapper( @@ -540,15 +531,16 @@ class PropRegistry(PathRegistry): inherit_cache = True is_property = True - prop: MapperProperty[Any] + prop: StrategizedProperty[Any] mapper: Optional[Mapper[Any]] entity: Optional[_InternalEntityType[Any]] def __init__( - self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + self, parent: AbstractEntityRegistry, prop: StrategizedProperty[Any] ): + # restate this path in terms of the - # given MapperProperty's parent. + # given StrategizedProperty's parent. insp = cast("_InternalEntityType[Any]", parent[-1]) natural_parent: AbstractEntityRegistry = parent @@ -572,7 +564,7 @@ def __init__( # entities are used. # # here we are trying to distinguish between a path that starts - # on a the with_polymorhpic entity vs. one that starts on a + # on a with_polymorphic entity vs. one that starts on a # normal entity that introduces a with_polymorphic() in the # middle using of_type(): # @@ -808,11 +800,9 @@ def _getitem(self, entity: Any) -> Any: def path_is_entity( path: PathRegistry, - ) -> TypeGuard[AbstractEntityRegistry]: - ... + ) -> TypeGuard[AbstractEntityRegistry]: ... - def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: - ... + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: ... else: path_is_entity = operator.attrgetter("is_entity") diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 6729b479f90..cbe8557add9 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1,5 +1,5 @@ # orm/persistence.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -140,11 +140,13 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): state_dict, sub_mapper, connection, - mapper._get_committed_state_attr_by_column( - state, state_dict, mapper.version_id_col - ) - if mapper.version_id_col is not None - else None, + ( + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) + if mapper.version_id_col is not None + else None + ), ) for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table @@ -559,7 +561,8 @@ def _collect_update_commands( f"No primary key value supplied for column(s) " f"""{ ', '.join( - str(c) for c in pks if pk_params[c._label] is None) + str(c) for c in pks if pk_params[c._label] is None + ) }; """ "per-row ORM Bulk UPDATE by Primary Key requires that " "records contain primary key values", @@ -702,10 +705,10 @@ def _collect_delete_commands( params = {} for col in mapper._pks_by_table[table]: - params[ - col.key - ] = value = mapper._get_committed_state_attr_by_column( - state, state_dict, col + params[col.key] = value = ( + mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) ) if value is None: raise orm_exc.FlushError( @@ -933,9 +936,11 @@ def update_stmt(existing_stmt=None): c.context.compiled_parameters[0], value_params, True, - c.returned_defaults - if not c.context.executemany - else None, + ( + c.returned_defaults + if not c.context.executemany + else None + ), ) if check_rowcount: @@ -1068,9 +1073,11 @@ def _emit_insert_statements( last_inserted_params, value_params, False, - result.returned_defaults - if not result.context.executemany - else None, + ( + result.returned_defaults + if not result.context.executemany + else None + ), ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1260,9 +1267,11 @@ def _emit_insert_statements( result.context.compiled_parameters[0], value_params, False, - result.returned_defaults - if not result.context.executemany - else None, + ( + result.returned_defaults + if not result.context.executemany + else None + ), ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1569,16 +1578,25 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch_post_update( mapper, uowtransaction, table, state, dict_, result, params ): - if uowtransaction.is_deleted(state): - return - - prefetch_cols = result.context.compiled.prefetch - postfetch_cols = result.context.compiled.postfetch - - if ( + needs_version_id = ( mapper.version_id_col is not None and mapper.version_id_col in mapper._cols_by_table[table] - ): + ) + + if not uowtransaction.is_deleted(state): + # post updating after a regular INSERT or UPDATE, do a full postfetch + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + elif needs_version_id: + # post updating before a DELETE with a version_id_col, need to + # postfetch just version_id_col + prefetch_cols = postfetch_cols = () + else: + # post updating before a DELETE without a version_id_col, + # don't need to postfetch + return + + if needs_version_id: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1658,9 +1676,18 @@ def _postfetch( for c in prefetch_cols: if c.key in params and c in mapper._columntoproperty: - dict_[mapper._columntoproperty[c].key] = params[c.key] + pkey = mapper._columntoproperty[c].key + + # set prefetched value in dict and also pop from committed_state, + # since this is new database state that replaces whatever might + # have previously been fetched (see #10800). this is essentially a + # shorthand version of set_committed_value(), which could also be + # used here directly (with more overhead) + dict_[pkey] = params[c.key] + state.committed_state.pop(pkey, None) + if refresh_flush: - load_evt_attrs.append(mapper._columntoproperty[c].key) + load_evt_attrs.append(pkey) if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4bb396edc5d..164ae009b25 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1,5 +1,5 @@ # orm/properties.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,6 +28,7 @@ from typing import Union from . import attributes +from . import exc as orm_exc from . import strategy_options from .base import _DeclarativeMapped from .base import class_mapper @@ -43,7 +44,6 @@ from .interfaces import StrategizedProperty from .relationships import RelationshipProperty from .util import de_stringify_annotation -from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -55,12 +55,13 @@ from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types +from ..util.typing import get_args +from ..util.typing import includes_none +from ..util.typing import is_a_type from ..util.typing import is_fwd_ref -from ..util.typing import is_optional_union from ..util.typing import is_pep593 -from ..util.typing import is_union +from ..util.typing import is_pep695 from ..util.typing import Self -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -233,7 +234,7 @@ def _memoized_attr__renders_in_subqueries(self) -> bool: return self.strategy._have_default_expression # type: ignore return ("deferred", True) not in self.strategy_key or ( - self not in self.parent._readonly_props # type: ignore + self not in self.parent._readonly_props ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") @@ -279,8 +280,8 @@ class File(Base): name = Column(String(64)) extension = Column(String(8)) - filename = column_property(name + '.' + extension) - path = column_property('C:/' + filename.expression) + filename = column_property(name + "." + extension) + path = column_property("C:/" + filename.expression) .. seealso:: @@ -429,8 +430,7 @@ def _orm_annotate_column(self, column: _NC) -> _NC: if TYPE_CHECKING: - def __clause_element__(self) -> NamedColumn[_PT]: - ... + def __clause_element__(self) -> NamedColumn[_PT]: ... def _memoized_method___clause_element__( self, @@ -636,9 +636,11 @@ def columns_to_assign(self) -> List[Tuple[Column[Any], int]]: return [ ( self.column, - self._sort_order - if self._sort_order is not _NoArg.NO_ARG - else 0, + ( + self._sort_order + if self._sort_order is not _NoArg.NO_ARG + else 0 + ), ) ] @@ -687,7 +689,7 @@ def declarative_scan( supercls_mapper = class_mapper(decl_scan.inherits, False) colname = column.name if column.name is not None else key - column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + column = self.column = supercls_mapper.local_table.c.get( # type: ignore[assignment] # noqa: E501 colname, column ) @@ -736,47 +738,44 @@ def _init_column_for_annotation( ) -> None: sqltype = self.column.type - if isinstance(argument, str) or is_fwd_ref( - argument, check_generic=True + if is_fwd_ref( + argument, check_generic=True, check_for_plain_string=True ): assert originating_module is not None argument = de_stringify_annotation( cls, argument, originating_module, include_generic=True ) - if is_union(argument): - assert originating_module is not None - argument = de_stringify_union_elements( - cls, argument, originating_module - ) - - nullable = is_optional_union(argument) + nullable = includes_none(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - use_args_from = None + find_mapped_in: Tuple[Any, ...] = () + our_type_is_pep593 = False + raw_pep_593_type = None if is_pep593(our_type): our_type_is_pep593 = True - pep_593_components = typing_get_args(our_type) + pep_593_components = get_args(our_type) raw_pep_593_type = pep_593_components[0] - if is_optional_union(raw_pep_593_type): + if nullable: raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type) - - nullable = True - if not self._has_nullable: - self.column.nullable = nullable - for elem in pep_593_components[1:]: - if isinstance(elem, MappedColumn): - use_args_from = elem - break + find_mapped_in = pep_593_components[1:] + elif is_pep695(argument) and is_pep593(argument.__value__): + # do not support nested annotation inside unions ets + find_mapped_in = get_args(argument.__value__)[1:] + + use_args_from: Optional[MappedColumn[Any]] + for elem in find_mapped_in: + if isinstance(elem, MappedColumn): + use_args_from = elem + break else: - our_type_is_pep593 = False - raw_pep_593_type = None + use_args_from = None if use_args_from is not None: if ( @@ -848,8 +847,7 @@ def _init_column_for_annotation( ) if sqltype._isnull and not self.column.foreign_keys: - new_sqltype = None - + checks: List[Any] if our_type_is_pep593: checks = [our_type, raw_pep_593_type] else: @@ -864,16 +862,23 @@ def _init_column_for_annotation( isinstance(our_type, type) and issubclass(our_type, TypeEngine) ): - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f"The type provided inside the {self.column.key!r} " "attribute Mapped annotation is the SQLAlchemy type " f"{our_type}. Expected a Python type instead" ) - else: - raise sa_exc.ArgumentError( + elif is_a_type(our_type): + raise orm_exc.MappedAnnotationError( "Could not locate SQLAlchemy Core type for Python " f"type {our_type} inside the {self.column.key!r} " "attribute Mapped annotation" ) + else: + raise orm_exc.MappedAnnotationError( + f"The object provided inside the {self.column.key!r} " + "attribute Mapped annotation is not a Python type, " + f"it's the object {our_type!r}. Expected a Python " + "type." + ) self.column._set_type(new_sqltype) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5da7ee9b228..3489c15fd6f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,5 +1,5 @@ # orm/query.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -134,6 +134,7 @@ from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import CacheableOptions from ..sql.base import ExecutableOption + from ..sql.dml import UpdateBase from ..sql.elements import ColumnElement from ..sql.elements import Label from ..sql.selectable import _ForUpdateOfArgument @@ -166,7 +167,6 @@ class Query( Executable, Generic[_T], ): - """ORM-level SQL construction object. .. legacy:: The ORM :class:`.Query` object is a legacy construct @@ -205,9 +205,9 @@ class Query( _memoized_select_entities = () - _compile_options: Union[ - Type[CacheableOptions], CacheableOptions - ] = ORMCompileState.default_compile_options + _compile_options: Union[Type[CacheableOptions], CacheableOptions] = ( + ORMCompileState.default_compile_options + ) _with_options: Tuple[ExecutableOption, ...] load_options = QueryContext.default_load_options + { @@ -493,7 +493,7 @@ def _get_select_statement_only(self) -> Select[_T]: return cast("Select[_T]", self.statement) @property - def statement(self) -> Union[Select[_T], FromStatement[_T]]: + def statement(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: """The full SELECT statement represented by this Query. The statement by default will not have disambiguating labels @@ -521,6 +521,8 @@ def statement(self) -> Union[Select[_T], FromStatement[_T]]: # from there, it starts to look much like Query itself won't be # passed into the execute process and won't generate its own cache # key; this will all occur in terms of the ORM-enabled Select. + stmt: Union[Select[_T], FromStatement[_T], UpdateBase] + if not self._compile_options._set_base_alias: # if we don't have legacy top level aliasing features in use # then convert to a future select() directly @@ -673,41 +675,38 @@ def cte( from sqlalchemy.orm import aliased + class Part(Base): - __tablename__ = 'part' + __tablename__ = "part" part = Column(String, primary_key=True) sub_part = Column(String, primary_key=True) quantity = Column(Integer) - included_parts = session.query( - Part.sub_part, - Part.part, - Part.quantity).\ - filter(Part.part=="our part").\ - cte(name="included_parts", recursive=True) + + included_parts = ( + session.query(Part.sub_part, Part.part, Part.quantity) + .filter(Part.part == "our part") + .cte(name="included_parts", recursive=True) + ) incl_alias = aliased(included_parts, name="pr") parts_alias = aliased(Part, name="p") included_parts = included_parts.union_all( session.query( - parts_alias.sub_part, - parts_alias.part, - parts_alias.quantity).\ - filter(parts_alias.part==incl_alias.c.sub_part) - ) + parts_alias.sub_part, parts_alias.part, parts_alias.quantity + ).filter(parts_alias.part == incl_alias.c.sub_part) + ) q = session.query( - included_parts.c.sub_part, - func.sum(included_parts.c.quantity). - label('total_quantity') - ).\ - group_by(included_parts.c.sub_part) + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label("total_quantity"), + ).group_by(included_parts.c.sub_part) .. seealso:: :meth:`_sql.Select.cte` - v2 equivalent method. - """ + """ # noqa: E501 return ( self.enable_eagerloads(False) ._get_select_statement_only() @@ -732,20 +731,17 @@ def label(self, name: Optional[str]) -> Label[Any]: ) @overload - def as_scalar( + def as_scalar( # type: ignore[overload-overlap] self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[_MAYBE_ENTITY]: - ... + ) -> ScalarSelect[_MAYBE_ENTITY]: ... @overload def as_scalar( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def as_scalar(self) -> ScalarSelect[Any]: - ... + def as_scalar(self) -> ScalarSelect[Any]: ... @util.deprecated( "1.4", @@ -763,18 +759,15 @@ def as_scalar(self) -> ScalarSelect[Any]: @overload def scalar_subquery( self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[Any]: - ... + ) -> ScalarSelect[Any]: ... @overload def scalar_subquery( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... def scalar_subquery(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this @@ -799,7 +792,7 @@ def scalar_subquery(self) -> ScalarSelect[Any]: ) @property - def selectable(self) -> Union[Select[_T], FromStatement[_T]]: + def selectable(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: """Return the :class:`_expression.Select` object emitted by this :class:`_query.Query`. @@ -810,7 +803,9 @@ def selectable(self) -> Union[Select[_T], FromStatement[_T]]: """ return self.__clause_element__() - def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: + def __clause_element__( + self, + ) -> Union[Select[_T], FromStatement[_T], UpdateBase]: return ( self._with_compile_options( _enable_eagerloads=False, _render_for_subquery=True @@ -822,14 +817,12 @@ def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: @overload def only_return_tuples( self: Query[_O], value: Literal[True] - ) -> RowReturningQuery[Tuple[_O]]: - ... + ) -> RowReturningQuery[Tuple[_O]]: ... @overload def only_return_tuples( self: Query[_O], value: Literal[False] - ) -> Query[_O]: - ... + ) -> Query[_O]: ... @_generative def only_return_tuples(self, value: bool) -> Query[Any]: @@ -950,9 +943,7 @@ def set_label_style(self, style: SelectLabelStyle) -> Self: :attr:`_query.Query.statement` using :meth:`.Session.execute`:: result = session.execute( - query - .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) - .statement + query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).statement ) .. versionadded:: 1.4 @@ -1061,8 +1052,7 @@ def get(self, ident: _PKIdentityArgument) -> Optional[Any]: some_object = session.query(VersionedFoo).get((5, 10)) - some_object = session.query(VersionedFoo).get( - {"id": 5, "version_id": 10}) + some_object = session.query(VersionedFoo).get({"id": 5, "version_id": 10}) :meth:`_query.Query.get` is special in that it provides direct access to the identity map of the owning :class:`.Session`. @@ -1128,7 +1118,7 @@ def get(self, ident: _PKIdentityArgument) -> Optional[Any]: :return: The object instance, or ``None``. - """ + """ # noqa: E501 self._no_criterion_assertion("get", order_by=False, distinct=False) # we still implement _get_impl() so that baked query can override @@ -1475,15 +1465,13 @@ def value(self, column: _ColumnExpressionArgument[Any]) -> Any: return None @overload - def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def with_entities( self, _colexpr: roles.TypedColumnsClauseRole[_T], - ) -> RowReturningQuery[Tuple[_T]]: - ... + ) -> RowReturningQuery[Tuple[_T]]: ... # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 @@ -1493,14 +1481,12 @@ def with_entities( @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... @overload def with_entities( @@ -1509,8 +1495,7 @@ def with_entities( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def with_entities( @@ -1520,8 +1505,7 @@ def with_entities( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def with_entities( @@ -1532,8 +1516,7 @@ def with_entities( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def with_entities( @@ -1545,8 +1528,7 @@ def with_entities( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def with_entities( @@ -1559,16 +1541,14 @@ def with_entities( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... # END OVERLOADED FUNCTIONS self.with_entities @overload def with_entities( self, *entities: _ColumnsClauseArgument[Any] - ) -> Query[Any]: - ... + ) -> Query[Any]: ... @_generative def with_entities( @@ -1582,19 +1562,22 @@ def with_entities( # Users, filtered on some arbitrary criterion # and then ordered by related email address - q = session.query(User).\ - join(User.address).\ - filter(User.name.like('%ed%')).\ - order_by(Address.email) + q = ( + session.query(User) + .join(User.address) + .filter(User.name.like("%ed%")) + .order_by(Address.email) + ) # given *only* User.id==5, Address.email, and 'q', what # would the *next* User in the result be ? - subq = q.with_entities(Address.email).\ - order_by(None).\ - filter(User.id==5).\ - subquery() - q = q.join((subq, subq.c.email < Address.email)).\ - limit(1) + subq = ( + q.with_entities(Address.email) + .order_by(None) + .filter(User.id == 5) + .subquery() + ) + q = q.join((subq, subq.c.email < Address.email)).limit(1) .. seealso:: @@ -1690,9 +1673,11 @@ def with_transformation( def filter_something(criterion): def transform(q): return q.filter(criterion) + return transform - q = q.with_transformation(filter_something(x==5)) + + q = q.with_transformation(filter_something(x == 5)) This allows ad-hoc recipes to be created for :class:`_query.Query` objects. @@ -1729,13 +1714,12 @@ def execution_options( schema_translate_map: Optional[SchemaTranslateMapType] = ..., populate_existing: bool = False, autoflush: bool = False, + preserve_rowcount: bool = False, **opt: Any, - ) -> Self: - ... + ) -> Self: ... @overload - def execution_options(self, **opt: Any) -> Self: - ... + def execution_options(self, **opt: Any) -> Self: ... @_generative def execution_options(self, **kwargs: Any) -> Self: @@ -1810,9 +1794,15 @@ def with_for_update( E.g.:: - q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User) + q = ( + sess.query(User) + .populate_existing() + .with_for_update(nowait=True, of=User) + ) - The above query on a PostgreSQL backend will render like:: + The above query on a PostgreSQL backend will render like: + + .. sourcecode:: sql SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT @@ -1890,14 +1880,13 @@ def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: e.g.:: - session.query(MyClass).filter(MyClass.name == 'some name') + session.query(MyClass).filter(MyClass.name == "some name") Multiple criteria may be specified as comma separated; the effect is that they will be joined together using the :func:`.and_` function:: - session.query(MyClass).\ - filter(MyClass.name == 'some name', MyClass.id > 5) + session.query(MyClass).filter(MyClass.name == "some name", MyClass.id > 5) The criterion is any SQL expression object applicable to the WHERE clause of a select. String expressions are coerced @@ -1910,7 +1899,7 @@ def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: :meth:`_sql.Select.where` - v2 equivalent method. - """ + """ # noqa: E501 for crit in list(criterion): crit = coercions.expect( roles.WhereHavingRole, crit, apply_propagate_attrs=self @@ -1978,14 +1967,13 @@ def filter_by(self, **kwargs: Any) -> Self: e.g.:: - session.query(MyClass).filter_by(name = 'some name') + session.query(MyClass).filter_by(name="some name") Multiple criteria may be specified as comma separated; the effect is that they will be joined together using the :func:`.and_` function:: - session.query(MyClass).\ - filter_by(name = 'some name', id = 5) + session.query(MyClass).filter_by(name="some name", id=5) The keyword expressions are extracted from the primary entity of the query, or the last entity that was the @@ -2112,10 +2100,12 @@ def having(self, *having: _ColumnExpressionArgument[bool]) -> Self: HAVING criterion makes it possible to use filters on aggregate functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: - q = session.query(User.id).\ - join(User.addresses).\ - group_by(User.id).\ - having(func.count(Address.id) > 2) + q = ( + session.query(User.id) + .join(User.addresses) + .group_by(User.id) + .having(func.count(Address.id) > 2) + ) .. seealso:: @@ -2139,8 +2129,8 @@ def union(self, *q: Query[Any]) -> Self: e.g.:: - q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') - q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') + q1 = sess.query(SomeClass).filter(SomeClass.foo == "bar") + q2 = sess.query(SomeClass).filter(SomeClass.bar == "foo") q3 = q1.union(q2) @@ -2149,7 +2139,9 @@ def union(self, *q: Query[Any]) -> Self: x.union(y).union(z).all() - will nest on each ``union()``, and produces:: + will nest on each ``union()``, and produces: + + .. sourcecode:: sql SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z) @@ -2158,7 +2150,9 @@ def union(self, *q: Query[Any]) -> Self: x.union(y, z).all() - produces:: + produces: + + .. sourcecode:: sql SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z) @@ -2270,7 +2264,9 @@ def join( q = session.query(User).join(User.addresses) Where above, the call to :meth:`_query.Query.join` along - ``User.addresses`` will result in SQL approximately equivalent to:: + ``User.addresses`` will result in SQL approximately equivalent to: + + .. sourcecode:: sql SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id @@ -2283,10 +2279,12 @@ def join( calls may be used. The relationship-bound attribute implies both the left and right side of the join at once:: - q = session.query(User).\ - join(User.orders).\ - join(Order.items).\ - join(Item.keywords) + q = ( + session.query(User) + .join(User.orders) + .join(Order.items) + .join(Item.keywords) + ) .. note:: as seen in the above example, **the order in which each call to the join() method occurs is important**. Query would not, @@ -2325,7 +2323,7 @@ def join( as the ON clause to be passed explicitly. A example that includes a SQL expression as the ON clause is as follows:: - q = session.query(User).join(Address, User.id==Address.user_id) + q = session.query(User).join(Address, User.id == Address.user_id) The above form may also use a relationship-bound attribute as the ON clause as well:: @@ -2340,11 +2338,13 @@ def join( a1 = aliased(Address) a2 = aliased(Address) - q = session.query(User).\ - join(a1, User.addresses).\ - join(a2, User.addresses).\ - filter(a1.email_address=='ed@foo.com').\ - filter(a2.email_address=='ed@bar.com') + q = ( + session.query(User) + .join(a1, User.addresses) + .join(a2, User.addresses) + .filter(a1.email_address == "ed@foo.com") + .filter(a2.email_address == "ed@bar.com") + ) The relationship-bound calling form can also specify a target entity using the :meth:`_orm.PropComparator.of_type` method; a query @@ -2353,11 +2353,13 @@ def join( a1 = aliased(Address) a2 = aliased(Address) - q = session.query(User).\ - join(User.addresses.of_type(a1)).\ - join(User.addresses.of_type(a2)).\ - filter(a1.email_address == 'ed@foo.com').\ - filter(a2.email_address == 'ed@bar.com') + q = ( + session.query(User) + .join(User.addresses.of_type(a1)) + .join(User.addresses.of_type(a2)) + .filter(a1.email_address == "ed@foo.com") + .filter(a2.email_address == "ed@bar.com") + ) **Augmenting Built-in ON Clauses** @@ -2368,7 +2370,7 @@ def join( with the default criteria using AND:: q = session.query(User).join( - User.addresses.and_(Address.email_address != 'foo@bar.com') + User.addresses.and_(Address.email_address != "foo@bar.com") ) .. versionadded:: 1.4 @@ -2381,29 +2383,28 @@ def join( appropriate ``.subquery()`` method in order to make a subquery out of a query:: - subq = session.query(Address).\ - filter(Address.email_address == 'ed@foo.com').\ - subquery() + subq = ( + session.query(Address) + .filter(Address.email_address == "ed@foo.com") + .subquery() + ) - q = session.query(User).join( - subq, User.id == subq.c.user_id - ) + q = session.query(User).join(subq, User.id == subq.c.user_id) Joining to a subquery in terms of a specific relationship and/or target entity may be achieved by linking the subquery to the entity using :func:`_orm.aliased`:: - subq = session.query(Address).\ - filter(Address.email_address == 'ed@foo.com').\ - subquery() + subq = ( + session.query(Address) + .filter(Address.email_address == "ed@foo.com") + .subquery() + ) address_subq = aliased(Address, subq) - q = session.query(User).join( - User.addresses.of_type(address_subq) - ) - + q = session.query(User).join(User.addresses.of_type(address_subq)) **Controlling what to Join From** @@ -2411,11 +2412,16 @@ def join( :class:`_query.Query` is not in line with what we want to join from, the :meth:`_query.Query.select_from` method may be used:: - q = session.query(Address).select_from(User).\ - join(User.addresses).\ - filter(User.name == 'ed') + q = ( + session.query(Address) + .select_from(User) + .join(User.addresses) + .filter(User.name == "ed") + ) + + Which will produce SQL similar to: - Which will produce SQL similar to:: + .. sourcecode:: sql SELECT address.* FROM user JOIN address ON user.id=address.user_id @@ -2519,11 +2525,16 @@ def select_from(self, *from_obj: _FromClauseArgument) -> Self: A typical example:: - q = session.query(Address).select_from(User).\ - join(User.addresses).\ - filter(User.name == 'ed') + q = ( + session.query(Address) + .select_from(User) + .join(User.addresses) + .filter(User.name == "ed") + ) - Which produces SQL equivalent to:: + Which produces SQL equivalent to: + + .. sourcecode:: sql SELECT address.* FROM user JOIN address ON user.id=address.user_id @@ -2776,11 +2787,10 @@ def one_or_none(self) -> Optional[_T]: def one(self) -> _T: """Return exactly one result or raise an exception. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects - no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` - if multiple object identities are returned, or if multiple - rows are returned for a query that returns only scalar values - as opposed to full identity-mapped entities. + Raises :class:`_exc.NoResultFound` if the query selects no rows. + Raises :class:`_exc.MultipleResultsFound` if multiple object identities + are returned, or if multiple rows are returned for a query that returns + only scalar values as opposed to full identity-mapped entities. Calling :meth:`.one` results in an execution of the underlying query. @@ -2800,7 +2810,7 @@ def one(self) -> _T: def scalar(self) -> Any: """Return the first element of the first result or None if no rows present. If multiple rows are returned, - raises MultipleResultsFound. + raises :class:`_exc.MultipleResultsFound`. >>> session.query(Item).scalar() @@ -2886,7 +2896,7 @@ def column_descriptions(self) -> List[ORMColumnDescription]: Format is a list of dictionaries:: - user_alias = aliased(User, name='user2') + user_alias = aliased(User, name="user2") q = sess.query(User, User.id, user_alias) # this expression: @@ -2895,26 +2905,26 @@ def column_descriptions(self) -> List[ORMColumnDescription]: # would return: [ { - 'name':'User', - 'type':User, - 'aliased':False, - 'expr':User, - 'entity': User + "name": "User", + "type": User, + "aliased": False, + "expr": User, + "entity": User, }, { - 'name':'id', - 'type':Integer(), - 'aliased':False, - 'expr':User.id, - 'entity': User + "name": "id", + "type": Integer(), + "aliased": False, + "expr": User.id, + "entity": User, }, { - 'name':'user2', - 'type':User, - 'aliased':True, - 'expr':user_alias, - 'entity': user_alias - } + "name": "user2", + "type": User, + "aliased": True, + "expr": user_alias, + "entity": user_alias, + }, ] .. seealso:: @@ -2959,6 +2969,7 @@ def instances( context = QueryContext( compile_state, compile_state.statement, + compile_state.statement, self._params, self.session, self.load_options, @@ -3022,10 +3033,12 @@ def exists(self) -> Exists: e.g.:: - q = session.query(User).filter(User.name == 'fred') + q = session.query(User).filter(User.name == "fred") session.query(q.exists()) - Producing SQL similar to:: + Producing SQL similar to: + + .. sourcecode:: sql SELECT EXISTS ( SELECT 1 FROM users WHERE users.name = :name_1 @@ -3074,7 +3087,9 @@ def count(self) -> int: r"""Return a count of rows this the SQL formed by this :class:`Query` would return. - This generates the SQL for this Query as follows:: + This generates the SQL for this Query as follows: + + .. sourcecode:: sql SELECT count(1) AS count_1 FROM ( SELECT @@ -3114,8 +3129,7 @@ def count(self) -> int: # return count of user "id" grouped # by "name" - session.query(func.count(User.id)).\ - group_by(User.name) + session.query(func.count(User.id)).group_by(User.name) from sqlalchemy import distinct @@ -3133,7 +3147,9 @@ def count(self) -> int: ) def delete( - self, synchronize_session: SynchronizeSessionArgument = "auto" + self, + synchronize_session: SynchronizeSessionArgument = "auto", + delete_args: Optional[Dict[Any, Any]] = None, ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3141,11 +3157,11 @@ def delete( E.g.:: - sess.query(User).filter(User.age == 25).\ - delete(synchronize_session=False) + sess.query(User).filter(User.age == 25).delete(synchronize_session=False) - sess.query(User).filter(User.age == 25).\ - delete(synchronize_session='evaluate') + sess.query(User).filter(User.age == 25).delete( + synchronize_session="evaluate" + ) .. warning:: @@ -3158,6 +3174,13 @@ def delete( :ref:`orm_expression_update_delete` for a discussion of these strategies. + :param delete_args: Optional dictionary, if present will be passed + to the underlying :func:`_expression.delete` construct as the ``**kw`` + for the object. May be used to pass dialect-specific arguments such + as ``mysql_limit``. + + .. versionadded:: 2.0.37 + :return: the count of rows matched as returned by the database's "row count" feature. @@ -3165,9 +3188,9 @@ def delete( :ref:`orm_expression_update_delete` - """ + """ # noqa: E501 - bulk_del = BulkDelete(self) + bulk_del = BulkDelete(self, delete_args) if self.dispatch.before_compile_delete: for fn in self.dispatch.before_compile_delete: new_query = fn(bulk_del.query, bulk_del) @@ -3177,6 +3200,10 @@ def delete( self = bulk_del.query delete_ = sql.delete(*self._raw_columns) # type: ignore + + if delete_args: + delete_ = delete_.with_dialect_options(**delete_args) + delete_._where_criteria = self._where_criteria result: CursorResult[Any] = self.session.execute( delete_, @@ -3203,11 +3230,13 @@ def update( E.g.:: - sess.query(User).filter(User.age == 25).\ - update({User.age: User.age - 10}, synchronize_session=False) + sess.query(User).filter(User.age == 25).update( + {User.age: User.age - 10}, synchronize_session=False + ) - sess.query(User).filter(User.age == 25).\ - update({"age": User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter(User.age == 25).update( + {"age": User.age - 10}, synchronize_session="evaluate" + ) .. warning:: @@ -3230,9 +3259,8 @@ def update( strategies. :param update_args: Optional dictionary, if present will be passed - to the underlying :func:`_expression.update` - construct as the ``**kw`` for - the object. May be used to pass dialect-specific arguments such + to the underlying :func:`_expression.update` construct as the ``**kw`` + for the object. May be used to pass dialect-specific arguments such as ``mysql_limit``, as well as other special arguments such as :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. @@ -3311,13 +3339,16 @@ def _compile_state( ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) - return compile_state_cls.create_for_statement(stmt, None) + return compile_state_cls._create_orm_context( + stmt, toplevel=True, compiler=None + ) def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) context = QueryContext( compile_state, compile_state.statement, + compile_state.statement, self._params, self.session, self.load_options, @@ -3406,6 +3437,14 @@ def __init__( class BulkDelete(BulkUD): """BulkUD which handles DELETEs.""" + def __init__( + self, + query: Query[Any], + delete_kwargs: Optional[Dict[Any, Any]], + ): + super().__init__(query) + self.delete_kwargs = delete_kwargs + class RowReturningQuery(Query[Row[_TP]]): if TYPE_CHECKING: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 7ea30d7b180..15b63d1b4b4 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1,5 +1,5 @@ # orm/relationships.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -19,6 +19,7 @@ from collections import abc import dataclasses import inspect as _py_inspect +import itertools import re import typing from typing import Any @@ -26,6 +27,7 @@ from typing import cast from typing import Collection from typing import Dict +from typing import FrozenSet from typing import Generic from typing import Iterable from typing import Iterator @@ -179,7 +181,10 @@ ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionElement = Union[ - ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole, "Mapped[Any]" + ColumnClause[Any], + _HasClauseElement[Any], + roles.DMLColumnRole, + "Mapped[Any]", ] _ORMColCollectionArgument = Union[ str, @@ -481,8 +486,7 @@ def __init__( else: self._overlaps = () - # mypy ignoring the @property setter - self.cascade = cascade # type: ignore + self.cascade = cascade self.back_populates = back_populates @@ -704,12 +708,16 @@ def in_(self, other: Any) -> NoReturn: def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``==`` operator. - In a many-to-one context, such as:: + In a many-to-one context, such as: + + .. sourcecode:: text MyClass.some_prop == this will typically produce a - clause such as:: + clause such as: + + .. sourcecode:: text mytable.related_id == @@ -872,11 +880,12 @@ def any( An expression like:: session.query(MyClass).filter( - MyClass.somereference.any(SomeRelated.x==2) + MyClass.somereference.any(SomeRelated.x == 2) ) + Will produce a query like: - Will produce a query like:: + .. sourcecode:: sql SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id @@ -890,11 +899,11 @@ def any( :meth:`~.Relationship.Comparator.any` is particularly useful for testing for empty collections:: - session.query(MyClass).filter( - ~MyClass.somereference.any() - ) + session.query(MyClass).filter(~MyClass.somereference.any()) - will produce:: + will produce: + + .. sourcecode:: sql SELECT * FROM my_table WHERE NOT (EXISTS (SELECT 1 FROM related WHERE @@ -925,11 +934,12 @@ def has( An expression like:: session.query(MyClass).filter( - MyClass.somereference.has(SomeRelated.x==2) + MyClass.somereference.has(SomeRelated.x == 2) ) + Will produce a query like: - Will produce a query like:: + .. sourcecode:: sql SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE @@ -948,7 +958,7 @@ def has( """ if self.property.uselist: raise sa_exc.InvalidRequestError( - "'has()' not implemented for collections. " "Use any()." + "'has()' not implemented for collections. Use any()." ) return self._criterion_exists(criterion, **kwargs) @@ -968,7 +978,9 @@ def contains( MyClass.contains(other) - Produces a clause like:: + Produces a clause like: + + .. sourcecode:: sql mytable.id == @@ -988,7 +1000,9 @@ def contains( query(MyClass).filter(MyClass.contains(other)) - Produces a query like:: + Produces a query like: + + .. sourcecode:: sql SELECT * FROM my_table, my_association_table AS my_association_table_1 WHERE @@ -1084,11 +1098,15 @@ def adapt(col: _CE) -> _CE: def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``!=`` operator. - In a many-to-one context, such as:: + In a many-to-one context, such as: + + .. sourcecode:: text MyClass.some_prop != - This will typically produce a clause such as:: + This will typically produce a clause such as: + + .. sourcecode:: sql mytable.related_id != @@ -1304,9 +1322,11 @@ def _go() -> Any: state, dict_, column, - passive=PassiveFlag.PASSIVE_OFF - if state.persistent - else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, + passive=( + PassiveFlag.PASSIVE_OFF + if state.persistent + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK + ), ) if current_value is LoaderCallableStatus.NEVER_SET: @@ -1737,8 +1757,6 @@ def declarative_scan( extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - argument = extracted_mapped_annotation - if extracted_mapped_annotation is None: if self.argument is None: self._raise_for_required(key, cls) @@ -1748,19 +1766,17 @@ def declarative_scan( argument = extracted_mapped_annotation assert originating_module is not None - is_write_only = mapped_container is not None and issubclass( - mapped_container, WriteOnlyMapped - ) - if is_write_only: - self.lazy = "write_only" - self.strategy_key = (("lazy", self.lazy),) - - is_dynamic = mapped_container is not None and issubclass( - mapped_container, DynamicMapped - ) - if is_dynamic: - self.lazy = "dynamic" - self.strategy_key = (("lazy", self.lazy),) + if mapped_container is not None: + is_write_only = issubclass(mapped_container, WriteOnlyMapped) + is_dynamic = issubclass(mapped_container, DynamicMapped) + if is_write_only: + self.lazy = "write_only" + self.strategy_key = (("lazy", self.lazy),) + elif is_dynamic: + self.lazy = "dynamic" + self.strategy_key = (("lazy", self.lazy),) + else: + is_write_only = is_dynamic = False argument = de_optionalize_union_types(argument) @@ -1811,15 +1827,12 @@ def declarative_scan( argument, originating_module ) - # we don't allow the collection class to be a - # __forward_arg__ right now, so if we see a forward arg here, - # we know there was no collection class either - if ( - self.collection_class is None - and not is_write_only - and not is_dynamic - ): - self.uselist = False + if ( + self.collection_class is None + and not is_write_only + and not is_dynamic + ): + self.uselist = False # ticket #8759 # if a lead argument was given to relationship(), like @@ -1999,9 +2012,11 @@ def _check_cascade_settings(self, cascade: CascadeOptions) -> None: "the single_parent=True flag." % { "rel": self, - "direction": "many-to-one" - if self.direction is MANYTOONE - else "many-to-many", + "direction": ( + "many-to-one" + if self.direction is MANYTOONE + else "many-to-many" + ), "clsname": self.parent.class_.__name__, "relatedcls": self.mapper.class_.__name__, }, @@ -2894,9 +2909,6 @@ def _check_foreign_cols( ) -> None: """Check the foreign key columns collected and emit error messages.""" - - can_sync = False - foreign_cols = self._gather_columns_with_annotation( join_condition, "foreign" ) @@ -3052,9 +3064,9 @@ def _deannotate_pairs( def _setup_pairs(self) -> None: sync_pairs: _MutableColumnPairs = [] - lrp: util.OrderedSet[ - Tuple[ColumnElement[Any], ColumnElement[Any]] - ] = util.OrderedSet([]) + lrp: util.OrderedSet[Tuple[ColumnElement[Any], ColumnElement[Any]]] = ( + util.OrderedSet([]) + ) secondary_sync_pairs: _MutableColumnPairs = [] def go( @@ -3131,9 +3143,9 @@ def _warn_for_conflicting_sync_targets(self) -> None: # level configuration that benefits from this warning. if to_ not in self._track_overlapping_sync_targets: - self._track_overlapping_sync_targets[ - to_ - ] = weakref.WeakKeyDictionary({self.prop: from_}) + self._track_overlapping_sync_targets[to_] = ( + weakref.WeakKeyDictionary({self.prop: from_}) + ) else: other_props = [] prop_to_from = self._track_overlapping_sync_targets[to_] @@ -3231,6 +3243,15 @@ def _gather_columns_with_annotation( if annotation_set.issubset(col._annotations) } + @util.memoized_property + def _secondary_lineage_set(self) -> FrozenSet[ColumnElement[Any]]: + if self.secondary is not None: + return frozenset( + itertools.chain(*[c.proxy_set for c in self.secondary.c]) + ) + else: + return util.EMPTY_SET + def join_targets( self, source_selectable: Optional[FromClause], @@ -3281,23 +3302,25 @@ def join_targets( if extra_criteria: - def mark_unrelated_columns_as_ok_to_adapt( + def mark_exclude_cols( elem: SupportsAnnotations, annotations: _AnnotationDict ) -> SupportsAnnotations: - """note unrelated columns in the "extra criteria" as OK - to adapt, even though they are not part of our "local" - or "remote" side. + """note unrelated columns in the "extra criteria" as either + should be adapted or not adapted, even though they are not + part of our "local" or "remote" side. - see #9779 for this case + see #9779 for this case, as well as #11010 for a follow up """ parentmapper_for_element = elem._annotations.get( "parentmapper", None ) + if ( parentmapper_for_element is not self.prop.parent and parentmapper_for_element is not self.prop.mapper + and elem not in self._secondary_lineage_set ): return _safe_annotate(elem, annotations) else: @@ -3306,8 +3329,8 @@ def mark_unrelated_columns_as_ok_to_adapt( extra_criteria = tuple( _deep_annotate( elem, - {"ok_to_adapt_in_join_condition": True}, - annotate_callable=mark_unrelated_columns_as_ok_to_adapt, + {"should_not_adapt": True}, + annotate_callable=mark_exclude_cols, ) for elem in extra_criteria ) @@ -3321,14 +3344,16 @@ def mark_unrelated_columns_as_ok_to_adapt( if secondary is not None: secondary = secondary._anonymous_fromclause(flat=True) primary_aliasizer = ClauseAdapter( - secondary, exclude_fn=_ColInAnnotations("local") + secondary, + exclude_fn=_local_col_exclude, ) secondary_aliasizer = ClauseAdapter( dest_selectable, equivalents=self.child_equivalents ).chain(primary_aliasizer) if source_selectable is not None: primary_aliasizer = ClauseAdapter( - secondary, exclude_fn=_ColInAnnotations("local") + secondary, + exclude_fn=_local_col_exclude, ).chain( ClauseAdapter( source_selectable, @@ -3340,14 +3365,14 @@ def mark_unrelated_columns_as_ok_to_adapt( else: primary_aliasizer = ClauseAdapter( dest_selectable, - exclude_fn=_ColInAnnotations("local"), + exclude_fn=_local_col_exclude, equivalents=self.child_equivalents, ) if source_selectable is not None: primary_aliasizer.chain( ClauseAdapter( source_selectable, - exclude_fn=_ColInAnnotations("remote"), + exclude_fn=_remote_col_exclude, equivalents=self.parent_equivalents, ) ) @@ -3366,9 +3391,7 @@ def mark_unrelated_columns_as_ok_to_adapt( dest_selectable, ) - def create_lazy_clause( - self, reverse_direction: bool = False - ) -> Tuple[ + def create_lazy_clause(self, reverse_direction: bool = False) -> Tuple[ ColumnElement[bool], Dict[str, ColumnElement[Any]], Dict[ColumnElement[Any], ColumnElement[Any]], @@ -3428,25 +3451,29 @@ def col_to_bind( class _ColInAnnotations: - """Serializable object that tests for a name in c._annotations.""" + """Serializable object that tests for names in c._annotations. - __slots__ = ("name",) + TODO: does this need to be serializable anymore? can we find what the + use case was for that? - def __init__(self, name: str): - self.name = name + """ + + __slots__ = ("names",) + + def __init__(self, *names: str): + self.names = frozenset(names) def __call__(self, c: ClauseElement) -> bool: - return ( - self.name in c._annotations - or "ok_to_adapt_in_join_condition" in c._annotations - ) + return bool(self.names.intersection(c._annotations)) + +_local_col_exclude = _ColInAnnotations("local", "should_not_adapt") +_remote_col_exclude = _ColInAnnotations("remote", "should_not_adapt") -class Relationship( # type: ignore + +class Relationship( RelationshipProperty[_T], _DeclarativeMapped[_T], - WriteOnlyMapped[_T], # not compatible with Mapped[_T] - DynamicMapped[_T], # not compatible with Mapped[_T] ): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -3464,3 +3491,18 @@ class Relationship( # type: ignore inherit_cache = True """:meta private:""" + + +class _RelationshipDeclared( # type: ignore[misc] + Relationship[_T], + WriteOnlyMapped[_T], # not compatible with Mapped[_T] + DynamicMapped[_T], # not compatible with Mapped[_T] +): + """Relationship subclass used implicitly for declarative mapping.""" + + inherit_cache = True + """:meta private:""" + + @classmethod + def _mapper_property_name(cls) -> str: + return "Relationship" diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index ab632bdd564..df5a6534dce 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,5 +1,5 @@ # orm/scoping.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -86,8 +86,7 @@ class QueryPropertyDescriptor(Protocol): """ - def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: - ... + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... _O = TypeVar("_O", bound=object) @@ -281,11 +280,13 @@ def query_property( Session = scoped_session(sessionmaker()) + class MyClass: query: QueryPropertyDescriptor = Session.query_property() + # after mappers are defined - result = MyClass.query.filter(MyClass.name=='foo').all() + result = MyClass.query.filter(MyClass.name == "foo").all() Produces instances of the session's configured query class by default. To override and use a custom implementation, provide @@ -534,12 +535,12 @@ def reset(self) -> None: behalf of the :class:`_orm.scoping.scoped_session` class. This method provides for same "reset-only" behavior that the - :meth:_orm.Session.close method has provided historically, where the + :meth:`_orm.Session.close` method has provided historically, where the state of the :class:`_orm.Session` is reset as though the object were brand new, and ready to be used again. - The method may then be useful for :class:`_orm.Session` objects + This method may then be useful for :class:`_orm.Session` objects which set :paramref:`_orm.Session.close_resets_only` to ``False``, - so that "reset only" behavior is still available from this method. + so that "reset only" behavior is still available. .. versionadded:: 2.0.22 @@ -682,8 +683,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: - ... + ) -> Result[_T]: ... @overload def execute( @@ -695,8 +695,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... @overload def execute( @@ -708,8 +707,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: - ... + ) -> Result[Any]: ... def execute( self, @@ -734,9 +732,8 @@ def execute( E.g.:: from sqlalchemy import select - result = session.execute( - select(User).where(User.id == 5) - ) + + result = session.execute(select(User).where(User.id == 5)) The API contract of :meth:`_orm.Session.execute` is similar to that of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version @@ -966,10 +963,7 @@ def get( some_object = session.get(VersionedFoo, (5, 10)) - some_object = session.get( - VersionedFoo, - {"id": 5, "version_id": 10} - ) + some_object = session.get(VersionedFoo, {"id": 5, "version_id": 10}) .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved from the now legacy :meth:`_orm.Query.get` method. @@ -1092,8 +1086,7 @@ def get_one( Proxied for the :class:`_orm.Session` class on behalf of the :class:`_orm.scoping.scoped_session` class. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query - selects no rows. + Raises :class:`_exc.NoResultFound` if the query selects no rows. For a detailed documentation of the arguments see the method :meth:`.Session.get`. @@ -1232,7 +1225,7 @@ def is_modified( This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously committed value, if any. + value to its previously flushed or committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1574,14 +1567,12 @@ def merge( return self._proxied.merge(instance, load=load, options=options) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: - ... + ) -> RowReturningQuery[Tuple[_T]]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -1591,14 +1582,12 @@ def query( @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... @overload def query( @@ -1607,8 +1596,7 @@ def query( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def query( @@ -1618,8 +1606,7 @@ def query( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def query( @@ -1630,8 +1617,7 @@ def query( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def query( @@ -1643,8 +1629,7 @@ def query( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def query( @@ -1657,16 +1642,14 @@ def query( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: - ... + ) -> Query[Any]: ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -1818,8 +1801,7 @@ def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -1830,8 +1812,7 @@ def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -1873,8 +1854,7 @@ def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -1885,8 +1865,7 @@ def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d8619812719..ca7b2c2b59f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1,5 +1,5 @@ # orm/session.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -146,9 +146,9 @@ "object_session", ] -_sessions: weakref.WeakValueDictionary[ - int, Session -] = weakref.WeakValueDictionary() +_sessions: weakref.WeakValueDictionary[int, Session] = ( + weakref.WeakValueDictionary() +) """Weak-referencing dictionary of :class:`.Session` objects. """ @@ -188,8 +188,7 @@ def __call__( mapper: Optional[Mapper[Any]] = None, instance: Optional[object] = None, **kw: Any, - ) -> Connection: - ... + ) -> Connection: ... def _state_session(state: InstanceState[Any]) -> Optional[Session]: @@ -576,22 +575,67 @@ def is_executemany(self) -> bool: @property def is_select(self) -> bool: - """return True if this is a SELECT operation.""" + """return True if this is a SELECT operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Select` construct, such as + ``select(Entity).from_statement(select(..))`` + + """ return self.statement.is_select + @property + def is_from_statement(self) -> bool: + """return True if this operation is a + :meth:`_sql.Select.from_statement` operation. + + This is independent from :attr:`_orm.ORMExecuteState.is_select`, as a + ``select().from_statement()`` construct can be used with + INSERT/UPDATE/DELETE RETURNING types of statements as well. + :attr:`_orm.ORMExecuteState.is_select` will only be set if the + :meth:`_sql.Select.from_statement` is itself against a + :class:`_sql.Select` construct. + + .. versionadded:: 2.0.30 + + """ + return self.statement.is_from_statement + @property def is_insert(self) -> bool: - """return True if this is an INSERT operation.""" + """return True if this is an INSERT operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Insert` construct, such as + ``select(Entity).from_statement(insert(..))`` + + """ return self.statement.is_dml and self.statement.is_insert @property def is_update(self) -> bool: - """return True if this is an UPDATE operation.""" + """return True if this is an UPDATE operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Update` construct, such as + ``select(Entity).from_statement(update(..))`` + + """ return self.statement.is_dml and self.statement.is_update @property def is_delete(self) -> bool: - """return True if this is a DELETE operation.""" + """return True if this is a DELETE operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Delete` construct, such as + ``select(Entity).from_statement(delete(..))`` + + """ return self.statement.is_dml and self.statement.is_delete @property @@ -1000,9 +1044,11 @@ def connection( def _begin(self, nested: bool = False) -> SessionTransaction: return SessionTransaction( self.session, - SessionTransactionOrigin.BEGIN_NESTED - if nested - else SessionTransactionOrigin.SUBTRANSACTION, + ( + SessionTransactionOrigin.BEGIN_NESTED + if nested + else SessionTransactionOrigin.SUBTRANSACTION + ), self, ) @@ -1165,6 +1211,17 @@ def _connection_for_bind( else: join_transaction_mode = "rollback_only" + if local_connect: + util.warn( + "The engine provided as bind produced a " + "connection that is already in a transaction. " + "This is usually caused by a core event, " + "such as 'engine_connect', that has left a " + "transaction open. The effective join " + "transaction mode used by this session is " + f"{join_transaction_mode!r}. To silence this " + "warning, do not leave transactions open" + ) if join_transaction_mode in ( "control_fully", "rollback_only", @@ -1512,12 +1569,16 @@ def __init__( operation. The complete heuristics for resolution are described at :meth:`.Session.get_bind`. Usage looks like:: - Session = sessionmaker(binds={ - SomeMappedClass: create_engine('postgresql+psycopg2://engine1'), - SomeDeclarativeBase: create_engine('postgresql+psycopg2://engine2'), - some_mapper: create_engine('postgresql+psycopg2://engine3'), - some_table: create_engine('postgresql+psycopg2://engine4'), - }) + Session = sessionmaker( + binds={ + SomeMappedClass: create_engine("postgresql+psycopg2://engine1"), + SomeDeclarativeBase: create_engine( + "postgresql+psycopg2://engine2" + ), + some_mapper: create_engine("postgresql+psycopg2://engine3"), + some_table: create_engine("postgresql+psycopg2://engine4"), + } + ) .. seealso:: @@ -1712,7 +1773,7 @@ def __init__( # the idea is that at some point NO_ARG will warn that in the future # the default will switch to close_resets_only=False. - if close_resets_only or close_resets_only is _NoArg.NO_ARG: + if close_resets_only in (True, _NoArg.NO_ARG): self._close_state = _SessionCloseState.CLOSE_IS_RESET else: self._close_state = _SessionCloseState.ACTIVE @@ -1819,9 +1880,11 @@ def _autobegin_t(self, begin: bool = False) -> SessionTransaction: ) trans = SessionTransaction( self, - SessionTransactionOrigin.BEGIN - if begin - else SessionTransactionOrigin.AUTOBEGIN, + ( + SessionTransactionOrigin.BEGIN + if begin + else SessionTransactionOrigin.AUTOBEGIN + ), ) assert self._transaction is trans return trans @@ -2057,8 +2120,7 @@ def _execute_internal( _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: Literal[True] = ..., - ) -> Any: - ... + ) -> Any: ... @overload def _execute_internal( @@ -2071,8 +2133,7 @@ def _execute_internal( _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result[Any]: - ... + ) -> Result[Any]: ... def _execute_internal( self, @@ -2215,8 +2276,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: - ... + ) -> Result[_T]: ... @overload def execute( @@ -2228,8 +2288,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... @overload def execute( @@ -2241,8 +2300,7 @@ def execute( bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: - ... + ) -> Result[Any]: ... def execute( self, @@ -2262,9 +2320,8 @@ def execute( E.g.:: from sqlalchemy import select - result = session.execute( - select(User).where(User.id == 5) - ) + + result = session.execute(select(User).where(User.id == 5)) The API contract of :meth:`_orm.Session.execute` is similar to that of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version @@ -2323,8 +2380,7 @@ def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -2335,8 +2391,7 @@ def scalar( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -2373,8 +2428,7 @@ def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -2385,8 +2439,7 @@ def scalars( execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, @@ -2472,12 +2525,12 @@ def reset(self) -> None: :class:`_orm.Session`, resetting the session to its initial state. This method provides for same "reset-only" behavior that the - :meth:_orm.Session.close method has provided historically, where the + :meth:`_orm.Session.close` method has provided historically, where the state of the :class:`_orm.Session` is reset as though the object were brand new, and ready to be used again. - The method may then be useful for :class:`_orm.Session` objects + This method may then be useful for :class:`_orm.Session` objects which set :paramref:`_orm.Session.close_resets_only` to ``False``, - so that "reset only" behavior is still available from this method. + so that "reset only" behavior is still available. .. versionadded:: 2.0.22 @@ -2795,14 +2848,12 @@ def get_bind( ) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: - ... + ) -> RowReturningQuery[Tuple[_T]]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -2812,14 +2863,12 @@ def query( @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... @overload def query( @@ -2828,8 +2877,7 @@ def query( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def query( @@ -2839,8 +2887,7 @@ def query( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def query( @@ -2851,8 +2898,7 @@ def query( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def query( @@ -2864,8 +2910,7 @@ def query( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def query( @@ -2878,16 +2923,14 @@ def query( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: - ... + ) -> Query[Any]: ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -2930,7 +2973,7 @@ def _identity_lookup( e.g.:: - obj = session._identity_lookup(inspect(SomeClass), (1, )) + obj = session._identity_lookup(inspect(SomeClass), (1,)) :param mapper: mapper in use :param primary_key_identity: the primary key we are searching for, as @@ -3001,7 +3044,8 @@ def no_autoflush(self) -> Iterator[Session]: @util.langhelpers.tag_method_for_warnings( "This warning originated from the Session 'autoflush' process, " "which was invoked automatically in response to a user-initiated " - "operation.", + "operation. Consider using ``no_autoflush`` context manager if this " + "warning happended while initializing objects.", sa_exc.SAWarning, ) def _autoflush(self) -> None: @@ -3557,10 +3601,7 @@ def get( some_object = session.get(VersionedFoo, (5, 10)) - some_object = session.get( - VersionedFoo, - {"id": 5, "version_id": 10} - ) + some_object = session.get(VersionedFoo, {"id": 5, "version_id": 10}) .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved from the now legacy :meth:`_orm.Query.get` method. @@ -3649,7 +3690,7 @@ def get( :return: The object instance, or ``None``. - """ + """ # noqa: E501 return self._get_impl( entity, ident, @@ -3677,8 +3718,7 @@ def get_one( """Return exactly one instance based on the given primary key identifier, or raise an exception if not found. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query - selects no rows. + Raises :class:`_exc.NoResultFound` if the query selects no rows. For a detailed documentation of the arguments see the method :meth:`.Session.get`. @@ -3768,9 +3808,9 @@ def _get_impl( if correct_keys: primary_key_identity = dict(primary_key_identity) for k in correct_keys: - primary_key_identity[ - pk_synonyms[k] - ] = primary_key_identity[k] + primary_key_identity[pk_synonyms[k]] = ( + primary_key_identity[k] + ) try: primary_key_identity = list( @@ -3974,14 +4014,7 @@ def _merge( else: key_is_persistent = True - if key in self.identity_map: - try: - merged = self.identity_map[key] - except KeyError: - # object was GC'ed right as we checked for it - merged = None - else: - merged = None + merged = self.identity_map.get(key) if merged is None: if key_is_persistent and key in _resolve_conflict_map: @@ -4545,11 +4578,11 @@ def grouping_key( self._bulk_save_mappings( mapper, states, - isupdate, - True, - return_defaults, - update_changed_only, - False, + isupdate=isupdate, + isstates=True, + return_defaults=return_defaults, + update_changed_only=update_changed_only, + render_nulls=False, ) def bulk_insert_mappings( @@ -4628,11 +4661,11 @@ def bulk_insert_mappings( self._bulk_save_mappings( mapper, mappings, - False, - False, - return_defaults, - False, - render_nulls, + isupdate=False, + isstates=False, + return_defaults=return_defaults, + update_changed_only=False, + render_nulls=render_nulls, ) def bulk_update_mappings( @@ -4674,13 +4707,20 @@ def bulk_update_mappings( """ self._bulk_save_mappings( - mapper, mappings, True, False, False, False, False + mapper, + mappings, + isupdate=True, + isstates=False, + return_defaults=False, + update_changed_only=False, + render_nulls=False, ) def _bulk_save_mappings( self, mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + *, isupdate: bool, isstates: bool, return_defaults: bool, @@ -4697,17 +4737,17 @@ def _bulk_save_mappings( mapper, mappings, transaction, - isstates, - update_changed_only, + isstates=isstates, + update_changed_only=update_changed_only, ) else: bulk_persistence._bulk_insert( mapper, mappings, transaction, - isstates, - return_defaults, - render_nulls, + isstates=isstates, + return_defaults=return_defaults, + render_nulls=render_nulls, ) transaction.commit() @@ -4725,7 +4765,7 @@ def is_modified( This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously committed value, if any. + value to its previously flushed or committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -4895,7 +4935,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): # an Engine, which the Session will use for connection # resources - engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/') + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/") Session = sessionmaker(engine) @@ -4948,7 +4988,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): with engine.connect() as connection: with Session(bind=connection) as session: - # work with session + ... # work with session The class also includes a method :meth:`_orm.sessionmaker.configure`, which can be used to specify additional keyword arguments to the factory, which @@ -4963,7 +5003,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): # ... later, when an engine URL is read from a configuration # file or other events allow the engine to be created - engine = create_engine('sqlite:///foo.db') + engine = create_engine("sqlite:///foo.db") Session.configure(bind=engine) sess = Session() @@ -4988,8 +5028,7 @@ def __init__( expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... @overload def __init__( @@ -5000,8 +5039,7 @@ def __init__( expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... def __init__( self, @@ -5103,7 +5141,7 @@ def configure(self, **new_kw: Any) -> None: Session = sessionmaker() - Session.configure(bind=create_engine('sqlite://')) + Session.configure(bind=create_engine("sqlite://")) """ self.kw.update(new_kw) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index d9e1f854d77..d4bbf920993 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -1,5 +1,5 @@ # orm/state.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -78,8 +78,7 @@ class _InstanceDictProto(Protocol): - def __call__(self) -> Optional[IdentityMap]: - ... + def __call__(self) -> Optional[IdentityMap]: ... class _InstallLoaderCallableProto(Protocol[_O]): @@ -94,13 +93,12 @@ class _InstallLoaderCallableProto(Protocol[_O]): def __call__( self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] - ) -> None: - ... + ) -> None: ... @inspection._self_inspects class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): - """tracks state information at the instance level. + """Tracks state information at the instance level. The :class:`.InstanceState` is a key object used by the SQLAlchemy ORM in order to track the state of an object; @@ -150,7 +148,14 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): committed_state: Dict[str, Any] modified: bool = False + """When ``True`` the object was modified.""" expired: bool = False + """When ``True`` the object is :term:`expired`. + + .. seealso:: + + :ref:`session_expire` + """ _deleted: bool = False _load_pending: bool = False _orphaned_outside_of_session: bool = False @@ -171,11 +176,12 @@ def _instance_dict(self): expired_attributes: Set[str] """The set of keys which are 'expired' to be loaded by - the manager's deferred scalar loader, assuming no pending - changes. + the manager's deferred scalar loader, assuming no pending + changes. - see also the ``unmodified`` collection which is intersected - against this set when a refresh operation occurs.""" + See also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs. + """ callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]] """A namespace where a per-state loader callable can be associated. @@ -230,7 +236,6 @@ def transient(self) -> bool: def pending(self) -> bool: """Return ``True`` if the object is :term:`pending`. - .. seealso:: :ref:`session_object_states` diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py index 3d74ff2de22..a79874e1c7a 100644 --- a/lib/sqlalchemy/orm/state_changes.py +++ b/lib/sqlalchemy/orm/state_changes.py @@ -1,13 +1,11 @@ # orm/state_changes.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""State tracking utilities used by :class:`_orm.Session`. - -""" +"""State tracking utilities used by :class:`_orm.Session`.""" from __future__ import annotations diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1e58f4091a6..8ac34e2943b 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1,5 +1,5 @@ # orm/strategies.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,7 +8,7 @@ """sqlalchemy.orm.interfaces.LoaderStrategy - implementations, and related MapperOptions.""" +implementations, and related MapperOptions.""" from __future__ import annotations @@ -16,8 +16,10 @@ import itertools from typing import Any from typing import Dict +from typing import Optional from typing import Tuple from typing import TYPE_CHECKING +from typing import Union from . import attributes from . import exc as orm_exc @@ -45,7 +47,7 @@ from .session import _state_session from .state import InstanceState from .strategy_options import Load -from .util import _none_set +from .util import _none_only_set from .util import AliasedClass from .. import event from .. import exc as sa_exc @@ -57,8 +59,10 @@ from ..sql import visitors from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import Select +from ..util.typing import Literal if TYPE_CHECKING: + from .mapper import Mapper from .relationships import RelationshipProperty from ..sql.elements import ColumnElement @@ -384,7 +388,7 @@ def __init__(self, parent, strategy_key): super().__init__(parent, strategy_key) if hasattr(self.parent_property, "composite_class"): raise NotImplementedError( - "Deferred loading for composite " "types not implemented yet" + "Deferred loading for composite types not implemented yet" ) self.raiseload = self.strategy_opts.get("raiseload", False) self.columns = self.parent_property.columns @@ -758,7 +762,7 @@ def __init__( self._equated_columns[c] = self._equated_columns[col] self.logger.info( - "%s will use Session.get() to " "optimize instance loads", self + "%s will use Session.get() to optimize instance loads", self ) def init_class_attribute(self, mapper): @@ -932,8 +936,15 @@ def _load_for_state( elif LoaderCallableStatus.NEVER_SET in primary_key_identity: return LoaderCallableStatus.NEVER_SET - if _none_set.issuperset(primary_key_identity): - return None + # test for None alone in primary_key_identity based on + # allow_partial_pks preference. PASSIVE_NO_RESULT and NEVER_SET + # have already been tested above + if not self.mapper.allow_partial_pks: + if _none_only_set.intersection(primary_key_identity): + return None + else: + if _none_only_set.issuperset(primary_key_identity): + return None if ( self.key in state.dict @@ -1195,9 +1206,11 @@ def create_row_processor( key, self, loadopt, - loadopt._generate_extra_criteria(context) - if loadopt._extra_criteria - else None, + ( + loadopt._generate_extra_criteria(context) + if loadopt._extra_criteria + else None + ), ), key, ) @@ -1371,12 +1384,16 @@ def create_row_processor( adapter, populators, ): + if not context.compile_state.compile_options._enable_eagerloads: + return + ( effective_path, run_loader, execution_options, recursion_depth, ) = self._setup_for_recursion(context, path, loadopt, self.join_depth) + if not run_loader: # this will not emit SQL and will only emit for a many-to-one # "use get" load. the "_RELATED" part means it may return @@ -1418,7 +1435,6 @@ def _load_for_path( alternate_effective_path = path._truncate_recursive() extra_options = (new_opt,) else: - new_opt = None alternate_effective_path = path extra_options = () @@ -1672,9 +1688,11 @@ def _apply_joins( elif ltj > 2: middle = [ ( - orm_util.AliasedClass(item[0]) - if not inspect(item[0]).is_aliased_class - else item[0].entity, + ( + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity + ), item[1], ) for item in to_join[1:-1] @@ -1953,6 +1971,18 @@ def create_row_processor( adapter, populators, ): + if ( + loadopt + and context.compile_state.statement is not None + and context.compile_state.statement.is_dml + ): + util.warn_deprecated( + "The subqueryload loader option is not compatible with DML " + "statements such as INSERT, UPDATE. Only SELECT may be used." + "This warning will become an exception in a future release.", + "2.0", + ) + if context.refresh_state: return self._immediateload_create_row_processor( context, @@ -2118,13 +2148,22 @@ def setup_query( if not compile_state.compile_options._enable_eagerloads: return + elif ( + loadopt + and compile_state.statement is not None + and compile_state.statement.is_dml + ): + util.warn_deprecated( + "The joinedload loader option is not compatible with DML " + "statements such as INSERT, UPDATE. Only SELECT may be used." + "This warning will become an exception in a future release.", + "2.0", + ) elif self.uselist: compile_state.multi_row_eager_loaders = True path = path[self.parent_property] - with_polymorphic = None - user_defined_adapter = ( self._init_user_defined_eager_proc( loadopt, compile_state, compile_state.attributes @@ -2328,9 +2367,11 @@ def _generate_row_adapter( to_adapt = orm_util.AliasedClass( self.mapper, - alias=alt_selectable._anonymous_fromclause(flat=True) - if alt_selectable is not None - else None, + alias=( + alt_selectable._anonymous_fromclause(flat=True) + if alt_selectable is not None + else None + ), flat=True, use_mapper_path=True, ) @@ -2500,13 +2541,13 @@ def _create_eager_join( or query_entity.entity_zero.represents_outer_join or (chained_from_outerjoin and isinstance(towrap, sql.Join)), _left_memo=self.parent, - _right_memo=self.mapper, + _right_memo=path[self.mapper], _extra_criteria=extra_join_criteria, ) else: # all other cases are innerjoin=='nested' approach eagerjoin = self._splice_nested_inner_join( - path, towrap, clauses, onclause, extra_join_criteria + path, path[-2], towrap, clauses, onclause, extra_join_criteria ) compile_state.eager_joins[query_entity_key] = eagerjoin @@ -2540,93 +2581,177 @@ def _create_eager_join( ) def _splice_nested_inner_join( - self, path, join_obj, clauses, onclause, extra_criteria, splicing=False + self, + path, + entity_we_want_to_splice_onto, + join_obj, + clauses, + onclause, + extra_criteria, + entity_inside_join_structure: Union[ + Mapper, None, Literal[False] + ] = False, + detected_existing_path: Optional[path_registry.PathRegistry] = None, ): # recursive fn to splice a nested join into an existing one. - # splicing=False means this is the outermost call, and it - # should return a value. splicing= is the recursive - # form, where it can return None to indicate the end of the recursion + # entity_inside_join_structure=False means this is the outermost call, + # and it should return a value. entity_inside_join_structure= + # indicates we've descended into a join and are looking at a FROM + # clause representing this mapper; if this is not + # entity_we_want_to_splice_onto then return None to end the recursive + # branch + + assert entity_we_want_to_splice_onto is path[-2] - if splicing is False: - # first call is always handed a join object - # from the outside + if entity_inside_join_structure is False: assert isinstance(join_obj, orm_util._ORMJoin) - elif isinstance(join_obj, sql.selectable.FromGrouping): + + if isinstance(join_obj, sql.selectable.FromGrouping): + # FromGrouping - continue descending into the structure return self._splice_nested_inner_join( path, + entity_we_want_to_splice_onto, join_obj.element, clauses, onclause, extra_criteria, - splicing, + entity_inside_join_structure, ) - elif not isinstance(join_obj, orm_util._ORMJoin): - if path[-2].isa(splicing): - return orm_util._ORMJoin( - join_obj, - clauses.aliased_insp, - onclause, - isouter=False, - _left_memo=splicing, - _right_memo=path[-1].mapper, - _extra_criteria=extra_criteria, - ) - else: - return None + elif isinstance(join_obj, orm_util._ORMJoin): + # _ORMJoin - continue descending into the structure - target_join = self._splice_nested_inner_join( - path, - join_obj.right, - clauses, - onclause, - extra_criteria, - join_obj._right_memo, - ) - if target_join is None: - right_splice = False + join_right_path = join_obj._right_memo + + # see if right side of join is viable target_join = self._splice_nested_inner_join( path, - join_obj.left, + entity_we_want_to_splice_onto, + join_obj.right, clauses, onclause, extra_criteria, - join_obj._left_memo, + entity_inside_join_structure=( + join_right_path[-1].mapper + if join_right_path is not None + else None + ), ) - if target_join is None: - # should only return None when recursively called, - # e.g. splicing refers to a from obj - assert ( - splicing is not False - ), "assertion failed attempting to produce joined eager loads" - return None - else: - right_splice = True - - if right_splice: - # for a right splice, attempt to flatten out - # a JOIN b JOIN c JOIN .. to avoid needless - # parenthesis nesting - if not join_obj.isouter and not target_join.isouter: - eagerjoin = join_obj._splice_into_center(target_join) + + if target_join is not None: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, + target_join, + join_obj.onclause, + isouter=join_obj.isouter, + _left_memo=join_obj._left_memo, + ) + + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + else: - eagerjoin = orm_util._ORMJoin( + # see if left side of join is viable + target_join = self._splice_nested_inner_join( + path, + entity_we_want_to_splice_onto, join_obj.left, - target_join, - join_obj.onclause, - isouter=join_obj.isouter, - _left_memo=join_obj._left_memo, + clauses, + onclause, + extra_criteria, + entity_inside_join_structure=join_obj._left_memo, + detected_existing_path=join_right_path, ) - else: - eagerjoin = orm_util._ORMJoin( - target_join, - join_obj.right, - join_obj.onclause, - isouter=join_obj.isouter, - _right_memo=join_obj._right_memo, - ) - eagerjoin._target_adapter = target_join._target_adapter - return eagerjoin + if target_join is not None: + eagerjoin = orm_util._ORMJoin( + target_join, + join_obj.right, + join_obj.onclause, + isouter=join_obj.isouter, + _right_memo=join_obj._right_memo, + ) + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + + # neither side viable, return None, or fail if this was the top + # most call + if entity_inside_join_structure is False: + assert ( + False + ), "assertion failed attempting to produce joined eager loads" + return None + + # reached an endpoint (e.g. a table that's mapped, or an alias of that + # table). determine if we can use this endpoint to splice onto + + # is this the entity we want to splice onto in the first place? + if not entity_we_want_to_splice_onto.isa(entity_inside_join_structure): + return None + + # path check. if we know the path how this join endpoint got here, + # lets look at our path we are satisfying and see if we're in the + # wrong place. This is specifically for when our entity may + # appear more than once in the path, issue #11449 + # updated in issue #11965. + if detected_existing_path and len(detected_existing_path) > 2: + # this assertion is currently based on how this call is made, + # where given a join_obj, the call will have these parameters as + # entity_inside_join_structure=join_obj._left_memo + # and entity_inside_join_structure=join_obj._right_memo.mapper + assert detected_existing_path[-3] is entity_inside_join_structure + + # from that, see if the path we are targeting matches the + # "existing" path of this join all the way up to the midpoint + # of this join object (e.g. the relationship). + # if not, then this is not our target + # + # a test condition where this test is false looks like: + # + # desired splice: Node->kind->Kind + # path of desired splice: NodeGroup->nodes->Node->kind + # path we've located: NodeGroup->nodes->Node->common_node->Node + # + # above, because we want to splice kind->Kind onto + # NodeGroup->nodes->Node, this is not our path because it actually + # goes more steps than we want into self-referential + # ->common_node->Node + # + # a test condition where this test is true looks like: + # + # desired splice: B->c2s->C2 + # path of desired splice: A->bs->B->c2s + # path we've located: A->bs->B->c1s->C1 + # + # above, we want to splice c2s->C2 onto B, and the located path + # shows that the join ends with B->c1s->C1. so we will + # add another join onto that, which would create a "branch" that + # we might represent in a pseudopath as: + # + # B->c1s->C1 + # ->c2s->C2 + # + # i.e. A JOIN B ON JOIN C1 ON + # JOIN C2 ON + # + + if detected_existing_path[0:-2] != path.path[0:-1]: + return None + + return orm_util._ORMJoin( + join_obj, + clauses.aliased_insp, + onclause, + isouter=False, + _left_memo=entity_inside_join_structure, + _right_memo=path[path[-1].mapper], + _extra_criteria=extra_criteria, + ) def _create_eager_adapter(self, context, result, adapter, path, loadopt): compile_state = context.compile_state @@ -2675,6 +2800,10 @@ def create_row_processor( adapter, populators, ): + + if not context.compile_state.compile_options._enable_eagerloads: + return + if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " @@ -2954,6 +3083,9 @@ def create_row_processor( if not run_loader: return + if not context.compile_state.compile_options._enable_eagerloads: + return + if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " @@ -3111,7 +3243,7 @@ def _load_for_path( orig_query = context.compile_state.select_statement # the actual statement that was requested is this one: - # context_query = context.query + # context_query = context.user_passed_query # # that's not the cached one, however. So while it is of the identical # structure, if it has entities like AliasedInsp, which we get from @@ -3135,11 +3267,11 @@ def _load_for_path( effective_path = path[self.parent_property] - if orig_query is context.query: + if orig_query is context.user_passed_query: new_options = orig_query._with_options else: cached_options = orig_query._with_options - uncached_options = context.query._with_options + uncached_options = context.user_passed_query._with_options # propagate compile state options from the original query, # updating their "extra_criteria" as necessary. diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 6c81e8fe737..17bbe353495 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1,13 +1,12 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# orm/strategy_options.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -""" - -""" +""" """ from __future__ import annotations @@ -97,6 +96,7 @@ def contains_eager( attr: _AttrType, alias: Optional[_FromClauseArgument] = None, _is_chain: bool = False, + _propagate_to_loaders: bool = False, ) -> Self: r"""Indicate that the given attribute should be eagerly loaded from columns stated manually in the query. @@ -107,9 +107,7 @@ def contains_eager( The option is used in conjunction with an explicit join that loads the desired rows, i.e.:: - sess.query(Order).\ - join(Order.user).\ - options(contains_eager(Order.user)) + sess.query(Order).join(Order.user).options(contains_eager(Order.user)) The above query would join from the ``Order`` entity to its related ``User`` entity, and the returned ``Order`` objects would have the @@ -120,11 +118,9 @@ def contains_eager( :ref:`orm_queryguide_populate_existing` execution option assuming the primary collection of parent objects may already have been loaded:: - sess.query(User).\ - join(User.addresses).\ - filter(Address.email_address.like('%@aol.com')).\ - options(contains_eager(User.addresses)).\ - populate_existing() + sess.query(User).join(User.addresses).filter( + Address.email_address.like("%@aol.com") + ).options(contains_eager(User.addresses)).populate_existing() See the section :ref:`contains_eager` for complete usage details. @@ -159,7 +155,7 @@ def contains_eager( cloned = self._set_relationship_strategy( attr, {"lazy": "joined"}, - propagate_to_loaders=False, + propagate_to_loaders=_propagate_to_loaders, opts={"eager_from_alias": coerced_alias}, _reconcile_to_other=True if _is_chain else None, ) @@ -190,10 +186,18 @@ def load_only(self, *attrs: _AttrType, raiseload: bool = False) -> Self: the lead entity can be specifically referred to using the :class:`_orm.Load` constructor:: - stmt = select(User, Address).join(User.addresses).options( - Load(User).load_only(User.name, User.fullname), - Load(Address).load_only(Address.email_address) - ) + stmt = ( + select(User, Address) + .join(User.addresses) + .options( + Load(User).load_only(User.name, User.fullname), + Load(Address).load_only(Address.email_address), + ) + ) + + When used together with the + :ref:`populate_existing ` + execution option only the attributes listed will be refreshed. :param \*attrs: Attributes to be loaded, all others will be deferred. @@ -218,7 +222,7 @@ def load_only(self, *attrs: _AttrType, raiseload: bool = False) -> Self: """ cloned = self._set_column_strategy( - attrs, + _expand_column_strategy_attrs(attrs), {"deferred": False, "instrument": True}, ) @@ -246,28 +250,25 @@ def joinedload( examples:: # joined-load the "orders" collection on "User" - query(User).options(joinedload(User.orders)) + select(User).options(joinedload(User.orders)) # joined-load Order.items and then Item.keywords - query(Order).options( - joinedload(Order.items).joinedload(Item.keywords)) + select(Order).options(joinedload(Order.items).joinedload(Item.keywords)) # lazily load Order.items, but when Items are loaded, # joined-load the keywords collection - query(Order).options( - lazyload(Order.items).joinedload(Item.keywords)) + select(Order).options(lazyload(Order.items).joinedload(Item.keywords)) :param innerjoin: if ``True``, indicates that the joined eager load should use an inner join instead of the default of left outer join:: - query(Order).options(joinedload(Order.user, innerjoin=True)) + select(Order).options(joinedload(Order.user, innerjoin=True)) In order to chain multiple eager joins together where some may be OUTER and others INNER, right-nested joins are used to link them:: - query(A).options( - joinedload(A.bs, innerjoin=False). - joinedload(B.cs, innerjoin=True) + select(A).options( + joinedload(A.bs, innerjoin=False).joinedload(B.cs, innerjoin=True) ) The above query, linking A.bs via "outer" join and B.cs via "inner" @@ -282,10 +283,7 @@ def joinedload( will render as LEFT OUTER JOIN. For example, supposing ``A.bs`` is an outerjoin:: - query(A).options( - joinedload(A.bs). - joinedload(B.cs, innerjoin="unnested") - ) + select(A).options(joinedload(A.bs).joinedload(B.cs, innerjoin="unnested")) The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c", rather than as "a LEFT OUTER JOIN (b JOIN c)". @@ -315,13 +313,15 @@ def joinedload( :ref:`joined_eager_loading` - """ + """ # noqa: E501 loader = self._set_relationship_strategy( attr, {"lazy": "joined"}, - opts={"innerjoin": innerjoin} - if innerjoin is not None - else util.EMPTY_DICT, + opts=( + {"innerjoin": innerjoin} + if innerjoin is not None + else util.EMPTY_DICT + ), ) return loader @@ -335,17 +335,16 @@ def subqueryload(self, attr: _AttrType) -> Self: examples:: # subquery-load the "orders" collection on "User" - query(User).options(subqueryload(User.orders)) + select(User).options(subqueryload(User.orders)) # subquery-load Order.items and then Item.keywords - query(Order).options( - subqueryload(Order.items).subqueryload(Item.keywords)) + select(Order).options( + subqueryload(Order.items).subqueryload(Item.keywords) + ) # lazily load Order.items, but when Items are loaded, # subquery-load the keywords collection - query(Order).options( - lazyload(Order.items).subqueryload(Item.keywords)) - + select(Order).options(lazyload(Order.items).subqueryload(Item.keywords)) .. seealso:: @@ -370,16 +369,16 @@ def selectinload( examples:: # selectin-load the "orders" collection on "User" - query(User).options(selectinload(User.orders)) + select(User).options(selectinload(User.orders)) # selectin-load Order.items and then Item.keywords - query(Order).options( - selectinload(Order.items).selectinload(Item.keywords)) + select(Order).options( + selectinload(Order.items).selectinload(Item.keywords) + ) # lazily load Order.items, but when Items are loaded, # selectin-load the keywords collection - query(Order).options( - lazyload(Order.items).selectinload(Item.keywords)) + select(Order).options(lazyload(Order.items).selectinload(Item.keywords)) :param recursion_depth: optional int; when set to a positive integer in conjunction with a self-referential relationship, @@ -490,10 +489,10 @@ def noload(self, attr: _AttrType) -> Self: :func:`_orm.noload` applies to :func:`_orm.relationship` attributes only. - .. note:: Setting this loading strategy as the default strategy - for a relationship using the :paramref:`.orm.relationship.lazy` - parameter may cause issues with flushes, such if a delete operation - needs to load related objects and instead ``None`` was returned. + .. legacy:: The :func:`_orm.noload` option is **legacy**. As it + forces collections to be empty, which invariably leads to + non-intuitive and difficult to predict results. There are no + legitimate uses for this option in modern SQLAlchemy. .. seealso:: @@ -555,17 +554,20 @@ def defaultload(self, attr: _AttrType) -> Self: element of an element:: session.query(MyClass).options( - defaultload(MyClass.someattribute). - joinedload(MyOtherClass.someotherattribute) + defaultload(MyClass.someattribute).joinedload( + MyOtherClass.someotherattribute + ) ) :func:`.defaultload` is also useful for setting column-level options on a related class, namely that of :func:`.defer` and :func:`.undefer`:: - session.query(MyClass).options( - defaultload(MyClass.someattribute). - defer("some_column"). - undefer("some_other_column") + session.scalars( + select(MyClass).options( + defaultload(MyClass.someattribute) + .defer("some_column") + .undefer("some_other_column") + ) ) .. seealso:: @@ -589,8 +591,7 @@ def defer(self, key: _AttrType, raiseload: bool = False) -> Self: from sqlalchemy.orm import defer session.query(MyClass).options( - defer(MyClass.attribute_one), - defer(MyClass.attribute_two) + defer(MyClass.attribute_one), defer(MyClass.attribute_two) ) To specify a deferred load of an attribute on a related class, @@ -606,11 +607,11 @@ def defer(self, key: _AttrType, raiseload: bool = False) -> Self: at once using :meth:`_orm.Load.options`:: - session.query(MyClass).options( + select(MyClass).options( defaultload(MyClass.someattr).options( defer(RelatedClass.some_column), defer(RelatedClass.some_other_column), - defer(RelatedClass.another_column) + defer(RelatedClass.another_column), ) ) @@ -635,7 +636,9 @@ def defer(self, key: _AttrType, raiseload: bool = False) -> Self: strategy = {"deferred": True, "instrument": True} if raiseload: strategy["raiseload"] = True - return self._set_column_strategy((key,), strategy) + return self._set_column_strategy( + _expand_column_strategy_attrs((key,)), strategy + ) def undefer(self, key: _AttrType) -> Self: r"""Indicate that the given column-oriented attribute should be @@ -656,12 +659,10 @@ def undefer(self, key: _AttrType) -> Self: ) # undefer all columns specific to a single class using Load + * - session.query(MyClass, MyOtherClass).options( - Load(MyClass).undefer("*")) + session.query(MyClass, MyOtherClass).options(Load(MyClass).undefer("*")) # undefer a column on a related object - session.query(MyClass).options( - defaultload(MyClass.items).undefer(MyClass.text)) + select(MyClass).options(defaultload(MyClass.items).undefer(MyClass.text)) :param key: Attribute to be undeferred. @@ -674,9 +675,10 @@ def undefer(self, key: _AttrType) -> Self: :func:`_orm.undefer_group` - """ + """ # noqa: E501 return self._set_column_strategy( - (key,), {"deferred": False, "instrument": True} + _expand_column_strategy_attrs((key,)), + {"deferred": False, "instrument": True}, ) def undefer_group(self, name: str) -> Self: @@ -694,8 +696,9 @@ def undefer_group(self, name: str) -> Self: spelled out using relationship loader options, such as :func:`_orm.defaultload`:: - session.query(MyClass).options( - defaultload("someattr").undefer_group("large_attrs")) + select(MyClass).options( + defaultload("someattr").undefer_group("large_attrs") + ) .. seealso:: @@ -776,12 +779,10 @@ def selectin_polymorphic(self, classes: Iterable[Type[Any]]) -> Self: return self @overload - def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: - ... + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: ... @overload - def _coerce_strat(self, strategy: Literal[None]) -> None: - ... + def _coerce_strat(self, strategy: Literal[None]) -> None: ... def _coerce_strat( self, strategy: Optional[_StrategySpec] @@ -1033,6 +1034,8 @@ def _construct_for_existing_path( def _adapt_cached_option_to_uncached_option( self, context: QueryContext, uncached_opt: ORMOption ) -> ORMOption: + if uncached_opt is self: + return self return self._adjust_for_extra_criteria(context) def _prepend_path(self, path: PathRegistry) -> Load: @@ -1048,47 +1051,51 @@ def _adjust_for_extra_criteria(self, context: QueryContext) -> Load: returning a new instance of this ``Load`` object. """ - orig_query = context.compile_state.select_statement - orig_cache_key: Optional[CacheKey] = None - replacement_cache_key: Optional[CacheKey] = None - found_crit = False + # avoid generating cache keys for the queries if we don't + # actually have any extra_criteria options, which is the + # common case + for value in self.context: + if value._extra_criteria: + break + else: + return self - def process(opt: _LoadElement) -> _LoadElement: - nonlocal orig_cache_key, replacement_cache_key, found_crit + replacement_cache_key = context.user_passed_query._generate_cache_key() - found_crit = True + if replacement_cache_key is None: + return self - if orig_cache_key is None or replacement_cache_key is None: - orig_cache_key = orig_query._generate_cache_key() - replacement_cache_key = context.query._generate_cache_key() + orig_query = context.compile_state.select_statement + orig_cache_key = orig_query._generate_cache_key() + assert orig_cache_key is not None - assert orig_cache_key is not None - assert replacement_cache_key is not None + def process( + opt: _LoadElement, + replacement_cache_key: CacheKey, + orig_cache_key: CacheKey, + ) -> _LoadElement: + cloned_opt = opt._clone() - opt._extra_criteria = tuple( + cloned_opt._extra_criteria = tuple( replacement_cache_key._apply_params_to_element( orig_cache_key, crit ) - for crit in opt._extra_criteria + for crit in cloned_opt._extra_criteria ) - return opt + return cloned_opt - # avoid generating cache keys for the queries if we don't - # actually have any extra_criteria options, which is the - # common case - new_context = tuple( - process(value._clone()) if value._extra_criteria else value + cloned = self._clone() + cloned.context = tuple( + ( + process(value, replacement_cache_key, orig_cache_key) + if value._extra_criteria + else value + ) for value in self.context ) - - if found_crit: - cloned = self._clone() - cloned.context = new_context - return cloned - else: - return self + return cloned def _reconcile_query_entities_with_us(self, mapper_entities, raiseerr): """called at process time to allow adjustment of the root @@ -1097,7 +1104,6 @@ def _reconcile_query_entities_with_us(self, mapper_entities, raiseerr): """ path = self.path - ezero = None for ent in mapper_entities: ezero = ent.entity_zero if ezero and orm_util._entity_corresponds_to( @@ -1120,7 +1126,20 @@ def _process( mapper_entities, raiseerr ) + # if the context has a current path, this is a lazy load + has_current_path = bool(compile_state.compile_options._current_path) + for loader in self.context: + # issue #11292 + # historically, propagate_to_loaders was only considered at + # object loading time, whether or not to carry along options + # onto an object's loaded state where it would be used by lazyload. + # however, the defaultload() option needs to propagate in case + # its sub-options propagate_to_loaders, but its sub-options + # that dont propagate should not be applied for lazy loaders. + # so we check again + if has_current_path and not loader.propagate_to_loaders: + continue loader.process_compile_state( self, compile_state, @@ -1178,13 +1197,11 @@ def options(self, *opts: _AbstractLoad) -> Self: query = session.query(Author) query = query.options( - joinedload(Author.book).options( - load_only(Book.summary, Book.excerpt), - joinedload(Book.citations).options( - joinedload(Citation.author) - ) - ) - ) + joinedload(Author.book).options( + load_only(Book.summary, Book.excerpt), + joinedload(Book.citations).options(joinedload(Citation.author)), + ) + ) :param \*opts: A series of loader option objects (ultimately :class:`_orm.Load` objects) which should be applied to the path @@ -1611,9 +1628,10 @@ def _raise_for_no_match(self, parent_loader, mapper_entities): f"Mapped class {path[0]} does not apply to any of the " f"root entities in this query, e.g. " f"""{ - ", ".join(str(x.entity_zero) - for x in mapper_entities if x.entity_zero - )}. Please """ + ", ".join( + str(x.entity_zero) + for x in mapper_entities if x.entity_zero + )}. Please """ "specify the full path " "from one of the root entities to the target " "attribute. " @@ -1627,13 +1645,17 @@ def _adjust_effective_path_for_current_path( loads, and adjusts the given path to be relative to the current_path. - E.g. given a loader path and current path:: + E.g. given a loader path and current path: + + .. sourcecode:: text lp: User -> orders -> Order -> items -> Item -> keywords -> Keyword cp: User -> orders -> Order -> items - The adjusted path would be:: + The adjusted path would be: + + .. sourcecode:: text Item -> keywords -> Keyword @@ -2079,9 +2101,9 @@ def __getstate__(self): d["_extra_criteria"] = () if self._path_with_polymorphic_path: - d[ - "_path_with_polymorphic_path" - ] = self._path_with_polymorphic_path.serialize() + d["_path_with_polymorphic_path"] = ( + self._path_with_polymorphic_path.serialize() + ) if self._of_type: if self._of_type.is_aliased_class: @@ -2114,11 +2136,11 @@ class _TokenStrategyLoad(_LoadElement): e.g.:: - raiseload('*') - Load(User).lazyload('*') - defer('*') + raiseload("*") + Load(User).lazyload("*") + defer("*") load_only(User.name, User.email) # will create a defer('*') - joinedload(User.addresses).raiseload('*') + joinedload(User.addresses).raiseload("*") """ @@ -2373,6 +2395,23 @@ def loader_unbound_fn(fn: _FN) -> _FN: return fn +def _expand_column_strategy_attrs( + attrs: Tuple[_AttrType, ...], +) -> Tuple[_AttrType, ...]: + return cast( + "Tuple[_AttrType, ...]", + tuple( + a + for attr in attrs + for a in ( + cast("QueryableAttribute[Any]", attr)._column_strategy_attrs() + if hasattr(attr, "_column_strategy_attrs") + else (attr,) + ) + ), + ) + + # standalone functions follow. docstrings are filled in # by the ``@loader_unbound_fn`` decorator. @@ -2386,6 +2425,7 @@ def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: def load_only(*attrs: _AttrType, raiseload: bool = False) -> _AbstractLoad: # TODO: attrs against different classes. we likely have to # add some extra state to Load of some kind + attrs = _expand_column_strategy_attrs(attrs) _, lead_element, _ = _parse_attr_argument(attrs[0]) return Load(lead_element).load_only(*attrs, raiseload=raiseload) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 036c26dd6be..8f85a41a2c0 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -1,5 +1,5 @@ # orm/sync.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -86,8 +86,9 @@ def clear(dest, dest_mapper, synchronize_pairs): not in orm_util._none_set ): raise AssertionError( - "Dependency rule tried to blank-out primary key " - "column '%s' on instance '%s'" % (r, orm_util.state_str(dest)) + f"Dependency rule on column '{l}' " + "tried to blank-out primary key " + f"column '{r}' on instance '{orm_util.state_str(dest)}'" ) try: dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 20fe022076b..80897f29262 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -1,5 +1,5 @@ # orm/unitofwork.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ea2f1a12e93..ca607af1be4 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1,5 +1,5 @@ # orm/util.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -35,6 +35,7 @@ from . import attributes # noqa from . import exc +from . import exc as orm_exc from ._typing import _O from ._typing import insp_is_aliased_class from ._typing import insp_is_mapper @@ -42,6 +43,7 @@ from .base import _class_to_mapper as _class_to_mapper from .base import _MappedAnnotationBase from .base import _never_set as _never_set # noqa: F401 +from .base import _none_only_set as _none_only_set # noqa: F401 from .base import _none_set as _none_set # noqa: F401 from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper @@ -85,14 +87,12 @@ from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation as _de_stringify_annotation -from ..util.typing import ( - de_stringify_union_elements as _de_stringify_union_elements, -) from ..util.typing import eval_name_only as _eval_name_only +from ..util.typing import fixup_container_fwd_refs +from ..util.typing import get_origin from ..util.typing import is_origin_of_cls from ..util.typing import Literal from ..util.typing import Protocol -from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -121,7 +121,6 @@ from ..sql.selectable import Selectable from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType - from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -138,7 +137,6 @@ ) ) - _de_stringify_partial = functools.partial( functools.partial, locals_=util.immutabledict( @@ -163,8 +161,7 @@ def __call__( *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, - ) -> Type[Any]: - ... + ) -> Type[Any]: ... de_stringify_annotation = cast( @@ -172,27 +169,8 @@ def __call__( ) -class _DeStringifyUnionElements(Protocol): - def __call__( - self, - cls: Type[Any], - annotation: ArgsTypeProcotol, - originating_module: str, - *, - str_cleanup_fn: Optional[Callable[[str, str], str]] = None, - ) -> Type[Any]: - ... - - -de_stringify_union_elements = cast( - _DeStringifyUnionElements, - _de_stringify_partial(_de_stringify_union_elements), -) - - class _EvalNameOnly(Protocol): - def __call__(self, name: str, module_name: str) -> Any: - ... + def __call__(self, name: str, module_name: str) -> Any: ... eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) @@ -250,7 +228,7 @@ def __new__( values.clear() values.discard("all") - self = super().__new__(cls, values) # type: ignore + self = super().__new__(cls, values) self.save_update = "save-update" in values self.delete = "delete" in values self.refresh_expire = "refresh-expire" in values @@ -259,9 +237,7 @@ def __new__( self.delete_orphan = "delete-orphan" in values if self.delete_orphan and not self.delete: - util.warn( - "The 'delete-orphan' cascade " "option requires 'delete'." - ) + util.warn("The 'delete-orphan' cascade option requires 'delete'.") return self def __repr__(self): @@ -478,9 +454,7 @@ def identity_key( E.g.:: - >>> row = engine.execute(\ - text("select * from table where a=1 and b=2")\ - ).first() + >>> row = engine.execute(text("select * from table where a=1 and b=2")).first() >>> identity_key(MyClass, row=row) (, (1, 2), None) @@ -491,7 +465,7 @@ def identity_key( .. versionadded:: 1.2 added identity_token - """ + """ # noqa: E501 if class_ is not None: mapper = class_mapper(class_) if row is None: @@ -669,9 +643,9 @@ class AliasedClass( # find all pairs of users with the same name user_alias = aliased(User) - session.query(User, user_alias).\ - join((user_alias, User.id > user_alias.id)).\ - filter(User.name == user_alias.name) + session.query(User, user_alias).join( + (user_alias, User.id > user_alias.id) + ).filter(User.name == user_alias.name) :class:`.AliasedClass` is also capable of mapping an existing mapped class to an entirely new selectable, provided this selectable is column- @@ -695,6 +669,7 @@ class to an entirely new selectable, provided this selectable is column- using :func:`_sa.inspect`:: from sqlalchemy import inspect + my_alias = aliased(MyClass) insp = inspect(my_alias) @@ -755,12 +730,16 @@ def __init__( insp, alias, name, - with_polymorphic_mappers - if with_polymorphic_mappers - else mapper.with_polymorphic_mappers, - with_polymorphic_discriminator - if with_polymorphic_discriminator is not None - else mapper.polymorphic_on, + ( + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers + ), + ( + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on + ), base_alias, use_mapper_path, adapt_on_names, @@ -971,9 +950,9 @@ def __init__( self._weak_entity = weakref.ref(entity) self.mapper = mapper - self.selectable = ( - self.persist_selectable - ) = self.local_table = selectable + self.selectable = self.persist_selectable = self.local_table = ( + selectable + ) self.name = name self.polymorphic_on = polymorphic_on self._base_alias = weakref.ref(_base_alias or self) @@ -1068,6 +1047,7 @@ def _with_polymorphic_factory( aliased: bool = False, innerjoin: bool = False, adapt_on_names: bool = False, + name: Optional[str] = None, _use_mapper_path: bool = False, ) -> AliasedClass[_O]: primary_mapper = _class_to_mapper(base) @@ -1088,6 +1068,7 @@ def _with_polymorphic_factory( return AliasedClass( base, selectable, + name=name, with_polymorphic_mappers=mappers, adapt_on_names=adapt_on_names, with_polymorphic_discriminator=polymorphic_on, @@ -1229,8 +1210,7 @@ def _orm_adapt_element( self, obj: _CE, key: Optional[str] = None, - ) -> _CE: - ... + ) -> _CE: ... else: _orm_adapt_element = _adapt_element @@ -1380,7 +1360,10 @@ class LoaderCriteriaOption(CriteriaOption): def __init__( self, entity_or_base: _EntityType[Any], - where_criteria: _ColumnExpressionArgument[bool], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, @@ -1539,7 +1522,7 @@ def _inspect_mc( def _inspect_generic_alias( class_: Type[_O], ) -> Optional[Mapper[_O]]: - origin = cast("Type[_O]", typing_get_origin(class_)) + origin = cast("Type[_O]", get_origin(class_)) return _inspect_mc(origin) @@ -1583,7 +1566,7 @@ class Bundle( _propagate_attrs: _PropagateAttrsType = util.immutabledict() - proxy_set = util.EMPTY_SET # type: ignore + proxy_set = util.EMPTY_SET exprs: List[_ColumnsClauseElement] @@ -1596,8 +1579,7 @@ def __init__( bn = Bundle("mybundle", MyClass.x, MyClass.y) - for row in session.query(bn).filter( - bn.c.x == 5).filter(bn.c.y == 4): + for row in session.query(bn).filter(bn.c.x == 5).filter(bn.c.y == 4): print(row.mybundle.x, row.mybundle.y) :param name: name of the bundle. @@ -1606,7 +1588,7 @@ def __init__( can be returned as a "single entity" outside of any enclosing tuple in the same manner as a mapped entity. - """ + """ # noqa: E501 self.name = self._label = name coerced_exprs = [ coercions.expect( @@ -1661,24 +1643,24 @@ def entity_namespace( Nesting of bundles is also supported:: - b1 = Bundle("b1", - Bundle('b2', MyClass.a, MyClass.b), - Bundle('b3', MyClass.x, MyClass.y) - ) + b1 = Bundle( + "b1", + Bundle("b2", MyClass.a, MyClass.b), + Bundle("b3", MyClass.x, MyClass.y), + ) - q = sess.query(b1).filter( - b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + q = sess.query(b1).filter(b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) .. seealso:: :attr:`.Bundle.c` - """ + """ # noqa: E501 c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" - def _clone(self): + def _clone(self, **kw): cloned = self.__class__.__new__(self.__class__) cloned.__dict__.update(self.__dict__) return cloned @@ -1739,25 +1721,24 @@ def create_row_processor( from sqlalchemy.orm import Bundle + class DictBundle(Bundle): def create_row_processor(self, query, procs, labels): - 'Override create_row_processor to return values as - dictionaries' + "Override create_row_processor to return values as dictionaries" def proc(row): - return dict( - zip(labels, (proc(row) for proc in procs)) - ) + return dict(zip(labels, (proc(row) for proc in procs))) + return proc A result from the above :class:`_orm.Bundle` will return dictionary values:: - bn = DictBundle('mybundle', MyClass.data1, MyClass.data2) - for row in session.execute(select(bn)).where(bn.c.data1 == 'd1'): - print(row.mybundle['data1'], row.mybundle['data2']) + bn = DictBundle("mybundle", MyClass.data1, MyClass.data2) + for row in session.execute(select(bn)).where(bn.c.data1 == "d1"): + print(row.mybundle["data1"], row.mybundle["data2"]) - """ + """ # noqa: E501 keyed_tuple = result_tuple(labels, [() for l in labels]) def proc(row: Row[Any]) -> Any: @@ -1940,7 +1921,7 @@ def _splice_into_center(self, other): self.onclause, isouter=self.isouter, _left_memo=self._left_memo, - _right_memo=other._left_memo, + _right_memo=other._left_memo._path_registry, ) return _ORMJoin( @@ -1983,7 +1964,6 @@ def with_parent( stmt = select(Address).where(with_parent(some_user, User.addresses)) - The SQL rendered is the same as that rendered when a lazy loader would fire off from the given parent on that attribute, meaning that the appropriate state is taken from the parent object in @@ -1996,9 +1976,7 @@ def with_parent( a1 = aliased(Address) a2 = aliased(Address) - stmt = select(a1, a2).where( - with_parent(u1, User.addresses.of_type(a2)) - ) + stmt = select(a1, a2).where(with_parent(u1, User.addresses.of_type(a2))) The above use is equivalent to using the :func:`_orm.with_parent.from_entity` argument:: @@ -2023,7 +2001,7 @@ def with_parent( .. versionadded:: 1.2 - """ + """ # noqa: E501 prop_t: RelationshipProperty[Any] if isinstance(prop, str): @@ -2117,14 +2095,13 @@ def _entity_corresponds_to_use_path_impl( someoption(A).someoption(C.d) # -> fn(A, C) -> False a1 = aliased(A) - someoption(a1).someoption(A.b) # -> fn(a1, A) -> False - someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True + someoption(a1).someoption(A.b) # -> fn(a1, A) -> False + someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True wp = with_polymorphic(A, [A1, A2]) someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True - """ if insp_is_aliased_class(given): return ( @@ -2151,7 +2128,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: mapper ) elif given.with_polymorphic_mappers: - return mapper in given.with_polymorphic_mappers + return mapper in given.with_polymorphic_mappers or given.isa(mapper) else: return given.isa(mapper) @@ -2233,7 +2210,7 @@ def _cleanup_mapped_str_annotation( inner: Optional[Match[str]] - mm = re.match(r"^(.+?)\[(.+)\]$", annotation) + mm = re.match(r"^([^ \|]+?)\[(.+)\]$", annotation) if not mm: return annotation @@ -2273,7 +2250,7 @@ def _cleanup_mapped_str_annotation( while True: stack.append(real_symbol if mm is inner else inner.group(1)) g2 = inner.group(2) - inner = re.match(r"^(.+?)\[(.+)\]$", g2) + inner = re.match(r"^([^ \|]+?)\[(.+)\]$", g2) if inner is None: stack.append(g2) break @@ -2295,8 +2272,10 @@ def _cleanup_mapped_str_annotation( # ['Mapped', "'Optional[Dict[str, str]]'"] not re.match(r"""^["'].*["']$""", stack[-1]) # avoid further generics like Dict[] such as - # ['Mapped', 'dict[str, str] | None'] - and not re.match(r".*\[.*\]", stack[-1]) + # ['Mapped', 'dict[str, str] | None'], + # ['Mapped', 'list[int] | list[str]'], + # ['Mapped', 'Union[list[int], list[str]]'], + and not re.search(r"[\[\]]", stack[-1]) ): stripchars = "\"' " stack[-1] = ", ".join( @@ -2318,7 +2297,7 @@ def _extract_mapped_subtype( is_dataclass_field: bool, expect_mapped: bool = True, raiseerr: bool = True, -) -> Optional[Tuple[Union[type, str], Optional[type]]]: +) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]: """given an annotation, figure out if it's ``Mapped[something]`` and if so, return the ``something`` part. @@ -2328,7 +2307,7 @@ def _extract_mapped_subtype( if raw_annotation is None: if required: - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f"Python typing annotation is required for attribute " f'"{cls.__name__}.{key}" when primary argument(s) for ' f'"{attr_cls.__name__}" construct are None or not present' @@ -2336,6 +2315,11 @@ def _extract_mapped_subtype( return None try: + # destringify the "outside" of the annotation. note we are not + # adding include_generic so it will *not* dig into generic contents, + # which will remain as ForwardRef or plain str under future annotations + # mode. The full destringify happens later when mapped_column goes + # to do a full lookup in the registry type_annotations_map. annotated = de_stringify_annotation( cls, raw_annotation, @@ -2343,14 +2327,14 @@ def _extract_mapped_subtype( str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f"Could not interpret annotation {raw_annotation}. " "Check that it uses names that are correctly imported at the " "module level. See chained stack trace for more hints." ) from ce except NameError as ne: if raiseerr and "Mapped[" in raw_annotation: # type: ignore - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f"Could not interpret annotation {raw_annotation}. " "Check that it uses names that are correctly imported at the " "module level. See chained stack trace for more hints." @@ -2379,7 +2363,7 @@ def _extract_mapped_subtype( ): return None - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( f'Type annotation for "{cls.__name__}.{key}" ' "can't be correctly interpreted for " "Annotated Declarative Table form. ORM annotations " @@ -2400,8 +2384,20 @@ def _extract_mapped_subtype( return annotated, None if len(annotated.__args__) != 1: - raise sa_exc.ArgumentError( + raise orm_exc.MappedAnnotationError( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0], annotated.__origin__ + return ( + # fix dict/list/set args to be ForwardRef, see #11814 + fixup_container_fwd_refs(annotated.__args__[0]), + annotated.__origin__, + ) + + +def _mapper_property_as_plain_name(prop: Type[Any]) -> str: + if hasattr(prop, "_mapper_property_name"): + name = prop._mapper_property_name() + else: + name = None + return util.clsname_as_plain_name(prop, name) diff --git a/lib/sqlalchemy/orm/writeonly.py b/lib/sqlalchemy/orm/writeonly.py index 416a0399f93..fe9c8e96e89 100644 --- a/lib/sqlalchemy/orm/writeonly.py +++ b/lib/sqlalchemy/orm/writeonly.py @@ -1,5 +1,5 @@ # orm/writeonly.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -196,8 +196,7 @@ def get_collection( dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -206,8 +205,7 @@ def get_collection( dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -218,8 +216,7 @@ def get_collection( passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, @@ -239,15 +236,11 @@ def get_collection( return DynamicCollectionAdapter(data) # type: ignore[return-value] @util.memoized_property - def _append_token( # type:ignore[override] - self, - ) -> attributes.AttributeEventToken: + def _append_token(self) -> attributes.AttributeEventToken: return attributes.AttributeEventToken(self, attributes.OP_APPEND) @util.memoized_property - def _remove_token( # type:ignore[override] - self, - ) -> attributes.AttributeEventToken: + def _remove_token(self) -> attributes.AttributeEventToken: return attributes.AttributeEventToken(self, attributes.OP_REMOVE) def fire_append_event( diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 7929b6e4bed..51bf0ec7992 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -1,5 +1,5 @@ -# sqlalchemy/pool/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# pool/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 90ed32ec27b..ed4d7c115ab 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -1,14 +1,12 @@ -# sqlalchemy/pool.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# pool/base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Base constructs for connection pools. - -""" +"""Base constructs for connection pools.""" from __future__ import annotations @@ -147,17 +145,14 @@ class _AsyncConnDialect(_ConnDialect): class _CreatorFnType(Protocol): - def __call__(self) -> DBAPIConnection: - ... + def __call__(self) -> DBAPIConnection: ... class _CreatorWRecFnType(Protocol): - def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: - ... + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: ... class Pool(log.Identified, event.EventTarget): - """Abstract base class for connection pools.""" dispatch: dispatcher[Pool] @@ -471,6 +466,7 @@ def _do_return_conn(self, record: ConnectionPoolEntry) -> None: raise NotImplementedError() def status(self) -> str: + """Returns a brief description of the state of this pool.""" raise NotImplementedError() @@ -633,7 +629,6 @@ def close(self) -> None: class _ConnectionRecord(ConnectionPoolEntry): - """Maintains a position in a connection pool which references a pooled connection. @@ -729,11 +724,13 @@ def checkout(cls, pool: Pool) -> _ConnectionFairy: rec.fairy_ref = ref = weakref.ref( fairy, - lambda ref: _finalize_fairy( - None, rec, pool, ref, echo, transaction_was_reset=False - ) - if _finalize_fairy is not None - else None, + lambda ref: ( + _finalize_fairy( + None, rec, pool, ref, echo, transaction_was_reset=False + ) + if _finalize_fairy is not None + else None + ), ) _strong_ref_connection_records[ref] = rec if echo: @@ -1074,14 +1071,13 @@ class PoolProxiedConnection(ManagesConnection): if typing.TYPE_CHECKING: - def commit(self) -> None: - ... + def commit(self) -> None: ... + + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... - def cursor(self) -> DBAPICursor: - ... + def rollback(self) -> None: ... - def rollback(self) -> None: - ... + def __getattr__(self, key: str) -> Any: ... @property def is_valid(self) -> bool: @@ -1189,7 +1185,6 @@ def __getattr__(self, key: Any) -> Any: class _ConnectionFairy(PoolProxiedConnection): - """Proxies a DBAPI connection and provides return-on-dereference support. diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py index 762418b14f2..4ceb260f79b 100644 --- a/lib/sqlalchemy/pool/events.py +++ b/lib/sqlalchemy/pool/events.py @@ -1,5 +1,5 @@ -# sqlalchemy/pool/events.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# pool/events.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -35,10 +35,12 @@ class PoolEvents(event.Events[Pool]): from sqlalchemy import event + def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): "handle an on checkout event" - event.listen(Pool, 'checkout', my_on_checkout) + + event.listen(Pool, "checkout", my_on_checkout) In addition to accepting the :class:`_pool.Pool` class and :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts @@ -49,7 +51,7 @@ def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") # will associate with engine.pool - event.listen(engine, 'checkout', my_on_checkout) + event.listen(engine, "checkout", my_on_checkout) """ # noqa: E501 @@ -173,7 +175,7 @@ def checkout( def checkin( self, - dbapi_connection: DBAPIConnection, + dbapi_connection: Optional[DBAPIConnection], connection_record: ConnectionPoolEntry, ) -> None: """Called when a connection returns to the pool. diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index af4f788e27d..f3d53ddb84d 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -1,14 +1,12 @@ -# sqlalchemy/pool.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# pool/impl.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Pool implementation classes. - -""" +"""Pool implementation classes.""" from __future__ import annotations import threading @@ -43,21 +41,30 @@ class QueuePool(Pool): - """A :class:`_pool.Pool` that imposes a limit on the number of open connections. :class:`.QueuePool` is the default pooling implementation used for - all :class:`_engine.Engine` objects, unless the SQLite dialect is - in use with a ``:memory:`` database. + all :class:`_engine.Engine` objects other than SQLite with a ``:memory:`` + database. + + The :class:`.QueuePool` class **is not compatible** with asyncio and + :func:`_asyncio.create_async_engine`. The + :class:`.AsyncAdaptedQueuePool` class is used automatically when + using :func:`_asyncio.create_async_engine`, if no other kind of pool + is specified. + + .. seealso:: + + :class:`.AsyncAdaptedQueuePool` """ - _is_asyncio = False # type: ignore[assignment] + _is_asyncio = False - _queue_class: Type[ - sqla_queue.QueueCommon[ConnectionPoolEntry] - ] = sqla_queue.Queue + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.Queue + ) _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] @@ -124,6 +131,7 @@ def __init__( :class:`_pool.Pool` constructor. """ + Pool.__init__(self, creator, **kw) self._pool = self._queue_class(pool_size, use_lifo=use_lifo) self._overflow = 0 - pool_size @@ -249,20 +257,31 @@ def checkedout(self) -> int: class AsyncAdaptedQueuePool(QueuePool): - _is_asyncio = True # type: ignore[assignment] - _queue_class: Type[ - sqla_queue.QueueCommon[ConnectionPoolEntry] - ] = sqla_queue.AsyncAdaptedQueue + """An asyncio-compatible version of :class:`.QueuePool`. + + This pool is used by default when using :class:`.AsyncEngine` engines that + were generated from :func:`_asyncio.create_async_engine`. It uses an + asyncio-compatible queue implementation that does not use + ``threading.Lock``. + + The arguments and operation of :class:`.AsyncAdaptedQueuePool` are + otherwise identical to that of :class:`.QueuePool`. + + """ + + _is_asyncio = True + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.AsyncAdaptedQueue + ) _dialect = _AsyncConnDialect() class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): - _queue_class = sqla_queue.FallbackAsyncAdaptedQueue + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue # type: ignore[assignment] # noqa: E501 class NullPool(Pool): - """A Pool which does not pool connections. Instead it literally opens and closes the underlying DB-API connection @@ -272,6 +291,9 @@ class NullPool(Pool): invalidation are not supported by this Pool implementation, since no connections are held persistently. + The :class:`.NullPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. + """ def status(self) -> str: @@ -302,7 +324,6 @@ def dispose(self) -> None: class SingletonThreadPool(Pool): - """A Pool that maintains one connection per thread. Maintains one connection per each thread, never moving a connection to a @@ -320,6 +341,9 @@ class SingletonThreadPool(Pool): scenarios using a SQLite ``:memory:`` database and is not recommended for production use. + The :class:`.SingletonThreadPool` class **is not compatible** with asyncio + and :func:`_asyncio.create_async_engine`. + Options are the same as those of :class:`_pool.Pool`, as well as: @@ -332,7 +356,7 @@ class SingletonThreadPool(Pool): """ - _is_asyncio = False # type: ignore[assignment] + _is_asyncio = False def __init__( self, @@ -422,13 +446,14 @@ def connect(self) -> PoolProxiedConnection: class StaticPool(Pool): - """A Pool of exactly one connection, used for all requests. Reconnect-related functions such as ``recycle`` and connection invalidation (which is also used to support auto-reconnect) are only partially supported right now and may not yield good results. + The :class:`.StaticPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. """ @@ -486,7 +511,6 @@ def _do_get(self) -> ConnectionPoolEntry: class AssertionPool(Pool): - """A :class:`_pool.Pool` that allows at most one checked out connection at any given time. @@ -494,6 +518,9 @@ class AssertionPool(Pool): at a time. Useful for debugging code that is using more connections than desired. + The :class:`.AssertionPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. + """ _conn: Optional[ConnectionPoolEntry] diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 19782bd7cfd..56b90ec99e8 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1,13 +1,11 @@ # schema.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Compatibility namespace for sqlalchemy.sql.schema and related. - -""" +"""Compatibility namespace for sqlalchemy.sql.schema and related.""" from __future__ import annotations @@ -65,6 +63,7 @@ from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .sql.schema import SchemaConst as SchemaConst from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import SchemaVisitable as SchemaVisitable from .sql.schema import Sequence as Sequence from .sql.schema import Table as Table from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index a81509fed74..188f709d7e4 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -1,5 +1,5 @@ # sql/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py index 5c0cc6247a9..0a6f60115f1 100644 --- a/lib/sqlalchemy/sql/_dml_constructors.py +++ b/lib/sqlalchemy/sql/_dml_constructors.py @@ -1,5 +1,5 @@ # sql/_dml_constructors.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -24,10 +24,7 @@ def insert(table: _DMLTableArgument) -> Insert: from sqlalchemy import insert - stmt = ( - insert(user_table). - values(name='username', fullname='Full Username') - ) + stmt = insert(user_table).values(name="username", fullname="Full Username") Similar functionality is available via the :meth:`_expression.TableClause.insert` method on @@ -78,7 +75,7 @@ def insert(table: _DMLTableArgument) -> Insert: :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` - """ + """ # noqa: E501 return Insert(table) @@ -90,9 +87,7 @@ def update(table: _DMLTableArgument) -> Update: from sqlalchemy import update stmt = ( - update(user_table). - where(user_table.c.id == 5). - values(name='user #5') + update(user_table).where(user_table.c.id == 5).values(name="user #5") ) Similar functionality is available via the @@ -109,7 +104,7 @@ def update(table: _DMLTableArgument) -> Update: :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` - """ + """ # noqa: E501 return Update(table) @@ -120,10 +115,7 @@ def delete(table: _DMLTableArgument) -> Delete: from sqlalchemy import delete - stmt = ( - delete(user_table). - where(user_table.c.id == 5) - ) + stmt = delete(user_table).where(user_table.c.id == 5) Similar functionality is available via the :meth:`_expression.TableClause.delete` method on diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 27197375d2d..3359998f3d8 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1,5 +1,5 @@ # sql/_elements_constructors.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,7 +10,6 @@ import typing from typing import Any from typing import Callable -from typing import Iterable from typing import Mapping from typing import Optional from typing import overload @@ -49,6 +48,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument @@ -125,11 +125,8 @@ def and_( # type: ignore[empty-body] from sqlalchemy import and_ stmt = select(users_table).where( - and_( - users_table.c.name == 'wendy', - users_table.c.enrolled == True - ) - ) + and_(users_table.c.name == "wendy", users_table.c.enrolled == True) + ) The :func:`.and_` conjunction is also available using the Python ``&`` operator (though note that compound expressions @@ -137,9 +134,8 @@ def and_( # type: ignore[empty-body] operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == 'wendy') & - (users_table.c.enrolled == True) - ) + (users_table.c.name == "wendy") & (users_table.c.enrolled == True) + ) The :func:`.and_` operation is also implicit in some cases; the :meth:`_expression.Select.where` @@ -147,9 +143,11 @@ def and_( # type: ignore[empty-body] times against a statement, which will have the effect of each clause being combined using :func:`.and_`:: - stmt = select(users_table).\ - where(users_table.c.name == 'wendy').\ - where(users_table.c.enrolled == True) + stmt = ( + select(users_table) + .where(users_table.c.name == "wendy") + .where(users_table.c.enrolled == True) + ) The :func:`.and_` construct must be given at least one positional argument in order to be valid; a :func:`.and_` construct with no @@ -159,6 +157,7 @@ def and_( # type: ignore[empty-body] specified:: from sqlalchemy import true + criteria = and_(true(), *expressions) The above expression will compile to SQL as the expression ``true`` @@ -190,11 +189,8 @@ def and_(*clauses): # noqa: F811 from sqlalchemy import and_ stmt = select(users_table).where( - and_( - users_table.c.name == 'wendy', - users_table.c.enrolled == True - ) - ) + and_(users_table.c.name == "wendy", users_table.c.enrolled == True) + ) The :func:`.and_` conjunction is also available using the Python ``&`` operator (though note that compound expressions @@ -202,9 +198,8 @@ def and_(*clauses): # noqa: F811 operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == 'wendy') & - (users_table.c.enrolled == True) - ) + (users_table.c.name == "wendy") & (users_table.c.enrolled == True) + ) The :func:`.and_` operation is also implicit in some cases; the :meth:`_expression.Select.where` @@ -212,9 +207,11 @@ def and_(*clauses): # noqa: F811 times against a statement, which will have the effect of each clause being combined using :func:`.and_`:: - stmt = select(users_table).\ - where(users_table.c.name == 'wendy').\ - where(users_table.c.enrolled == True) + stmt = ( + select(users_table) + .where(users_table.c.name == "wendy") + .where(users_table.c.enrolled == True) + ) The :func:`.and_` construct must be given at least one positional argument in order to be valid; a :func:`.and_` construct with no @@ -224,6 +221,7 @@ def and_(*clauses): # noqa: F811 specified:: from sqlalchemy import true + criteria = and_(true(), *expressions) The above expression will compile to SQL as the expression ``true`` @@ -241,7 +239,7 @@ def and_(*clauses): # noqa: F811 :func:`.or_` - """ + """ # noqa: E501 return BooleanClauseList.and_(*clauses) @@ -307,9 +305,12 @@ def asc( e.g.:: from sqlalchemy import asc + stmt = select(users_table).order_by(asc(users_table.c.name)) - will produce SQL as:: + will produce SQL as: + + .. sourcecode:: sql SELECT id, name FROM user ORDER BY name ASC @@ -346,9 +347,11 @@ def collate( e.g.:: - collate(mycolumn, 'utf8_bin') + collate(mycolumn, "utf8_bin") + + produces: - produces:: + .. sourcecode:: sql mycolumn COLLATE utf8_bin @@ -373,9 +376,12 @@ def between( E.g.:: from sqlalchemy import between + stmt = select(users_table).where(between(users_table.c.id, 5, 7)) - Would produce SQL resembling:: + Would produce SQL resembling: + + .. sourcecode:: sql SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 @@ -436,16 +442,12 @@ def outparam( return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) -# mypy insists that BinaryExpression and _HasClauseElement protocol overlap. -# they do not. at all. bug in mypy? @overload -def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore - ... +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... @overload -def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: - ... +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ... def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: @@ -497,10 +499,13 @@ def bindparam( from sqlalchemy import bindparam - stmt = select(users_table).\ - where(users_table.c.name == bindparam('username')) + stmt = select(users_table).where( + users_table.c.name == bindparam("username") + ) + + The above statement, when rendered, will produce SQL similar to: - The above statement, when rendered, will produce SQL similar to:: + .. sourcecode:: sql SELECT id, name FROM user WHERE name = :username @@ -508,22 +513,25 @@ def bindparam( would typically be applied at execution time to a method like :meth:`_engine.Connection.execute`:: - result = connection.execute(stmt, username='wendy') + result = connection.execute(stmt, {"username": "wendy"}) Explicit use of :func:`.bindparam` is also common when producing UPDATE or DELETE statements that are to be invoked multiple times, where the WHERE criterion of the statement is to change on each invocation, such as:: - stmt = (users_table.update(). - where(user_table.c.name == bindparam('username')). - values(fullname=bindparam('fullname')) - ) + stmt = ( + users_table.update() + .where(user_table.c.name == bindparam("username")) + .values(fullname=bindparam("fullname")) + ) connection.execute( - stmt, [{"username": "wendy", "fullname": "Wendy Smith"}, - {"username": "jack", "fullname": "Jack Jones"}, - ] + stmt, + [ + {"username": "wendy", "fullname": "Wendy Smith"}, + {"username": "jack", "fullname": "Jack Jones"}, + ], ) SQLAlchemy's Core expression system makes wide use of @@ -532,7 +540,7 @@ def bindparam( coerced into fixed :func:`.bindparam` constructs. For example, given a comparison operation such as:: - expr = users_table.c.name == 'Wendy' + expr = users_table.c.name == "Wendy" The above expression will produce a :class:`.BinaryExpression` construct, where the left side is the :class:`_schema.Column` object @@ -540,9 +548,11 @@ def bindparam( :class:`.BindParameter` representing the literal value:: print(repr(expr.right)) - BindParameter('%(4327771088 name)s', 'Wendy', type_=String()) + BindParameter("%(4327771088 name)s", "Wendy", type_=String()) - The expression above will render SQL such as:: + The expression above will render SQL such as: + + .. sourcecode:: sql user.name = :name_1 @@ -551,10 +561,12 @@ def bindparam( along where it is later used within statement execution. If we invoke a statement like the following:: - stmt = select(users_table).where(users_table.c.name == 'Wendy') + stmt = select(users_table).where(users_table.c.name == "Wendy") result = connection.execute(stmt) - We would see SQL logging output as:: + We would see SQL logging output as: + + .. sourcecode:: sql SELECT "user".id, "user".name FROM "user" @@ -572,9 +584,11 @@ def bindparam( bound placeholders based on the arguments passed, as in:: stmt = users_table.insert() - result = connection.execute(stmt, name='Wendy') + result = connection.execute(stmt, {"name": "Wendy"}) + + The above will produce SQL output as: - The above will produce SQL output as:: + .. sourcecode:: sql INSERT INTO "user" (name) VALUES (%(name)s) {'name': 'Wendy'} @@ -647,12 +661,12 @@ def bindparam( :param quote: True if this parameter name requires quoting and is not currently known as a SQLAlchemy reserved word; this currently - only applies to the Oracle backend, where bound names must + only applies to the Oracle Database backends, where bound names must sometimes be quoted. :param isoutparam: if True, the parameter should be treated like a stored procedure - "OUT" parameter. This applies to backends such as Oracle which + "OUT" parameter. This applies to backends such as Oracle Database which support OUT parameters. :param expanding: @@ -738,16 +752,17 @@ def case( from sqlalchemy import case - stmt = select(users_table).\ - where( - case( - (users_table.c.name == 'wendy', 'W'), - (users_table.c.name == 'jack', 'J'), - else_='E' - ) - ) + stmt = select(users_table).where( + case( + (users_table.c.name == "wendy", "W"), + (users_table.c.name == "jack", "J"), + else_="E", + ) + ) - The above statement will produce SQL resembling:: + The above statement will produce SQL resembling: + + .. sourcecode:: sql SELECT id, name FROM user WHERE CASE @@ -765,14 +780,9 @@ def case( compared against keyed to result expressions. The statement below is equivalent to the preceding statement:: - stmt = select(users_table).\ - where( - case( - {"wendy": "W", "jack": "J"}, - value=users_table.c.name, - else_='E' - ) - ) + stmt = select(users_table).where( + case({"wendy": "W", "jack": "J"}, value=users_table.c.name, else_="E") + ) The values which are accepted as result values in :paramref:`.case.whens` as well as with :paramref:`.case.else_` are @@ -787,20 +797,16 @@ def case( from sqlalchemy import case, literal_column case( - ( - orderline.c.qty > 100, - literal_column("'greaterthan100'") - ), - ( - orderline.c.qty > 10, - literal_column("'greaterthan10'") - ), - else_=literal_column("'lessthan10'") + (orderline.c.qty > 100, literal_column("'greaterthan100'")), + (orderline.c.qty > 10, literal_column("'greaterthan10'")), + else_=literal_column("'lessthan10'"), ) The above will render the given constants without using bound parameters for the result values (but still for the comparison - values), as in:: + values), as in: + + .. sourcecode:: sql CASE WHEN (orderline.qty > :qty_1) THEN 'greaterthan100' @@ -821,8 +827,8 @@ def case( resulting value, e.g.:: case( - (users_table.c.name == 'wendy', 'W'), - (users_table.c.name == 'jack', 'J') + (users_table.c.name == "wendy", "W"), + (users_table.c.name == "jack", "J"), ) In the second form, it accepts a Python dictionary of comparison @@ -830,10 +836,7 @@ def case( :paramref:`.case.value` to be present, and values will be compared using the ``==`` operator, e.g.:: - case( - {"wendy": "W", "jack": "J"}, - value=users_table.c.name - ) + case({"wendy": "W", "jack": "J"}, value=users_table.c.name) :param value: An optional SQL expression which will be used as a fixed "comparison point" for candidate values within a dictionary @@ -846,7 +849,7 @@ def case( expressions evaluate to true. - """ + """ # noqa: E501 return Case(*whens, value=value, else_=else_) @@ -864,7 +867,9 @@ def cast( stmt = select(cast(product_table.c.unit_price, Numeric(10, 4))) - The above statement will produce SQL resembling:: + The above statement will produce SQL resembling: + + .. sourcecode:: sql SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product @@ -933,11 +938,11 @@ def try_cast( from sqlalchemy import select, try_cast, Numeric - stmt = select( - try_cast(product_table.c.unit_price, Numeric(10, 4)) - ) + stmt = select(try_cast(product_table.c.unit_price, Numeric(10, 4))) - The above would render on Microsoft SQL Server as:: + The above would render on Microsoft SQL Server as: + + .. sourcecode:: sql SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4)) FROM product_table @@ -968,7 +973,9 @@ def column( id, name = column("id"), column("name") stmt = select(id, name).select_from("user") - The above statement would produce SQL like:: + The above statement would produce SQL like: + + .. sourcecode:: sql SELECT id, name FROM user @@ -1004,13 +1011,14 @@ def column( from sqlalchemy import table, column, select - user = table("user", - column("id"), - column("name"), - column("description"), + user = table( + "user", + column("id"), + column("name"), + column("description"), ) - stmt = select(user.c.description).where(user.c.name == 'wendy') + stmt = select(user.c.description).where(user.c.name == "wendy") A :func:`_expression.column` / :func:`.table` construct like that illustrated @@ -1057,7 +1065,9 @@ def desc( stmt = select(users_table).order_by(desc(users_table.c.name)) - will produce SQL as:: + will produce SQL as: + + .. sourcecode:: sql SELECT id, name FROM user ORDER BY name DESC @@ -1090,16 +1100,26 @@ def desc( def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce an column-expression-level unary ``DISTINCT`` clause. - This applies the ``DISTINCT`` keyword to an individual column - expression, and is typically contained within an aggregate function, - as in:: + This applies the ``DISTINCT`` keyword to an **individual column + expression** (e.g. not the whole statement), and renders **specifically + in that column position**; this is used for containment within + an aggregate function, as in:: from sqlalchemy import distinct, func - stmt = select(func.count(distinct(users_table.c.name))) - The above would produce an expression resembling:: + stmt = select(users_table.c.id, func.count(distinct(users_table.c.name))) - SELECT COUNT(DISTINCT name) FROM user + The above would produce an statement resembling: + + .. sourcecode:: sql + + SELECT user.id, count(DISTINCT user.name) FROM user + + .. tip:: The :func:`_sql.distinct` function does **not** apply DISTINCT + to the full SELECT statement, instead applying a DISTINCT modifier + to **individual column expressions**. For general ``SELECT DISTINCT`` + support, use the + :meth:`_sql.Select.distinct` method on :class:`_sql.Select`. The :func:`.distinct` function is also available as a column-level method, e.g. :meth:`_expression.ColumnElement.distinct`, as in:: @@ -1122,7 +1142,7 @@ def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :data:`.func` - """ + """ # noqa: E501 return UnaryExpression._create_distinct(expr) @@ -1152,6 +1172,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: :param field: The field to extract. + .. warning:: This field is used as a literal SQL string. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + :param expr: A column or Python scalar expression serving as the right side of the ``EXTRACT`` expression. @@ -1160,9 +1183,10 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: from sqlalchemy import extract from sqlalchemy import table, column - logged_table = table("user", - column("id"), - column("date_created"), + logged_table = table( + "user", + column("id"), + column("date_created"), ) stmt = select(logged_table.c.id).where( @@ -1174,9 +1198,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: Similarly, one can also select an extracted component:: - stmt = select( - extract("YEAR", logged_table.c.date_created) - ).where(logged_table.c.id == 1) + stmt = select(extract("YEAR", logged_table.c.date_created)).where( + logged_table.c.id == 1 + ) The implementation of ``EXTRACT`` may vary across database backends. Users are reminded to consult their database documentation. @@ -1235,7 +1259,8 @@ def funcfilter( E.g.:: from sqlalchemy import funcfilter - funcfilter(func.count(1), MyClass.name == 'some name') + + funcfilter(func.count(1), MyClass.name == "some name") Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". @@ -1292,10 +1317,11 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: from sqlalchemy import desc, nulls_first - stmt = select(users_table).order_by( - nulls_first(desc(users_table.c.name))) + stmt = select(users_table).order_by(nulls_first(desc(users_table.c.name))) + + The SQL expression from the above would resemble: - The SQL expression from the above would resemble:: + .. sourcecode:: sql SELECT id, name FROM user ORDER BY name DESC NULLS FIRST @@ -1306,7 +1332,8 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: function version, as in:: stmt = select(users_table).order_by( - users_table.c.name.desc().nulls_first()) + users_table.c.name.desc().nulls_first() + ) .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from :func:`.nullsfirst` in previous releases. @@ -1322,7 +1349,7 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :meth:`_expression.Select.order_by` - """ + """ # noqa: E501 return UnaryExpression._create_nulls_first(column) @@ -1336,10 +1363,11 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: from sqlalchemy import desc, nulls_last - stmt = select(users_table).order_by( - nulls_last(desc(users_table.c.name))) + stmt = select(users_table).order_by(nulls_last(desc(users_table.c.name))) - The SQL expression from the above would resemble:: + The SQL expression from the above would resemble: + + .. sourcecode:: sql SELECT id, name FROM user ORDER BY name DESC NULLS LAST @@ -1349,8 +1377,7 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: rather than as its standalone function version, as in:: - stmt = select(users_table).order_by( - users_table.c.name.desc().nulls_last()) + stmt = select(users_table).order_by(users_table.c.name.desc().nulls_last()) .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from :func:`.nullslast` in previous releases. @@ -1366,7 +1393,7 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :meth:`_expression.Select.order_by` - """ + """ # noqa: E501 return UnaryExpression._create_nulls_last(column) @@ -1381,11 +1408,8 @@ def or_( # type: ignore[empty-body] from sqlalchemy import or_ stmt = select(users_table).where( - or_( - users_table.c.name == 'wendy', - users_table.c.name == 'jack' - ) - ) + or_(users_table.c.name == "wendy", users_table.c.name == "jack") + ) The :func:`.or_` conjunction is also available using the Python ``|`` operator (though note that compound expressions @@ -1393,9 +1417,8 @@ def or_( # type: ignore[empty-body] operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == 'wendy') | - (users_table.c.name == 'jack') - ) + (users_table.c.name == "wendy") | (users_table.c.name == "jack") + ) The :func:`.or_` construct must be given at least one positional argument in order to be valid; a :func:`.or_` construct with no @@ -1405,6 +1428,7 @@ def or_( # type: ignore[empty-body] specified:: from sqlalchemy import false + or_criteria = or_(false(), *expressions) The above expression will compile to SQL as the expression ``false`` @@ -1436,11 +1460,8 @@ def or_(*clauses): # noqa: F811 from sqlalchemy import or_ stmt = select(users_table).where( - or_( - users_table.c.name == 'wendy', - users_table.c.name == 'jack' - ) - ) + or_(users_table.c.name == "wendy", users_table.c.name == "jack") + ) The :func:`.or_` conjunction is also available using the Python ``|`` operator (though note that compound expressions @@ -1448,9 +1469,8 @@ def or_(*clauses): # noqa: F811 operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == 'wendy') | - (users_table.c.name == 'jack') - ) + (users_table.c.name == "wendy") | (users_table.c.name == "jack") + ) The :func:`.or_` construct must be given at least one positional argument in order to be valid; a :func:`.or_` construct with no @@ -1460,6 +1480,7 @@ def or_(*clauses): # noqa: F811 specified:: from sqlalchemy import false + or_criteria = or_(false(), *expressions) The above expression will compile to SQL as the expression ``false`` @@ -1477,26 +1498,17 @@ def or_(*clauses): # noqa: F811 :func:`.and_` - """ + """ # noqa: E501 return BooleanClauseList.or_(*clauses) def over( element: FunctionElement[_T], - partition_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, - order_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: r"""Produce an :class:`.Over` object against a function. @@ -1508,19 +1520,23 @@ def over( func.row_number().over(order_by=mytable.c.some_column) - Would produce:: + Would produce: + + .. sourcecode:: sql ROW_NUMBER() OVER(ORDER BY some_column) - Ranges are also possible using the :paramref:`.expression.over.range_` - and :paramref:`.expression.over.rows` parameters. These + Ranges are also possible using the :paramref:`.expression.over.range_`, + :paramref:`.expression.over.rows`, and :paramref:`.expression.over.groups` + parameters. These mutually-exclusive parameters each accept a 2-tuple, which contains a combination of integers and None:: - func.row_number().over( - order_by=my_table.c.some_column, range_=(None, 0)) + func.row_number().over(order_by=my_table.c.some_column, range_=(None, 0)) + + The above would produce: - The above would produce:: + .. sourcecode:: sql ROW_NUMBER() OVER(ORDER BY some_column RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) @@ -1531,19 +1547,23 @@ def over( * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING:: - func.row_number().over(order_by='x', range_=(-5, 10)) + func.row_number().over(order_by="x", range_=(-5, 10)) * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW:: - func.row_number().over(order_by='x', rows=(None, 0)) + func.row_number().over(order_by="x", rows=(None, 0)) * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING:: - func.row_number().over(order_by='x', range_=(-2, None)) + func.row_number().over(order_by="x", range_=(-2, None)) * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: - func.row_number().over(order_by='x', range_=(1, 3)) + func.row_number().over(order_by="x", range_=(1, 3)) + + * GROUPS BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: + + func.row_number().over(order_by="x", groups=(1, 3)) :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`, or other compatible construct. @@ -1556,10 +1576,14 @@ def over( :param range\_: optional range clause for the window. This is a tuple value which can contain integer values or ``None``, and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause. - :param rows: optional rows clause for the window. This is a tuple value which can contain integer values or None, and will render a ROWS BETWEEN PRECEDING / FOLLOWING clause. + :param groups: optional groups clause for the window. This is a + tuple value which can contain integer values or ``None``, + and will render a GROUPS BETWEEN PRECEDING / FOLLOWING clause. + + .. versionadded:: 2.0.40 This function is also available from the :data:`~.expression.func` construct itself via the :meth:`.FunctionElement.over` method. @@ -1572,8 +1596,8 @@ def over( :func:`_expression.within_group` - """ - return Over(element, partition_by, order_by, range_, rows) + """ # noqa: E501 + return Over(element, partition_by, order_by, range_, rows, groups) @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") @@ -1603,7 +1627,7 @@ def text(text: str) -> TextClause: E.g.:: t = text("SELECT * FROM users WHERE id=:user_id") - result = connection.execute(t, user_id=12) + result = connection.execute(t, {"user_id": 12}) For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape:: @@ -1621,9 +1645,11 @@ def text(text: str) -> TextClause: method allows specification of return columns including names and types:: - t = text("SELECT * FROM users WHERE id=:user_id").\ - bindparams(user_id=7).\ - columns(id=Integer, name=String) + t = ( + text("SELECT * FROM users WHERE id=:user_id") + .bindparams(user_id=7) + .columns(id=Integer, name=String) + ) for id, name in connection.execute(t): print(id, name) @@ -1633,7 +1659,7 @@ def text(text: str) -> TextClause: such as for the WHERE clause of a SELECT statement:: s = select(users.c.id, users.c.name).where(text("id=:user_id")) - result = connection.execute(s, user_id=12) + result = connection.execute(s, {"user_id": 12}) :func:`_expression.text` is also used for the construction of a full, standalone statement using plain text. @@ -1705,9 +1731,7 @@ def tuple_( from sqlalchemy import tuple_ - tuple_(table.c.col1, table.c.col2).in_( - [(1, 2), (5, 12), (10, 19)] - ) + tuple_(table.c.col1, table.c.col2).in_([(1, 2), (5, 12), (10, 19)]) .. versionchanged:: 1.3.6 Added support for SQLite IN tuples. @@ -1757,10 +1781,9 @@ def type_coerce( :meth:`_expression.ColumnElement.label`:: stmt = select( - type_coerce(log_table.date_string, StringDateTime()).label('date') + type_coerce(log_table.date_string, StringDateTime()).label("date") ) - A type that features bound-value handling will also have that behavior take effect when literal values or :func:`.bindparam` constructs are passed to :func:`.type_coerce` as targets. @@ -1821,11 +1844,10 @@ def within_group( the :meth:`.FunctionElement.within_group` method, e.g.:: from sqlalchemy import within_group + stmt = select( department.c.id, - func.percentile_cont(0.5).within_group( - department.c.salary.desc() - ) + func.percentile_cont(0.5).within_group(department.c.salary.desc()), ) The above statement would produce SQL similar to diff --git a/lib/sqlalchemy/sql/_orm_types.py b/lib/sqlalchemy/sql/_orm_types.py index 90986ec0ccb..c37d805ef3f 100644 --- a/lib/sqlalchemy/sql/_orm_types.py +++ b/lib/sqlalchemy/sql/_orm_types.py @@ -1,5 +1,5 @@ # sql/_orm_types.py -# Copyright (C) 2022 the SQLAlchemy authors and contributors +# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py index edff0d66910..9e1a084a3f5 100644 --- a/lib/sqlalchemy/sql/_py_util.py +++ b/lib/sqlalchemy/sql/_py_util.py @@ -1,5 +1,5 @@ # sql/_py_util.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 41e8b6eb164..ae83efa5d79 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -1,5 +1,5 @@ # sql/_selectable_constructors.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,6 @@ from typing import overload from typing import Tuple from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union from . import coercions @@ -47,6 +46,7 @@ from ._typing import _T7 from ._typing import _T8 from ._typing import _T9 + from ._typing import _TP from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE @@ -55,9 +55,6 @@ from .selectable import SelectBase -_T = TypeVar("_T", bound=Any) - - def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -106,9 +103,28 @@ def cte( ) +# TODO: mypy requires the _TypedSelectable overloads in all compound select +# constructors since _SelectStatementForCompoundArgument includes +# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]] +# pyright does not have this issue +_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"] + + +@overload +def except_( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def except_( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def except_( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -121,9 +137,21 @@ def except_( return CompoundSelect._create_except(*selects) +@overload +def except_all( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def except_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def except_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -155,16 +183,16 @@ def exists( :meth:`_sql.SelectBase.exists` method:: exists_criteria = ( - select(table2.c.col2). - where(table1.c.col1 == table2.c.col2). - exists() + select(table2.c.col2).where(table1.c.col1 == table2.c.col2).exists() ) The EXISTS criteria is then used inside of an enclosing SELECT:: stmt = select(table1.c.col1).where(exists_criteria) - The above statement will then be of the form:: + The above statement will then be of the form: + + .. sourcecode:: sql SELECT col1 FROM table1 WHERE EXISTS (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1) @@ -181,9 +209,21 @@ def exists( return Exists(__argument) +@overload +def intersect( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def intersect( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def intersect( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -196,9 +236,21 @@ def intersect( return CompoundSelect._create_intersect(*selects) +@overload def intersect_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def intersect_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def intersect_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -225,11 +277,14 @@ def join( E.g.:: - j = join(user_table, address_table, - user_table.c.id == address_table.c.user_id) + j = join( + user_table, address_table, user_table.c.id == address_table.c.user_id + ) stmt = select(user_table).select_from(j) - would emit SQL along the lines of:: + would emit SQL along the lines of: + + .. sourcecode:: sql SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id @@ -263,7 +318,7 @@ def join( :class:`_expression.Join` - the type of object produced. - """ + """ # noqa: E501 return Join(left, right, onclause, isouter, full) @@ -330,20 +385,19 @@ def outerjoin( @overload -def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: - ... +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: ... @overload -def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: - ... +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] +) -> Select[Tuple[_T0, _T1]]: ... @overload def select( __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] -) -> Select[Tuple[_T0, _T1, _T2]]: - ... +) -> Select[Tuple[_T0, _T1, _T2]]: ... @overload @@ -352,8 +406,7 @@ def select( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], -) -> Select[Tuple[_T0, _T1, _T2, _T3]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ... @overload @@ -363,8 +416,7 @@ def select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload @@ -375,8 +427,7 @@ def select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload @@ -388,8 +439,7 @@ def select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload @@ -402,8 +452,7 @@ def select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... @overload @@ -417,8 +466,7 @@ def select( __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: ... @overload @@ -433,16 +481,16 @@ def select( __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], __ent9: _TCCA[_T9], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: - ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: ... # END OVERLOADED FUNCTIONS select @overload -def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: - ... +def select( + *entities: _ColumnsClauseArgument[Any], **__kw: Any +) -> Select[Any]: ... def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: @@ -536,13 +584,14 @@ class via the from sqlalchemy import func selectable = people.tablesample( - func.bernoulli(1), - name='alias', - seed=func.random()) + func.bernoulli(1), name="alias", seed=func.random() + ) stmt = select(selectable.c.people_id) Assuming ``people`` with a column ``people_id``, the above - statement would render as:: + statement would render as: + + .. sourcecode:: sql SELECT alias.people_id FROM people AS alias TABLESAMPLE bernoulli(:bernoulli_1) @@ -560,9 +609,21 @@ class via the return TableSample._factory(selectable, sampling, name=name, seed=seed) +@overload def union( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def union( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def union( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -582,9 +643,21 @@ def union( return CompoundSelect._create_union(*selects) +@overload def union_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def union_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def union_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of @@ -618,14 +691,14 @@ def values( from sqlalchemy import column from sqlalchemy import values + from sqlalchemy import Integer + from sqlalchemy import String value_expr = values( - column('id', Integer), - column('name', String), - name="my_values" - ).data( - [(1, 'name1'), (2, 'name2'), (3, 'name3')] - ) + column("id", Integer), + column("name", String), + name="my_values", + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) :param \*columns: column expressions, typically composed using :func:`_expression.column` objects. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index c9e183058e6..8e3c66e553f 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,5 +1,5 @@ # sql/_typing.py -# Copyright (C) 2022 the SQLAlchemy authors and contributors +# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,6 +11,8 @@ from typing import Any from typing import Callable from typing import Dict +from typing import Generic +from typing import Iterable from typing import Mapping from typing import NoReturn from typing import Optional @@ -51,10 +53,10 @@ from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement - from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column from .selectable import Alias + from .selectable import CompoundSelect from .selectable import CTE from .selectable import FromClause from .selectable import Join @@ -68,9 +70,14 @@ from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Connection + from ..engine import Dialect + from ..engine import Engine + from ..engine.mock import MockConnection from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _CE = TypeVar("_CE", bound="ColumnElement[Any]") @@ -78,18 +85,25 @@ _CLE = TypeVar("_CLE", bound="ClauseElement") -class _HasClauseElement(Protocol): +class _HasClauseElement(Protocol, Generic[_T_co]): """indicates a class that has a __clause_element__() method""" - def __clause_element__(self) -> ColumnsClauseRole: - ... + def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: ... class _CoreAdapterProto(Protocol): """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" - def __call__(self, obj: _CE) -> _CE: - ... + def __call__(self, obj: _CE) -> _CE: ... + + +class _HasDialect(Protocol): + """protocol for Engine/Connection-like objects that have dialect + attribute. + """ + + @property + def dialect(self) -> Dialect: ... # match column types that are not ORM entities @@ -97,6 +111,7 @@ def __call__(self, obj: _CE) -> _CE: "_NOT_ENTITY", int, str, + bool, "datetime", "date", "time", @@ -106,13 +121,15 @@ def __call__(self, obj: _CE) -> _CE: "Decimal", ) +_StarOrOne = Literal["*", 1] + _MAYBE_ENTITY = TypeVar( "_MAYBE_ENTITY", roles.ColumnsClauseRole, - Literal["*", 1], + _StarOrOne, Type[Any], - Inspectable[_HasClauseElement], - _HasClauseElement, + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], ) @@ -126,7 +143,7 @@ def __call__(self, obj: _CE) -> _CE: str, "TextClause", "ColumnElement[_T]", - _HasClauseElement, + _HasClauseElement[_T], roles.ExpressionElementRole[_T], ] @@ -134,10 +151,10 @@ def __call__(self, obj: _CE) -> _CE: roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, "SQLCoreOperations[_T]", - Literal["*", 1], + _StarOrOne, Type[_T], - Inspectable[_HasClauseElement], - _HasClauseElement, + Inspectable[_HasClauseElement[_T]], + _HasClauseElement[_T], ] """open-ended SELECT columns clause argument. @@ -171,9 +188,10 @@ def __call__(self, obj: _CE) -> _CE: _ColumnExpressionArgument = Union[ "ColumnElement[_T]", - _HasClauseElement, + _HasClauseElement[_T], "SQLCoreOperations[_T]", roles.ExpressionElementRole[_T], + roles.TypedColumnsClauseRole[_T], Callable[[], "ColumnElement[_T]"], "LambdaElement", ] @@ -198,6 +216,12 @@ def __call__(self, obj: _CE) -> _CE: _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] +_ByArgument = Union[ + Iterable[_ColumnExpressionOrStrLabelArgument[Any]], + _ColumnExpressionOrStrLabelArgument[Any], +] +"""Used for keyword-based ``order_by`` and ``partition_by`` parameters.""" + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" @@ -205,8 +229,8 @@ def __call__(self, obj: _CE) -> _CE: _FromClauseArgument = Union[ roles.FromClauseRole, Type[Any], - Inspectable[_HasClauseElement], - _HasClauseElement, + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], ] """A FROM clause, like we would send to select().select_from(). @@ -227,13 +251,15 @@ def __call__(self, obj: _CE) -> _CE: """ _SelectStatementForCompoundArgument = Union[ - "SelectBase", roles.CompoundElementRole + "Select[_TP]", + "CompoundSelect[_TP]", + roles.CompoundElementRole, ] """SELECT statement acceptable by ``union()`` and other SQL set operations""" _DMLColumnArgument = Union[ str, - _HasClauseElement, + _HasClauseElement[Any], roles.DMLColumnRole, "SQLCoreOperations[Any]", ] @@ -264,8 +290,8 @@ def __call__(self, obj: _CE) -> _CE: "Alias", "CTE", Type[Any], - Inspectable[_HasClauseElement], - _HasClauseElement, + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], ] _PropagateAttrsType = util.immutabledict[str, Any] @@ -278,58 +304,51 @@ def __call__(self, obj: _CE) -> _CE: _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] +_CreateDropBind = Union["Engine", "Connection", "MockConnection"] + if TYPE_CHECKING: - def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: - ... + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... - def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: - ... + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: ... - def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: - ... + def is_named_from_clause( + t: FromClauseRole, + ) -> TypeGuard[NamedFromClause]: ... - def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: - ... + def is_column_element( + c: ClauseElement, + ) -> TypeGuard[ColumnElement[Any]]: ... def is_keyed_column_element( c: ClauseElement, - ) -> TypeGuard[KeyedColumnElement[Any]]: - ... + ) -> TypeGuard[KeyedColumnElement[Any]]: ... - def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: - ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... - def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: - ... + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: ... - def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: - ... + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: ... - def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: - ... + def is_table_value_type( + t: TypeEngine[Any], + ) -> TypeGuard[TableValueType]: ... - def is_selectable(t: Any) -> TypeGuard[Selectable]: - ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: ... def is_select_base( - t: Union[Executable, ReturnsRows] - ) -> TypeGuard[SelectBase]: - ... + t: Union[Executable, ReturnsRows], + ) -> TypeGuard[SelectBase]: ... def is_select_statement( - t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select[Any]]: - ... + t: Union[Executable, ReturnsRows], + ) -> TypeGuard[Select[Any]]: ... - def is_table(t: FromClause) -> TypeGuard[TableClause]: - ... + def is_table(t: FromClause) -> TypeGuard[TableClause]: ... - def is_subquery(t: FromClause) -> TypeGuard[Subquery]: - ... + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... - def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: - ... + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: ... else: is_sql_compiler = operator.attrgetter("is_sql") @@ -357,7 +376,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: return hasattr(s, "quote") -def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: +def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement[Any]]: return hasattr(s, "__clause_element__") @@ -380,20 +399,17 @@ def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn: @overload def Nullable( val: "SQLCoreOperations[_T]", -) -> "SQLCoreOperations[Optional[_T]]": - ... +) -> "SQLCoreOperations[Optional[_T]]": ... @overload def Nullable( val: roles.ExpressionElementRole[_T], -) -> roles.ExpressionElementRole[Optional[_T]]: - ... +) -> roles.ExpressionElementRole[Optional[_T]]: ... @overload -def Nullable(val: Type[_T]) -> Type[Optional[_T]]: - ... +def Nullable(val: Type[_T]) -> Type[Optional[_T]]: ... def Nullable( @@ -417,25 +433,21 @@ def Nullable( @overload def NotNullable( val: "SQLCoreOperations[Optional[_T]]", -) -> "SQLCoreOperations[_T]": - ... +) -> "SQLCoreOperations[_T]": ... @overload def NotNullable( val: roles.ExpressionElementRole[Optional[_T]], -) -> roles.ExpressionElementRole[_T]: - ... +) -> roles.ExpressionElementRole[_T]: ... @overload -def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: - ... +def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: ... @overload -def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: - ... +def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: ... def NotNullable( diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 08ff47d3d64..bf445ff330d 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -1,5 +1,5 @@ # sql/annotation.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -67,16 +67,14 @@ def _deannotate( self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -99,9 +97,11 @@ def _gen_annotations_cache_key( tuple( ( key, - value._gen_cache_key(anon_map, []) - if isinstance(value, HasCacheKey) - else value, + ( + value._gen_cache_key(anon_map, []) + if isinstance(value, HasCacheKey) + else value + ), ) for key, value in [ (key, self._annotations[key]) @@ -119,8 +119,7 @@ class SupportsWrappingAnnotations(SupportsAnnotations): if TYPE_CHECKING: @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... def _annotate(self, values: _AnnotationDict) -> Self: """return a copy of this ClauseElement with annotations @@ -141,16 +140,14 @@ def _deannotate( self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -214,16 +211,14 @@ def _deannotate( self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -316,16 +311,14 @@ def _deannotate( self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> Annotated: - ... + ) -> Annotated: ... def _deannotate( self, @@ -395,9 +388,9 @@ def entity_namespace(self) -> _EntityNamespace: # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes: Dict[ - Type[SupportsWrappingAnnotations], Type[Annotated] -] = {} +annotated_classes: Dict[Type[SupportsWrappingAnnotations], Type[Annotated]] = ( + {} +) _SA = TypeVar("_SA", bound="SupportsAnnotations") @@ -487,15 +480,13 @@ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: @overload def _deep_deannotate( element: Literal[None], values: Optional[Sequence[str]] = None -) -> Literal[None]: - ... +) -> Literal[None]: ... @overload def _deep_deannotate( element: _SA, values: Optional[Sequence[str]] = None -) -> _SA: - ... +) -> _SA: ... def _deep_deannotate( diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 104c5958a07..e27296b5332 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1,14 +1,12 @@ # sql/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -"""Foundational utilities common to many sql modules. - -""" +"""Foundational utilities common to many sql modules.""" from __future__ import annotations @@ -68,11 +66,11 @@ from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument from ._typing import _CLE + from .compiler import SQLCompiler from .elements import BindParameter from .elements import ClauseList from .elements import ColumnClause # noqa from .elements import ColumnElement - from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -155,14 +153,12 @@ def _from_column_default( class _EntityNamespace(Protocol): - def __getattr__(self, key: str) -> SQLCoreOperations[Any]: - ... + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: ... class _HasEntityNamespace(Protocol): @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: @@ -261,8 +257,7 @@ def _select_iterables( class _GenerativeType(compat_typing.Protocol): - def _generate(self) -> Self: - ... + def _generate(self) -> Self: ... def _generative(fn: _Fn) -> _Fn: @@ -366,6 +361,8 @@ class _DialectArgView(MutableMapping[str, Any]): """ + __slots__ = ("obj",) + def __init__(self, obj): self.obj = obj @@ -484,7 +481,7 @@ def argument_for(cls, dialect_name, argument_name, default): Index.argument_for("mydialect", "length", None) - some_index = Index('a', 'b', mydialect_length=5) + some_index = Index("a", "b", mydialect_length=5) The :meth:`.DialectKWArgs.argument_for` method is a per-argument way adding extra arguments to the @@ -524,7 +521,7 @@ def argument_for(cls, dialect_name, argument_name, default): construct_arg_dictionary[cls] = {} construct_arg_dictionary[cls][argument_name] = default - @util.memoized_property + @property def dialect_kwargs(self): """A collection of keyword arguments specified as dialect-specific options to this construct. @@ -552,14 +549,15 @@ def kwargs(self): _kw_registry = util.PopulateDict(_kw_reg_for_dialect) - def _kw_reg_for_dialect_cls(self, dialect_name): + @classmethod + def _kw_reg_for_dialect_cls(cls, dialect_name): construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] d = _DialectArgDict() if construct_arg_dictionary is None: d._defaults.update({"*": None}) else: - for cls in reversed(self.__class__.__mro__): + for cls in reversed(cls.__mro__): if cls in construct_arg_dictionary: d._defaults.update(construct_arg_dictionary[cls]) return d @@ -573,7 +571,7 @@ def dialect_options(self): and ````. For example, the ``postgresql_where`` argument would be locatable as:: - arg = my_object.dialect_options['postgresql']['where'] + arg = my_object.dialect_options["postgresql"]["where"] .. versionadded:: 0.9.2 @@ -583,9 +581,7 @@ def dialect_options(self): """ - return util.PopulateDict( - util.portable_instancemethod(self._kw_reg_for_dialect_cls) - ) + return util.PopulateDict(self._kw_reg_for_dialect_cls) def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: # validate remaining kwargs that they all specify DB prefixes @@ -661,7 +657,9 @@ class CompileState: _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, statement: Executable, compiler: SQLCompiler, **kw: Any + ) -> CompileState: # factory construction. if statement._propagate_attrs: @@ -801,14 +799,11 @@ def __add__(self, other): if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... - def __setattr__(self, key: str, value: Any) -> None: - ... + def __setattr__(self, key: str, value: Any) -> None: ... - def __delattr__(self, key: str) -> None: - ... + def __delattr__(self, key: str) -> None: ... class Options(metaclass=_MetaOptions): @@ -924,11 +919,7 @@ def from_execution_options( execution_options, ) = QueryContext.default_load_options.from_execution_options( "_sa_orm_load_options", - { - "populate_existing", - "autoflush", - "yield_per" - }, + {"populate_existing", "autoflush", "yield_per"}, execution_options, statement._execution_options, ) @@ -966,14 +957,11 @@ def from_execution_options( if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... - def __setattr__(self, key: str, value: Any) -> None: - ... + def __setattr__(self, key: str, value: Any) -> None: ... - def __delattr__(self, key: str) -> None: - ... + def __delattr__(self, key: str) -> None: ... class CacheableOptions(Options, HasCacheKey): @@ -1038,6 +1026,7 @@ class Executable(roles.StatementRole): ] is_select = False + is_from_statement = False is_update = False is_insert = False is_text = False @@ -1058,24 +1047,21 @@ def _compile_w_cache( **kw: Any, ) -> Tuple[ Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats - ]: - ... + ]: ... def _execute_on_connection( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... def _execute_on_scalar( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> Any: - ... + ) -> Any: ... @util.ro_non_memoized_property def _all_selected_columns(self): @@ -1179,13 +1165,12 @@ def execution_options( render_nulls: bool = ..., is_delete_using: bool = ..., is_update_from: bool = ..., + preserve_rowcount: bool = False, **opt: Any, - ) -> Self: - ... + ) -> Self: ... @overload - def execution_options(self, **opt: Any) -> Self: - ... + def execution_options(self, **opt: Any) -> Self: ... @_generative def execution_options(self, **kw: Any) -> Self: @@ -1237,6 +1222,7 @@ def execution_options(self, **kw: Any) -> Self: from sqlalchemy import event + @event.listens_for(some_engine, "before_execute") def _process_opt(conn, statement, multiparams, params, execution_options): "run a SQL function before invoking a statement" @@ -1338,8 +1324,19 @@ def _set_parent_with_dispatch( self.dispatch.after_parent_attach(self, parent) +class SchemaVisitable(SchemaEventTarget, visitors.Visitable): + """Base class for elements that are targets of a :class:`.SchemaVisitor`. + + .. versionadded:: 2.0.41 + + """ + + class SchemaVisitor(ClauseVisitor): - """Define the visiting for ``SchemaItem`` objects.""" + """Define the visiting for ``SchemaItem`` and more + generally ``SchemaVisitable`` objects. + + """ __traverse_options__ = {"schema_visitor": True} @@ -1366,7 +1363,7 @@ class _SentinelColumnCharacterization(NamedTuple): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") +_COL = TypeVar("_COL", bound="ColumnElement[Any]") class _ColumnMetrics(Generic[_COL_co]): @@ -1488,14 +1485,14 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): mean either two columns with the same key, in which case the column returned by key access is **arbitrary**:: - >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> x1, x2 = Column("x", Integer), Column("x", Integer) >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)]) >>> list(cc) [Column('x', Integer(), table=None), Column('x', Integer(), table=None)] - >>> cc['x'] is x1 + >>> cc["x"] is x1 False - >>> cc['x'] is x2 + >>> cc["x"] is x2 True Or it can also mean the same column multiple times. These cases are @@ -1591,20 +1588,17 @@ def __iter__(self) -> Iterator[_COL_co]: return iter([col for _, col, _ in self._collection]) @overload - def __getitem__(self, key: Union[str, int]) -> _COL_co: - ... + def __getitem__(self, key: Union[str, int]) -> _COL_co: ... @overload def __getitem__( self, key: Tuple[Union[str, int], ...] - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: - ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... @overload def __getitem__( self, key: slice - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: - ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... def __getitem__( self, key: Union[str, int, slice, Tuple[Union[str, int], ...]] @@ -1657,9 +1651,15 @@ def compare(self, other: ColumnCollection[Any, Any]) -> bool: def __eq__(self, other: Any) -> bool: return self.compare(other) + @overload + def get(self, key: str, default: None = None) -> Optional[_COL_co]: ... + + @overload + def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ... + def get( - self, key: str, default: Optional[_COL_co] = None - ) -> Optional[_COL_co]: + self, key: str, default: Optional[_COL] = None + ) -> Optional[Union[_COL_co, _COL]]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1940,16 +1940,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - def add( - self, column: ColumnElement[Any], key: Optional[str] = None + def add( # type: ignore[override] + self, column: _NAMEDCOL, key: Optional[str] = None ) -> None: - named_column = cast(_NAMEDCOL, column) - if key is not None and named_column.key != key: + if key is not None and column.key != key: raise exc.ArgumentError( "DedupeColumnCollection requires columns be under " "the same key as their .key" ) - key = named_column.key + key = column.key if key is None: raise exc.ArgumentError( @@ -1959,17 +1958,17 @@ def add( if key in self._index: existing = self._index[key][1] - if existing is named_column: + if existing is column: return - self.replace(named_column) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(named_column, "proxy_set") + util.memoized_property.reset(column, "proxy_set") else: - self._append_new_column(key, named_column) + self._append_new_column(key, column) def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None: l = len(self._collection) @@ -2044,8 +2043,8 @@ def replace( e.g.:: - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) + t = Table("sometable", metadata, Column("col1", Integer)) + t.columns.replace(Column("col1", Integer, key="columnone")) will remove the original 'col1' from the collection, and add the new column under the name 'columnname'. @@ -2148,12 +2147,12 @@ def __eq__(self, other): l.append(c == local) return elements.and_(*l) - def __hash__(self): + def __hash__(self): # type: ignore[override] return hash(tuple(x for x in self)) def _entity_namespace( - entity: Union[_HasEntityNamespace, ExternallyTraversible] + entity: Union[_HasEntityNamespace, ExternallyTraversible], ) -> _EntityNamespace: """Return the nearest .entity_namespace for the given entity. diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 500e3e4dd72..cec0450aa61 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -1,5 +1,5 @@ # sql/cache_key.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,6 +11,7 @@ from itertools import zip_longest import typing from typing import Any +from typing import Callable from typing import Dict from typing import Iterable from typing import Iterator @@ -36,6 +37,7 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ClauseElement + from .elements import ColumnElement from .visitors import _TraverseInternalsType from ..engine.interfaces import _CoreSingleExecuteParams @@ -43,8 +45,7 @@ class _CacheKeyTraversalDispatchType(Protocol): def __call__( s, self: HasCacheKey, visitor: _CacheKeyTraversal - ) -> CacheKey: - ... + ) -> _CacheKeyTraversalDispatchTypeReturn: ... class CacheConst(enum.Enum): @@ -75,6 +76,18 @@ class CacheTraverseTarget(enum.Enum): ANON_NAME, ) = tuple(CacheTraverseTarget) +_CacheKeyTraversalDispatchTypeReturn = Sequence[ + Tuple[ + str, + Any, + Union[ + Callable[..., Tuple[Any, ...]], + CacheTraverseTarget, + InternalTraversal, + ], + ] +] + class HasCacheKey: """Mixin for objects which can produce a cache key. @@ -290,11 +303,13 @@ def _gen_cache_key( result += ( attrname, obj["compile_state_plugin"], - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None, + ( + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None + ), ) elif meth is InternalTraversal.dp_annotations_key: # obj is here is the _annotations dict. Table uses @@ -324,7 +339,7 @@ def _gen_cache_key( ), ) else: - result += meth( + result += meth( # type: ignore attrname, obj, self, anon_map, bindparams ) return result @@ -501,7 +516,7 @@ def _whats_different(self, other: CacheKey) -> Iterator[str]: e2, ) else: - pickup_index = stack.pop(-1) + stack.pop(-1) break def _diff(self, other: CacheKey) -> str: @@ -543,18 +558,17 @@ def _generate_param_dict(self) -> Dict[str, Any]: _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} + @util.preload_module("sqlalchemy.sql.elements") def _apply_params_to_element( - self, original_cache_key: CacheKey, target_element: ClauseElement - ) -> ClauseElement: - if target_element._is_immutable: + self, original_cache_key: CacheKey, target_element: ColumnElement[Any] + ) -> ColumnElement[Any]: + if target_element._is_immutable or original_cache_key is self: return target_element - translate = { - k.key: v.value - for k, v in zip(original_cache_key.bindparams, self.bindparams) - } - - return target_element.params(translate) + elements = util.preloaded.sql_elements + return elements._OverrideBinds( + target_element, self.bindparams, original_cache_key.bindparams + ) def _ad_hoc_cache_key_from_args( @@ -606,9 +620,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): InternalTraversal.dp_memoized_select_entities ) - visit_string = ( - visit_boolean - ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_string = visit_boolean = visit_operator = visit_plain_obj = ( + CACHE_IN_PLACE + ) visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY visit_anon_name = ANON_NAME @@ -655,9 +669,11 @@ def visit_multi( ) -> Tuple[Any, ...]: return ( attrname, - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj, + ( + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj + ), ) def visit_multi_list( @@ -671,9 +687,11 @@ def visit_multi_list( return ( attrname, tuple( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem + ( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + ) for elem in obj ), ) @@ -834,12 +852,16 @@ def visit_setup_join_tuple( return tuple( ( target._gen_cache_key(anon_map, bindparams), - onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None, - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None, + ( + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None + ), + ( + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None + ), tuple([(key, flags[key]) for key in sorted(flags)]), ) for (target, onclause, from_, flags) in obj @@ -933,9 +955,11 @@ def visit_string_multi_dict( tuple( ( key, - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value, + ( + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value + ), ) for key, value in [(key, obj[key]) for key in sorted(obj)] ), @@ -981,9 +1005,11 @@ def visit_dml_ordered_values( attrname, tuple( ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key, + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key + ), value._gen_cache_key(anon_map, bindparams), ) for key, value in obj @@ -1004,9 +1030,11 @@ def visit_dml_values( attrname, tuple( ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k, + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k + ), obj[k]._gen_cache_key(anon_map, bindparams), ) for k in obj diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index c4d340713ba..ac0393a6056 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -1,5 +1,5 @@ # sql/coercions.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,7 +29,6 @@ from typing import TypeVar from typing import Union -from . import operators from . import roles from . import visitors from ._typing import is_from_clause @@ -58,9 +57,9 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement - from .elements import DQLDMLClauseElement from .elements import NamedColumn from .elements import SQLCoreOperations + from .elements import TextClause from .schema import Column from .selectable import _ColumnsClauseElement from .selectable import _JoinTargetProtocol @@ -76,7 +75,7 @@ _T = TypeVar("_T", bound=Any) -def _is_literal(element): +def _is_literal(element: Any) -> bool: """Return whether or not the element is a "literal" in the context of a SQL expression construct. @@ -165,8 +164,7 @@ def expect( role: Type[roles.TruncatedLabelRole], element: Any, **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -176,8 +174,7 @@ def expect( *, as_key: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -185,8 +182,7 @@ def expect( role: Type[roles.LiteralValueRole], element: Any, **kw: Any, -) -> BindParameter[Any]: - ... +) -> BindParameter[Any]: ... @overload @@ -194,8 +190,7 @@ def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, -) -> Column[Any]: - ... +) -> Union[Column[Any], str]: ... @overload @@ -203,8 +198,7 @@ def expect( role: Type[roles.DDLConstraintColumnRole], element: Any, **kw: Any, -) -> Union[Column[Any], str]: - ... +) -> Union[Column[Any], str]: ... @overload @@ -212,8 +206,7 @@ def expect( role: Type[roles.StatementOptionRole], element: Any, **kw: Any, -) -> DQLDMLClauseElement: - ... +) -> Union[ColumnElement[Any], TextClause]: ... @overload @@ -221,8 +214,7 @@ def expect( role: Type[roles.LabeledColumnExprRole[Any]], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> NamedColumn[_T]: - ... +) -> NamedColumn[_T]: ... @overload @@ -234,8 +226,7 @@ def expect( ], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> ColumnElement[_T]: - ... +) -> ColumnElement[_T]: ... @overload @@ -249,8 +240,7 @@ def expect( ], element: Any, **kw: Any, -) -> ColumnElement[Any]: - ... +) -> ColumnElement[Any]: ... @overload @@ -258,8 +248,7 @@ def expect( role: Type[roles.DMLTableRole], element: _DMLTableArgument, **kw: Any, -) -> _DMLTableElement: - ... +) -> _DMLTableElement: ... @overload @@ -267,8 +256,7 @@ def expect( role: Type[roles.HasCTERole], element: HasCTE, **kw: Any, -) -> HasCTE: - ... +) -> HasCTE: ... @overload @@ -276,8 +264,7 @@ def expect( role: Type[roles.SelectStatementRole], element: SelectBase, **kw: Any, -) -> SelectBase: - ... +) -> SelectBase: ... @overload @@ -285,8 +272,7 @@ def expect( role: Type[roles.FromClauseRole], element: _FromClauseArgument, **kw: Any, -) -> FromClause: - ... +) -> FromClause: ... @overload @@ -296,8 +282,7 @@ def expect( *, explicit_subquery: Literal[True] = ..., **kw: Any, -) -> Subquery: - ... +) -> Subquery: ... @overload @@ -305,8 +290,7 @@ def expect( role: Type[roles.ColumnsClauseRole], element: _ColumnsClauseArgument[Any], **kw: Any, -) -> _ColumnsClauseElement: - ... +) -> _ColumnsClauseElement: ... @overload @@ -314,8 +298,7 @@ def expect( role: Type[roles.JoinTargetRole], element: _JoinTargetProtocol, **kw: Any, -) -> _JoinTargetProtocol: - ... +) -> _JoinTargetProtocol: ... # catchall for not-yet-implemented overloads @@ -324,8 +307,7 @@ def expect( role: Type[_SR], element: Any, **kw: Any, -) -> Any: - ... +) -> Any: ... def expect( @@ -510,6 +492,7 @@ def _raise_for_expected( element: Any, argname: Optional[str] = None, resolved: Optional[Any] = None, + *, advice: Optional[str] = None, code: Optional[str] = None, err: Optional[Exception] = None, @@ -612,7 +595,7 @@ def _no_text_coercion( class _NoTextCoercion(RoleImpl): __slots__ = () - def _literal_coercion(self, element, argname=None, **kw): + def _literal_coercion(self, element, *, argname=None, **kw): if isinstance(element, str) and issubclass( elements.TextClause, self._role_class ): @@ -630,7 +613,7 @@ class _CoerceLiterals(RoleImpl): def _text_coercion(self, element, argname=None): return _no_text_coercion(element, argname) - def _literal_coercion(self, element, argname=None, **kw): + def _literal_coercion(self, element, *, argname=None, **kw): if isinstance(element, str): if self._coerce_star and element == "*": return elements.ColumnClause("*", is_literal=True) @@ -658,7 +641,8 @@ def _implicit_coercions( self, element, resolved, - argname, + argname=None, + *, type_=None, literal_execute=False, **kw, @@ -676,7 +660,7 @@ def _implicit_coercions( literal_execute=literal_execute, ) - def _literal_coercion(self, element, argname=None, type_=None, **kw): + def _literal_coercion(self, element, **kw): return element @@ -688,6 +672,7 @@ def _raise_for_expected( element: Any, argname: Optional[str] = None, resolved: Optional[Any] = None, + *, advice: Optional[str] = None, code: Optional[str] = None, err: Optional[Exception] = None, @@ -762,7 +747,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): __slots__ = () def _literal_coercion( - self, element, name=None, type_=None, argname=None, is_crud=False, **kw + self, element, *, name=None, type_=None, is_crud=False, **kw ): if ( element is None @@ -804,15 +789,22 @@ def _raise_for_expected(self, element, argname=None, resolved=None, **kw): class BinaryElementImpl(ExpressionElementImpl, RoleImpl): __slots__ = () - def _literal_coercion( - self, element, expr, operator, bindparam_type=None, argname=None, **kw + def _literal_coercion( # type: ignore[override] + self, + element, + *, + expr, + operator, + bindparam_type=None, + argname=None, + **kw, ): try: return expr._bind_param(operator, element, type_=bindparam_type) except exc.ArgumentError as err: self._raise_for_expected(element, err=err) - def _post_coercion(self, resolved, expr, bindparam_type=None, **kw): + def _post_coercion(self, resolved, *, expr, bindparam_type=None, **kw): if resolved.type._isnull and not expr.type._isnull: resolved = resolved._with_binary_element_type( bindparam_type if bindparam_type is not None else expr.type @@ -850,31 +842,32 @@ def _warn_for_implicit_coercion(self, elem): % (elem.__class__.__name__) ) - def _literal_coercion(self, element, expr, operator, **kw): - if isinstance(element, collections_abc.Iterable) and not isinstance( - element, str - ): + @util.preload_module("sqlalchemy.sql.elements") + def _literal_coercion(self, element, *, expr, operator, **kw): # type: ignore[override] # noqa: E501 + if util.is_non_string_iterable(element): non_literal_expressions: Dict[ - Optional[operators.ColumnOperators], - operators.ColumnOperators, + Optional[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], ] = {} element = list(element) for o in element: if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): + if not isinstance( + o, util.preloaded.sql_elements.ColumnElement + ) and not hasattr(o, "__clause_element__"): self._raise_for_expected(element, **kw) else: non_literal_expressions[o] = o - elif o is None: - non_literal_expressions[o] = elements.Null() if non_literal_expressions: return elements.ClauseList( *[ - non_literal_expressions[o] - if o in non_literal_expressions - else expr._bind_param(operator, o) + ( + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) + ) for o in element ] ) @@ -884,7 +877,7 @@ def _literal_coercion(self, element, expr, operator, **kw): else: self._raise_for_expected(element, **kw) - def _post_coercion(self, element, expr, operator, **kw): + def _post_coercion(self, element, *, expr, operator, **kw): if element._is_select_base: # for IN, we are doing scalar_subquery() coercion without # a warning @@ -910,12 +903,10 @@ class OnClauseImpl(_ColumnCoercions, RoleImpl): _coerce_consts = True - def _literal_coercion( - self, element, name=None, type_=None, argname=None, is_crud=False, **kw - ): + def _literal_coercion(self, element, **kw): self._raise_for_expected(element) - def _post_coercion(self, resolved, original_element=None, **kw): + def _post_coercion(self, resolved, *, original_element=None, **kw): # this is a hack right now as we want to use coercion on an # ORM InstrumentedAttribute, but we want to return the object # itself if it is one, not its clause element. @@ -1000,7 +991,7 @@ def _implicit_coercions( class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () - def _post_coercion(self, element, as_key=False, **kw): + def _post_coercion(self, element, *, as_key=False, **kw): if as_key: return element.key else: @@ -1010,7 +1001,7 @@ def _post_coercion(self, element, as_key=False, **kw): class ConstExprImpl(RoleImpl): __slots__ = () - def _literal_coercion(self, element, argname=None, **kw): + def _literal_coercion(self, element, *, argname=None, **kw): if element is None: return elements.Null() elif element is False: @@ -1036,7 +1027,7 @@ def _implicit_coercions( else: self._raise_for_expected(element, argname, resolved) - def _literal_coercion(self, element, argname=None, **kw): + def _literal_coercion(self, element, **kw): """coerce the given value to :class:`._truncated_label`. Existing :class:`._truncated_label` and @@ -1086,7 +1077,9 @@ def _implicit_coercions( else: self._raise_for_expected(element, argname, resolved) - def _literal_coercion(self, element, name, type_, **kw): + def _literal_coercion( # type: ignore[override] + self, element, *, name, type_, **kw + ): if element is None: return None else: @@ -1128,7 +1121,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): _guess_straight_column = re.compile(r"^\w\S*$", re.I) def _raise_for_expected( - self, element, argname=None, resolved=None, advice=None, **kw + self, element, argname=None, resolved=None, *, advice=None, **kw ): if not advice and isinstance(element, list): advice = ( @@ -1152,9 +1145,9 @@ def _text_coercion(self, element, argname=None): % { "column": util.ellipses_string(element), "argname": "for argument %s" % (argname,) if argname else "", - "literal_column": "literal_column" - if guess_is_literal - else "column", + "literal_column": ( + "literal_column" if guess_is_literal else "column" + ), } ) @@ -1166,7 +1159,9 @@ class ReturnsRowsImpl(RoleImpl): class StatementImpl(_CoerceLiterals, RoleImpl): __slots__ = () - def _post_coercion(self, resolved, original_element, argname=None, **kw): + def _post_coercion( + self, resolved, *, original_element, argname=None, **kw + ): if resolved is not original_element and not isinstance( original_element, str ): @@ -1232,7 +1227,7 @@ class JoinTargetImpl(RoleImpl): _skip_clauseelement_for_target_match = True - def _literal_coercion(self, element, argname=None, **kw): + def _literal_coercion(self, element, *, argname=None, **kw): self._raise_for_expected(element, argname) def _implicit_coercions( @@ -1240,6 +1235,7 @@ def _implicit_coercions( element: Any, resolved: Any, argname: Optional[str] = None, + *, legacy: bool = False, **kw: Any, ) -> Any: @@ -1273,6 +1269,7 @@ def _implicit_coercions( element: Any, resolved: Any, argname: Optional[str] = None, + *, explicit_subquery: bool = False, allow_select: bool = True, **kw: Any, @@ -1294,7 +1291,7 @@ def _implicit_coercions( else: self._raise_for_expected(element, argname, resolved) - def _post_coercion(self, element, deannotate=False, **kw): + def _post_coercion(self, element, *, deannotate=False, **kw): if deannotate: return element._deannotate() else: @@ -1309,7 +1306,7 @@ def _implicit_coercions( element: Any, resolved: Any, argname: Optional[str] = None, - explicit_subquery: bool = False, + *, allow_select: bool = False, **kw: Any, ) -> Any: @@ -1329,7 +1326,7 @@ def _implicit_coercions( class AnonymizedFromClauseImpl(StrictFromClauseImpl): __slots__ = () - def _post_coercion(self, element, flat=False, name=None, **kw): + def _post_coercion(self, element, *, flat=False, name=None, **kw): assert name is None return element._anonymous_fromclause(flat=flat) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb6899c5e9a..2353fa39e40 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1,5 +1,5 @@ # sql/compiler.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,6 +29,7 @@ import collections.abc as collections_abc import contextlib from enum import IntEnum +import functools import itertools import operator import re @@ -73,38 +74,49 @@ from .base import _from_objects from .base import _NONE_NAME from .base import _SentinelDefaultCharacterization -from .base import Executable from .base import NO_ARG -from .elements import ClauseElement from .elements import quoted_name -from .schema import Column from .sqltypes import TupleType -from .type_api import TypeEngine from .visitors import prefix_anon_map -from .visitors import Visitable from .. import exc from .. import util from ..util import FastIntFlag from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import Self from ..util.typing import TypedDict if typing.TYPE_CHECKING: from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState + from .base import Executable from .cache_key import CacheKey from .ddl import ExecutableDDLElement from .dml import Insert + from .dml import Update from .dml import UpdateBase + from .dml import UpdateDMLState from .dml import ValuesBase from .elements import _truncated_label + from .elements import BinaryExpression from .elements import BindParameter + from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import False_ from .elements import Label + from .elements import Null + from .elements import True_ from .functions import Function + from .schema import Column + from .schema import Constraint + from .schema import ForeignKeyConstraint + from .schema import Index + from .schema import PrimaryKeyConstraint from .schema import Table + from .schema import UniqueConstraint + from .selectable import _ColumnsClauseElement from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -114,7 +126,10 @@ from .selectable import Select from .selectable import SelectState from .type_api import _BindProcessorType - from .type_api import _SentinelProcessorType + from .type_api import TypeDecorator + from .type_api import TypeEngine + from .type_api import UserDefinedType + from .visitors import Visitable from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _DBAPIAnyExecuteParams @@ -126,6 +141,7 @@ from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType + _FromHintsType = Dict["FromClause", str] RESERVED_WORDS = { @@ -382,8 +398,7 @@ def __call__( name: str, objects: Sequence[Any], type_: TypeEngine[Any], - ) -> None: - ... + ) -> None: ... # integer indexes into ResultColumnsEntry used by cursor.py. @@ -546,8 +561,8 @@ class _InsertManyValues(NamedTuple): """ - sentinel_param_keys: Optional[Sequence[Union[str, int]]] = None - """parameter str keys / int indexes in each param dictionary / tuple + sentinel_param_keys: Optional[Sequence[str]] = None + """parameter str keys in each param dictionary / tuple that would link to the client side "sentinel" values for that row, which we can use to match up parameter sets to result rows. @@ -557,6 +572,10 @@ class _InsertManyValues(NamedTuple): .. versionadded:: 2.0.10 + .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys + only, used against the "compiled parameteters" collection before + the parameters were converted by bound parameter processors + """ implicit_sentinel: bool = False @@ -601,7 +620,8 @@ class _InsertManyValuesBatch(NamedTuple): replaced_parameters: _DBAPIAnyExecuteParams processed_setinputsizes: Optional[_GenericSetInputSizesType] batch: Sequence[_DBAPISingleExecuteParams] - batch_size: int + sentinel_values: Sequence[Tuple[Any, ...]] + current_batch_size: int batchnum: int total_batches: int rows_sorted: bool @@ -737,7 +757,6 @@ def warn(self, stmt_type="SELECT"): class Compiled: - """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce @@ -867,6 +886,7 @@ def __init__( self.string = self.process(self.statement, **compile_kwargs) if render_schema_translate: + assert schema_translate_map is not None self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) @@ -899,7 +919,7 @@ def visit_unsupported_compilation(self, element, err, **kw): raise exc.UnsupportedCompilationError(self, type(element)) from err @property - def sql_compiler(self): + def sql_compiler(self) -> SQLCompiler: """Return a Compiled that is capable of processing SQL expressions. If this compiler is one, it would likely just return 'self'. @@ -967,7 +987,6 @@ def visit_unsupported_compilation( class _CompileLabel( roles.BinaryElementRole[Any], elements.CompilerColumnElement ): - """lightweight label object which acts as an expression.Label.""" __visit_name__ = "label" @@ -1037,19 +1056,19 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP - bindname_escape_characters: ClassVar[ - Mapping[str, str] - ] = util.immutabledict( - { - "%": "P", - "(": "A", - ")": "Z", - ":": "C", - ".": "_", - "[": "_", - "]": "_", - " ": "_", - } + bindname_escape_characters: ClassVar[Mapping[str, str]] = ( + util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) ) """A mapping (e.g. dict or similar) containing a lookup of characters keyed to replacement characters which will be applied to all @@ -1343,6 +1362,7 @@ def __init__( column_keys: Optional[Sequence[str]] = None, for_executemany: bool = False, linting: Linting = NO_LINTING, + _supporting_against: Optional[SQLCompiler] = None, **kwargs: Any, ): """Construct a new :class:`.SQLCompiler` object. @@ -1445,6 +1465,24 @@ def __init__( self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + if _supporting_against: + self.__dict__.update( + { + k: v + for k, v in _supporting_against.__dict__.items() + if k + not in { + "state", + "dialect", + "preparer", + "positional", + "_numeric_binds", + "compilation_bindtemplate", + "bindtemplate", + } + } + ) + if self.state is CompilerState.STRING_APPLIED: if self.positional: if self._numeric_binds: @@ -1659,19 +1697,9 @@ def find_position(m: re.Match[str]) -> str: for v in self._insertmanyvalues.insert_crud_params ] - sentinel_param_int_idxs = ( - [ - self.positiontup.index(cast(str, _param_key)) - for _param_key in self._insertmanyvalues.sentinel_param_keys # noqa: E501 - ] - if self._insertmanyvalues.sentinel_param_keys is not None - else None - ) - self._insertmanyvalues = self._insertmanyvalues._replace( single_values_expr=single_values_expr, insert_crud_params=insert_crud_params, - sentinel_param_keys=sentinel_param_int_idxs, ) def _process_numeric(self): @@ -1740,21 +1768,11 @@ def _process_numeric(self): for v in self._insertmanyvalues.insert_crud_params ] - sentinel_param_int_idxs = ( - [ - self.positiontup.index(cast(str, _param_key)) - for _param_key in self._insertmanyvalues.sentinel_param_keys # noqa: E501 - ] - if self._insertmanyvalues.sentinel_param_keys is not None - else None - ) - self._insertmanyvalues = self._insertmanyvalues._replace( # This has the numbers (:1, :2) single_values_expr=single_values_expr, # The single binds are instead %s so they can be formatted insert_crud_params=insert_crud_params, - sentinel_param_keys=sentinel_param_int_idxs, ) @util.memoized_property @@ -1770,11 +1788,15 @@ def _bind_processors( for key, value in ( ( self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - if not bindparam.type._is_tuple_type - else tuple( - elem_type._cached_bind_processor(self.dialect) - for elem_type in cast(TupleType, bindparam.type).types + ( + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in cast( + TupleType, bindparam.type + ).types + ) ), ) for bindparam in self.bind_names @@ -1782,28 +1804,11 @@ def _bind_processors( if value is not None } - @util.memoized_property - def _imv_sentinel_value_resolvers( - self, - ) -> Optional[Sequence[Optional[_SentinelProcessorType[Any]]]]: - imv = self._insertmanyvalues - if imv is None or imv.sentinel_columns is None: - return None - - sentinel_value_resolvers = [ - _scol.type._cached_sentinel_value_processor(self.dialect) - for _scol in imv.sentinel_columns - ] - if util.NONE_SET.issuperset(sentinel_value_resolvers): - return None - else: - return sentinel_value_resolvers - def is_subquery(self): return len(self.stack) > 1 @property - def sql_compiler(self): + def sql_compiler(self) -> Self: return self def construct_expanded_state( @@ -2080,11 +2085,11 @@ def _process_parameters_for_postcompile( if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: - replacement_expressions[ - escaped_name - ] = self.render_literal_bindparam( - parameter, - render_literal_value=parameters.pop(escaped_name), + replacement_expressions[escaped_name] = ( + self.render_literal_bindparam( + parameter, + render_literal_value=parameters.pop(escaped_name), + ) ) continue @@ -2293,12 +2298,14 @@ def get(lastrowid, parameters): else: return row_fn( ( - autoinc_getter(lastrowid, parameters) - if autoinc_getter is not None - else lastrowid + ( + autoinc_getter(lastrowid, parameters) + if autoinc_getter is not None + else lastrowid + ) + if col is autoinc_col + else getter(parameters) ) - if col is autoinc_col - else getter(parameters) for getter, col in getters ) @@ -2328,11 +2335,15 @@ def _inserted_primary_key_from_returning_getter(self): getters = cast( "List[Tuple[Callable[[Any], Any], bool]]", [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, + ( + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller( + "get", param_key_getter(col), None + ), + False, + ) ) for col in table.primary_key ], @@ -2348,15 +2359,80 @@ def get(row, parameters): return get - def default_from(self): + def default_from(self) -> str: """Called when a SELECT statement has no froms, and no FROM clause is to be appended. - Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string + output. """ return "" + def visit_override_binds(self, override_binds, **kw): + """SQL compile the nested element of an _OverrideBinds with + bindparams swapped out. + + The _OverrideBinds is not normally expected to be compiled; it + is meant to be used when an already cached statement is to be used, + the compilation was already performed, and only the bound params should + be swapped in at execution time. + + However, there are test cases that exericise this object, and + additionally the ORM subquery loader is known to feed in expressions + which include this construct into new queries (discovered in #11173), + so it has to do the right thing at compile time as well. + + """ + + # get SQL text first + sqltext = override_binds.element._compiler_dispatch(self, **kw) + + # for a test compile that is not for caching, change binds after the + # fact. note that we don't try to + # swap the bindparam as we compile, because our element may be + # elsewhere in the statement already (e.g. a subquery or perhaps a + # CTE) and was already visited / compiled. See + # test_relationship_criteria.py -> + # test_selectinload_local_criteria_subquery + for k in override_binds.translate: + if k not in self.binds: + continue + bp = self.binds[k] + + # so this would work, just change the value of bp in place. + # but we dont want to mutate things outside. + # bp.value = override_binds.translate[bp.key] + # continue + + # instead, need to replace bp with new_bp or otherwise accommodate + # in all internal collections + new_bp = bp._with_value( + override_binds.translate[bp.key], + maintain_key=True, + required=False, + ) + + name = self.bind_names[bp] + self.binds[k] = self.binds[name] = new_bp + self.bind_names[new_bp] = name + self.bind_names.pop(bp, None) + + if bp in self.post_compile_params: + self.post_compile_params |= {new_bp} + if bp in self.literal_execute_params: + self.literal_execute_params |= {new_bp} + + ckbm_tuple = self._cache_key_bind_match + if ckbm_tuple: + ckbm, cksm = ckbm_tuple + for bp in bp._cloned_set: + if bp.key in cksm: + cb = cksm[bp.key] + ckbm[cb].append(new_bp) + + return sqltext + def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" @@ -2401,9 +2477,9 @@ def visit_label_reference( resolve_dict[order_by_elem.name] ) ): - kwargs[ - "render_label_as_label" - ] = element.element._order_by_label_element + kwargs["render_label_as_label"] = ( + element.element._order_by_label_element + ) return self.process( element.element, within_columns_clause=within_columns_clause, @@ -2506,7 +2582,7 @@ def visit_label( def _fallback_column_name(self, column): raise exc.CompileError( - "Cannot compile Column object until " "its 'name' is assigned." + "Cannot compile Column object until its 'name' is assigned." ) def visit_lambda_element(self, element, **kw): @@ -2649,9 +2725,9 @@ def visit_textual_select( ) if populate_result_map: - self._ordered_columns = ( - self._textual_ordered_columns - ) = taf.positional + self._ordered_columns = self._textual_ordered_columns = ( + taf.positional + ) # enable looser result column matching when the SQL text links to # Column objects by name only @@ -2675,16 +2751,16 @@ def visit_textual_select( return text - def visit_null(self, expr, **kw): + def visit_null(self, expr: Null, **kw: Any) -> str: return "NULL" - def visit_true(self, expr, **kw): + def visit_true(self, expr: True_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "true" else: return "1" - def visit_false(self, expr, **kw): + def visit_false(self, expr: False_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "false" else: @@ -2778,36 +2854,60 @@ def visit_cast(self, cast, **kwargs): def _format_frame_clause(self, range_, **kw): return "%s AND %s" % ( - "UNBOUNDED PRECEDING" - if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" - if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" - % (self.process(elements.literal(abs(range_[0])), **kw),) - if range_[0] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[0]), **kw),), - "UNBOUNDED FOLLOWING" - if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" - if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" - % (self.process(elements.literal(abs(range_[1])), **kw),) - if range_[1] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[1]), **kw),), + ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[0])), **kw + ), + ) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),) + ) + ) + ), + ( + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[1])), **kw + ), + ) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),) + ) + ) + ), ) def visit_over(self, over, **kwargs): text = over.element._compiler_dispatch(self, **kwargs) - if over.range_: + if over.range_ is not None: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs ) - elif over.rows: + elif over.rows is not None: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( over.rows, **kwargs ) + elif over.groups is not None: + range_ = "GROUPS BETWEEN %s" % self._format_frame_clause( + over.groups, **kwargs + ) else: range_ = None @@ -2858,7 +2958,7 @@ def visit_function( **kwargs: Any, ) -> str: if add_to_result_map is not None: - add_to_result_map(func.name, func.name, (), func.type) + add_to_result_map(func.name, func.name, (func.name,), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) @@ -2906,7 +3006,7 @@ def visit_sequence(self, sequence, **kw): % self.dialect.name ) - def function_argspec(self, func, **kwargs): + def function_argspec(self, func: Function[Any], **kwargs: Any) -> str: return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( @@ -3036,9 +3136,12 @@ def visit_truediv_binary(self, binary, operator, **kw): + self.process( elements.Cast( binary.right, - binary.right.type - if binary.right.type._type_affinity is sqltypes.Numeric - else sqltypes.Numeric(), + ( + binary.right.type + if binary.right.type._type_affinity + is sqltypes.Numeric + else sqltypes.Numeric() + ), ), **kw, ) @@ -3367,8 +3470,12 @@ def visit_custom_op_unary_modifier(self, element, operator, **kw): ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw - ): + self, + binary: BinaryExpression[Any], + opstring: str, + eager_grouping: bool = False, + **kw: Any, + ) -> str: _in_operator_expression = kw.get("_in_operator_expression", False) kw["_in_operator_expression"] = True @@ -3537,19 +3644,25 @@ def visit_not_between_op_binary(self, binary, operator, **kw): **kw, ) - def visit_regexp_match_op_binary(self, binary, operator, **kw): + def visit_regexp_match_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + def visit_not_regexp_match_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_regexp_replace_op_binary(self, binary, operator, **kw): + def visit_regexp_replace_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expression replacements" % self.dialect.name @@ -3565,6 +3678,7 @@ def visit_bindparam( render_postcompile=False, **kwargs, ): + if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: @@ -3755,7 +3869,9 @@ def render_literal_bindparam( else: return self.render_literal_value(value, bindparam.type) - def render_literal_value(self, value, type_): + def render_literal_value( + self, value: Any, type_: sqltypes.TypeEngine[Any] + ) -> str: """Render the value of a bind parameter as a quoted literal. This is used for statement sections that do not accept bind parameters @@ -3991,15 +4107,28 @@ def visit_cte( del self.level_name_by_cte[existing_cte_reference_cte] else: - # if the two CTEs are deep-copy identical, consider them - # the same, **if** they are clones, that is, they came from - # the ORM or other visit method if ( - cte._is_clone_of is not None - or existing_cte._is_clone_of is not None - ) and cte.compare(existing_cte): + # if the two CTEs have the same hash, which we expect + # here means that one/both is an annotated of the other + (hash(cte) == hash(existing_cte)) + # or... + or ( + ( + # if they are clones, i.e. they came from the ORM + # or some other visit method + cte._is_clone_of is not None + or existing_cte._is_clone_of is not None + ) + # and are deep-copy identical + and cte.compare(existing_cte) + ) + ): + # then consider these two CTEs the same is_new_cte = False else: + # otherwise these are two CTEs that either will render + # differently, or were indicated separately by the user, + # with the same name raise exc.CompileError( "Multiple, unrelated CTEs found with " "the same name: %r" % cte_name @@ -4101,7 +4230,7 @@ def visit_cte( if self.preparer._requires_quotes(cte_name): cte_name = self.preparer.quote(cte_name) text += self.get_render_as_alias_suffix(cte_name) - return text + return text # type: ignore[no-any-return] else: return self.preparer.format_alias(cte, cte_name) @@ -4163,7 +4292,7 @@ def visit_alias( inner = "(%s)" % (inner,) return inner else: - enclosing_alias = kwargs["enclosing_alias"] = alias + kwargs["enclosing_alias"] = alias if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): @@ -4193,12 +4322,14 @@ def visit_alias( "%s%s" % ( self.preparer.quote(col.name), - " %s" - % self.dialect.type_compiler_instance.process( - col.type, **kwargs - ) - if alias._render_derived_w_types - else "", + ( + " %s" + % self.dialect.type_compiler_instance.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "" + ), ) for col in alias.c ) @@ -4302,6 +4433,11 @@ def _add_to_result_map( objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: + + # note objects must be non-empty for cursor.py to handle the + # collection properly + assert objects + if keyname is None or keyname == "*": self._ordered_columns = False self._ad_hoc_textual = True @@ -4375,7 +4511,7 @@ def _label_select_column( _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - _add_to_result_map(keyname, name, (), type_) + _add_to_result_map(keyname, name, (keyname,), type_) # if we redefined col_expr for type expressions, wrap the # callable with one that adds the original column to the targets @@ -4509,7 +4645,9 @@ def format_from_hint_text(self, sqltext, table, hint, iscrud): def get_select_hint_text(self, byfroms): return None - def get_from_hint_text(self, table, text): + def get_from_hint_text( + self, table: FromClause, text: Optional[str] + ) -> Optional[str]: return None def get_crud_hint_text(self, table, text): @@ -4590,9 +4728,9 @@ def visit_select( compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) - kwargs[ - "ambiguous_table_name_map" - ] = compile_state._ambiguous_table_name_map + kwargs["ambiguous_table_name_map"] = ( + compile_state._ambiguous_table_name_map + ) select_stmt = compile_state.statement @@ -4994,7 +5132,7 @@ def get_cte_preamble(self, recursive): else: return "WITH" - def get_select_precolumns(self, select, **kw): + def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str: """Called when building a ``SELECT`` statement, position is just before column list. @@ -5039,7 +5177,7 @@ def for_update_clause(self, select, **kw): def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[ColumnElement[Any]], + returning_cols: Sequence[_ColumnsClauseElement], *, populate_result_map: bool, **kw: Any, @@ -5129,6 +5267,7 @@ def visit_table( use_schema=True, from_linter=None, ambiguous_table_name_map=None, + enclosing_alias=None, **kwargs, ): if from_linter: @@ -5147,7 +5286,11 @@ def visit_table( ret = self.preparer.quote(table.name) if ( - not effective_schema + ( + enclosing_alias is None + or enclosing_alias.element is not table + ) + and not effective_schema and ambiguous_table_name_map and table.name in ambiguous_table_name_map ): @@ -5307,13 +5450,22 @@ def _deliver_insertmanyvalues_batches( self, statement: str, parameters: _DBAPIMultiExecuteParams, + compiled_parameters: List[_MutableCoreSingleExecuteParams], generic_setinputsizes: Optional[_GenericSetInputSizesType], batch_size: int, sort_by_parameter_order: bool, + schema_translate_map: Optional[SchemaTranslateMapType], ) -> Iterator[_InsertManyValuesBatch]: imv = self._insertmanyvalues assert imv is not None + if not imv.sentinel_param_keys: + _sentinel_from_params = None + else: + _sentinel_from_params = operator.itemgetter( + *imv.sentinel_param_keys + ) + lenparams = len(parameters) if imv.is_default_expr and not self.dialect.supports_default_metavalue: # backend doesn't support @@ -5345,15 +5497,24 @@ def _deliver_insertmanyvalues_batches( downgraded = False if use_row_at_a_time: - for batchnum, param in enumerate( - cast("Sequence[_DBAPISingleExecuteParams]", parameters), 1 + for batchnum, (param, compiled_param) in enumerate( + cast( + "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501 + zip(parameters, compiled_parameters), + ), + 1, ): yield _InsertManyValuesBatch( statement, param, generic_setinputsizes, [param], - batch_size, + ( + [_sentinel_from_params(compiled_param)] + if _sentinel_from_params + else [] + ), + 1, batchnum, lenparams, sort_by_parameter_order, @@ -5361,7 +5522,19 @@ def _deliver_insertmanyvalues_batches( ) return - executemany_values = f"({imv.single_values_expr})" + if schema_translate_map: + rst = functools.partial( + self.preparer._render_schema_translates, + schema_translate_map=schema_translate_map, + ) + else: + rst = None + + imv_single_values_expr = imv.single_values_expr + if rst: + imv_single_values_expr = rst(imv_single_values_expr) + + executemany_values = f"({imv_single_values_expr})" statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__") # Use optional insertmanyvalues_max_parameters @@ -5384,7 +5557,10 @@ def _deliver_insertmanyvalues_batches( ), ) - batches = list(parameters) + batches = cast("List[Sequence[Any]]", list(parameters)) + compiled_batches = cast( + "List[Sequence[Any]]", list(compiled_parameters) + ) processed_setinputsizes: Optional[_GenericSetInputSizesType] = None batchnum = 1 @@ -5395,6 +5571,12 @@ def _deliver_insertmanyvalues_batches( insert_crud_params = imv.insert_crud_params assert insert_crud_params is not None + if rst: + insert_crud_params = [ + (col, key, rst(expr), st) + for col, key, expr, st in insert_crud_params + ] + escaped_bind_names: Mapping[str, str] expand_pos_lower_index = expand_pos_upper_index = 0 @@ -5442,10 +5624,10 @@ def apply_placeholders(keys, formatted): if imv.embed_values_counter: executemany_values_w_comma = ( - f"({imv.single_values_expr}, _IMV_VALUES_COUNTER), " + f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), " ) else: - executemany_values_w_comma = f"({imv.single_values_expr}), " + executemany_values_w_comma = f"({imv_single_values_expr}), " all_names_we_will_expand: Set[str] = set() for elem in imv.insert_crud_params: @@ -5478,8 +5660,16 @@ def apply_placeholders(keys, formatted): ) while batches: - batch = cast("Sequence[Any]", batches[0:batch_size]) + batch = batches[0:batch_size] + compiled_batch = compiled_batches[0:batch_size] + batches[0:batch_size] = [] + compiled_batches[0:batch_size] = [] + + if batches: + current_batch_size = batch_size + else: + current_batch_size = len(batch) if generic_setinputsizes: # if setinputsizes is present, expand this collection to @@ -5489,7 +5679,7 @@ def apply_placeholders(keys, formatted): (new_key, len_, typ) for new_key, len_, typ in ( (f"{key}_{index}", len_, typ) - for index in range(len(batch)) + for index in range(current_batch_size) for key, len_, typ in generic_setinputsizes ) ] @@ -5499,6 +5689,9 @@ def apply_placeholders(keys, formatted): num_ins_params = imv.num_positional_params_counted batch_iterator: Iterable[Sequence[Any]] + extra_params_left: Sequence[Any] + extra_params_right: Sequence[Any] + if num_ins_params == len(batch[0]): extra_params_left = extra_params_right = () batch_iterator = batch @@ -5521,7 +5714,7 @@ def apply_placeholders(keys, formatted): )[:-2] else: expanded_values_string = ( - (executemany_values_w_comma * len(batch)) + (executemany_values_w_comma * current_batch_size) )[:-2] if self._numeric_binds and num_ins_params > 0: @@ -5537,7 +5730,7 @@ def apply_placeholders(keys, formatted): assert not extra_params_right start = expand_pos_lower_index + 1 - end = num_ins_params * (len(batch)) + start + end = num_ins_params * (current_batch_size) + start # need to format here, since statement may contain # unescaped %, while values_string contains just (%s, %s) @@ -5587,7 +5780,12 @@ def apply_placeholders(keys, formatted): replaced_parameters, processed_setinputsizes, batch, - batch_size, + ( + [_sentinel_from_params(cb) for cb in compiled_batch] + if _sentinel_from_params + else [] + ), + current_batch_size, batchnum, total_batches, sort_by_parameter_order, @@ -5595,13 +5793,19 @@ def apply_placeholders(keys, formatted): ) batchnum += 1 - def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): + def visit_insert( + self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw + ): compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw ) insert_stmt = compile_state.statement - toplevel = not self.stack + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack if toplevel: self.isinsert = True @@ -5629,14 +5833,12 @@ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): # params inside them. After multiple attempts to figure this out, # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy - if self.positional: + if self.positional and visiting_cte is None: # if we are inside a CTE, don't count parameters # here since they wont be for insertmanyvalues. keep # visited_bindparam at None so no counting happens. # see #9173 - has_visiting_cte = "visiting_cte" in kw - if not has_visiting_cte: - visited_bindparam = [] + visited_bindparam = [] crud_params_struct = crud._get_crud_params( self, @@ -5724,7 +5926,6 @@ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): returning_cols = self.implicit_returning or insert_stmt._returning if returning_cols: add_sentinel_cols = crud_params_struct.use_sentinel_columns - if add_sentinel_cols is not None: assert use_insertmanyvalues @@ -5831,9 +6032,9 @@ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): insert_stmt._post_values_clause is not None ), sentinel_columns=add_sentinel_cols, - num_sentinel_columns=len(add_sentinel_cols) - if add_sentinel_cols - else 0, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), implicit_sentinel=implicit_sentinel, ) elif compile_state._has_multi_parameters: @@ -5927,9 +6128,9 @@ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): insert_stmt._post_values_clause is not None ), sentinel_columns=add_sentinel_cols, - num_sentinel_columns=len(add_sentinel_cols) - if add_sentinel_cols - else 0, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), sentinel_param_keys=named_sentinel_params, implicit_sentinel=implicit_sentinel, embed_values_counter=embed_sentinel_value, @@ -5966,6 +6167,10 @@ def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None + def delete_limit_clause(self, delete_stmt): + """Provide a hook for MySQL to add LIMIT to the DELETE""" + return None + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. @@ -5990,13 +6195,25 @@ def update_from_clause( "criteria within UPDATE" ) - def visit_update(self, update_stmt, **kw): - compile_state = update_stmt._compile_state_factory( - update_stmt, self, **kw + def visit_update( + self, + update_stmt: Update, + visiting_cte: Optional[CTE] = None, + **kw: Any, + ) -> str: + compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # noqa: E501 + update_stmt, self, **kw # type: ignore[arg-type] ) - update_stmt = compile_state.statement + if TYPE_CHECKING: + assert isinstance(compile_state, UpdateDMLState) + update_stmt = compile_state.statement # type: ignore[assignment] + + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack - toplevel = not self.stack if toplevel: self.isupdate = True if not self.dml_compile_state: @@ -6124,10 +6341,10 @@ def visit_update(self, update_stmt, **kw): self.stack.pop(-1) - return text + return text # type: ignore[no-any-return] def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -6147,13 +6364,18 @@ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): self, asfrom=True, iscrud=True, **kw ) - def visit_delete(self, delete_stmt, **kw): + def visit_delete(self, delete_stmt, visiting_cte=None, **kw): compile_state = delete_stmt._compile_state_factory( delete_stmt, self, **kw ) delete_stmt = compile_state.statement - toplevel = not self.stack + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + if toplevel: self.isdelete = True if not self.dml_compile_state: @@ -6248,6 +6470,10 @@ def visit_delete(self, delete_stmt, **kw): if t: text += " WHERE " + t + limit_clause = self.delete_limit_clause(delete_stmt) + if limit_clause: + text += " " + limit_clause + if ( self.implicit_returning or delete_stmt._returning ) and not self.returning_precedes_values: @@ -6312,9 +6538,11 @@ def visit_unsupported_compilation(self, element, err, **kw): url = util.preloaded.engine_url dialect = url.URL.create(element.stringify_dialect).get_dialect()() - compiler = dialect.statement_compiler(dialect, None) + compiler = dialect.statement_compiler( + dialect, None, _supporting_against=self + ) if not isinstance(compiler, StrSQLCompiler): - return compiler.process(element) + return compiler.process(element, **kw) return super().visit_unsupported_compilation(element, err) @@ -6330,13 +6558,15 @@ def visit_json_getitem_op_binary(self, binary, operator, **kw): def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return self.visit_getitem_binary(binary, operator, **kw) - def visit_sequence(self, seq, **kw): - return "" % self.preparer.format_sequence(seq) + def visit_sequence(self, sequence, **kw): + return ( + f"" + ) def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[ColumnElement[Any]], + returning_cols: Sequence[_ColumnsClauseElement], *, populate_result_map: bool, **kw: Any, @@ -6357,7 +6587,7 @@ def update_from_clause( ) def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): kw["asfrom"] = True return ", " + ", ".join( @@ -6365,7 +6595,7 @@ def delete_extra_from_clause( for t in extra_froms ) - def visit_empty_set_expr(self, type_, **kw): + def visit_empty_set_expr(self, element_types, **kw): return "SELECT 1 WHERE 1!=1" def get_from_hint_text(self, table, text): @@ -6402,11 +6632,10 @@ def __init__( schema_translate_map: Optional[SchemaTranslateMapType] = ..., render_schema_translate: bool = ..., compile_kwargs: Mapping[str, Any] = ..., - ): - ... + ): ... - @util.memoized_property - def sql_compiler(self): + @util.ro_memoized_property + def sql_compiler(self) -> SQLCompiler: return self.dialect.statement_compiler( self.dialect, None, schema_translate_map=self.schema_translate_map ) @@ -6570,10 +6799,10 @@ def visit_drop_table(self, drop, **kw): def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) - def _verify_index_table(self, index): + def _verify_index_table(self, index: Index) -> None: if index.table is None: raise exc.CompileError( - "Index '%s' is not associated " "with any table." % index.name + "Index '%s' is not associated with any table." % index.name ) def visit_create_index( @@ -6621,7 +6850,9 @@ def visit_drop_index(self, drop, **kw): return text + self._prepared_index_name(index, include_schema=True) - def _prepared_index_name(self, index, include_schema=False): + def _prepared_index_name( + self, index: Index, include_schema: bool = False + ) -> str: if index.table is not None: effective_schema = self.preparer.schema_for_object(index.table) else: @@ -6631,7 +6862,7 @@ def _prepared_index_name(self, index, include_schema=False): else: schema_name = None - index_name = self.preparer.format_index(index) + index_name: str = self.preparer.format_index(index) if schema_name: index_name = schema_name + "." + index_name @@ -6768,13 +6999,13 @@ def create_table_suffix(self, table): def post_create_table(self, table): return "" - def get_column_default_string(self, column): + def get_column_default_string(self, column: Column[Any]) -> Optional[str]: if isinstance(column.server_default, schema.DefaultClause): return self.render_default_string(column.server_default.arg) else: return None - def render_default_string(self, default): + def render_default_string(self, default: Union[Visitable, str]) -> str: if isinstance(default, str): return self.sql_compiler.render_literal_value( default, sqltypes.STRINGTYPE @@ -6812,7 +7043,9 @@ def visit_column_check_constraint(self, constraint, **kw): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint, **kw): + def visit_primary_key_constraint( + self, constraint: PrimaryKeyConstraint, **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -6861,7 +7094,9 @@ def define_constraint_remote_table(self, constraint, table, preparer): return preparer.format_table(table) - def visit_unique_constraint(self, constraint, **kw): + def visit_unique_constraint( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -6876,22 +7111,37 @@ def visit_unique_constraint(self, constraint, **kw): text += self.define_constraint_deferrability(constraint) return text - def define_unique_constraint_distinct(self, constraint, **kw): + def define_unique_constraint_distinct( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: return "" - def define_constraint_cascades(self, constraint): + def define_constraint_cascades( + self, constraint: ForeignKeyConstraint + ) -> str: text = "" if constraint.ondelete is not None: - text += " ON DELETE %s" % self.preparer.validate_sql_phrase( - constraint.ondelete, FK_ON_DELETE - ) + text += self.define_constraint_ondelete_cascade(constraint) + if constraint.onupdate is not None: - text += " ON UPDATE %s" % self.preparer.validate_sql_phrase( - constraint.onupdate, FK_ON_UPDATE - ) + text += self.define_constraint_onupdate_cascade(constraint) return text - def define_constraint_deferrability(self, constraint): + def define_constraint_ondelete_cascade( + self, constraint: ForeignKeyConstraint + ) -> str: + return " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, FK_ON_DELETE + ) + + def define_constraint_onupdate_cascade( + self, constraint: ForeignKeyConstraint + ) -> str: + return " ON UPDATE %s" % self.preparer.validate_sql_phrase( + constraint.onupdate, FK_ON_UPDATE + ) + + def define_constraint_deferrability(self, constraint: Constraint) -> str: text = "" if constraint.deferrable is not None: if constraint.deferrable: @@ -6931,19 +7181,21 @@ def visit_identity_column(self, identity, **kw): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str: return "FLOAT" - def visit_DOUBLE(self, type_, **kw): + def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str: return "DOUBLE" - def visit_DOUBLE_PRECISION(self, type_, **kw): + def visit_DOUBLE_PRECISION( + self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any + ) -> str: return "DOUBLE PRECISION" - def visit_REAL(self, type_, **kw): + def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str: return "REAL" - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str: if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -6954,7 +7206,7 @@ def visit_NUMERIC(self, type_, **kw): "scale": type_.scale, } - def visit_DECIMAL(self, type_, **kw): + def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str: if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -6965,122 +7217,138 @@ def visit_DECIMAL(self, type_, **kw): "scale": type_.scale, } - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str: return "INTEGER" - def visit_SMALLINT(self, type_, **kw): + def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str: return "SMALLINT" - def visit_BIGINT(self, type_, **kw): + def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str: return "BIGINT" - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str: return "TIMESTAMP" - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str: return "DATETIME" - def visit_DATE(self, type_, **kw): + def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str: return "DATE" - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str: return "TIME" - def visit_CLOB(self, type_, **kw): + def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str: return "CLOB" - def visit_NCLOB(self, type_, **kw): + def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str: return "NCLOB" - def _render_string_type(self, type_, name, length_override=None): + def _render_string_type( + self, name: str, length: Optional[int], collation: Optional[str] + ) -> str: text = name - if length_override: - text += "(%d)" % length_override - elif type_.length: - text += "(%d)" % type_.length - if type_.collation: - text += ' COLLATE "%s"' % type_.collation + if length: + text += f"({length})" + if collation: + text += f' COLLATE "{collation}"' return text - def visit_CHAR(self, type_, **kw): - return self._render_string_type(type_, "CHAR") + def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str: + return self._render_string_type("CHAR", type_.length, type_.collation) + + def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str: + return self._render_string_type("NCHAR", type_.length, type_.collation) - def visit_NCHAR(self, type_, **kw): - return self._render_string_type(type_, "NCHAR") + def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str: + return self._render_string_type( + "VARCHAR", type_.length, type_.collation + ) - def visit_VARCHAR(self, type_, **kw): - return self._render_string_type(type_, "VARCHAR") + def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str: + return self._render_string_type( + "NVARCHAR", type_.length, type_.collation + ) - def visit_NVARCHAR(self, type_, **kw): - return self._render_string_type(type_, "NVARCHAR") + def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str: + return self._render_string_type("TEXT", type_.length, type_.collation) - def visit_TEXT(self, type_, **kw): - return self._render_string_type(type_, "TEXT") + def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str: + return "UUID" - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str: return "BLOB" - def visit_BINARY(self, type_, **kw): + def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str: return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_, **kw): + def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str: return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str: return "BOOLEAN" - def visit_uuid(self, type_, **kw): - return self._render_string_type(type_, "CHAR", length_override=32) + def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str: + if not type_.native_uuid or not self.dialect.supports_native_uuid: + return self._render_string_type("CHAR", length=32, collation=None) + else: + return self.visit_UUID(type_, **kw) - def visit_large_binary(self, type_, **kw): + def visit_large_binary( + self, type_: sqltypes.LargeBinary, **kw: Any + ) -> str: return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_, **kw): + def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str: return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_, **kw): + def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str: return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_, **kw): + def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str: return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_, **kw): + def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str: return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_, **kw): + def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str: return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_, **kw): + def visit_small_integer( + self, type_: sqltypes.SmallInteger, **kw: Any + ) -> str: return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_, **kw): + def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str: return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_, **kw): + def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str: return self.visit_REAL(type_, **kw) - def visit_float(self, type_, **kw): + def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str: return self.visit_FLOAT(type_, **kw) - def visit_double(self, type_, **kw): + def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str: return self.visit_DOUBLE(type_, **kw) - def visit_numeric(self, type_, **kw): + def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str: return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_, **kw): + def visit_string(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_, **kw): + def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_, **kw): + def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str: return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_, **kw): + def visit_unicode_text( + self, type_: sqltypes.UnicodeText, **kw: Any + ) -> str: return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_, **kw): + def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): @@ -7090,10 +7358,14 @@ def visit_null(self, type_, **kw): "type on this Column?" % type_ ) - def visit_type_decorator(self, type_, **kw): + def visit_type_decorator( + self, type_: TypeDecorator[Any], **kw: Any + ) -> str: return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_, **kw): + def visit_user_defined( + self, type_: UserDefinedType[Any], **kw: Any + ) -> str: return type_.get_col_spec(**kw) @@ -7131,17 +7403,14 @@ def visit_user_defined(self, type_, **kw): class _SchemaForObjectCallable(Protocol): - def __call__(self, obj: Any) -> str: - ... + def __call__(self, __obj: Any) -> str: ... class _BindNameForColProtocol(Protocol): - def __call__(self, col: ColumnClause[Any]) -> str: - ... + def __call__(self, col: ColumnClause[Any]) -> str: ... class IdentifierPreparer: - """Handle quoting and case-folding of identifiers based on options.""" reserved_words = RESERVED_WORDS @@ -7171,12 +7440,12 @@ class IdentifierPreparer: def __init__( self, - dialect, - initial_quote='"', - final_quote=None, - escape_quote='"', - quote_case_sensitive_collations=True, - omit_schema=False, + dialect: Dialect, + initial_quote: str = '"', + final_quote: Optional[str] = None, + escape_quote: str = '"', + quote_case_sensitive_collations: bool = True, + omit_schema: bool = False, ): """Construct a new ``IdentifierPreparer`` object. @@ -7229,7 +7498,9 @@ def symbol_getter(obj): prep._includes_none_schema_translate = includes_none return prep - def _render_schema_translates(self, statement, schema_translate_map): + def _render_schema_translates( + self, statement: str, schema_translate_map: SchemaTranslateMapType + ) -> str: d = schema_translate_map if None in d: if not self._includes_none_schema_translate: @@ -7241,7 +7512,7 @@ def _render_schema_translates(self, statement, schema_translate_map): "schema_translate_map dictionaries." ) - d["_none"] = d[None] + d["_none"] = d[None] # type: ignore[index] def replace(m): name = m.group(2) @@ -7434,7 +7705,9 @@ def format_collation(self, collation_name): else: return collation_name - def format_sequence(self, sequence, use_schema=True): + def format_sequence( + self, sequence: schema.Sequence, use_schema: bool = True + ) -> str: name = self.quote(sequence.name) effective_schema = self.schema_for_object(sequence) @@ -7471,7 +7744,9 @@ def format_savepoint(self, savepoint, name=None): return ident @util.preload_module("sqlalchemy.sql.naming") - def format_constraint(self, constraint, _alembic_quote=True): + def format_constraint( + self, constraint: Union[Constraint, Index], _alembic_quote: bool = True + ) -> Optional[str]: naming = util.preloaded.sql_naming if constraint.name is _NONE_NAME: @@ -7484,6 +7759,7 @@ def format_constraint(self, constraint, _alembic_quote=True): else: name = constraint.name + assert name is not None if constraint.__visit_name__ == "index": return self.truncate_and_render_index_name( name, _alembic_quote=_alembic_quote @@ -7493,7 +7769,9 @@ def format_constraint(self, constraint, _alembic_quote=True): name, _alembic_quote=_alembic_quote ) - def truncate_and_render_index_name(self, name, _alembic_quote=True): + def truncate_and_render_index_name( + self, name: str, _alembic_quote: bool = True + ) -> str: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7505,7 +7783,9 @@ def truncate_and_render_index_name(self, name, _alembic_quote=True): name, max_, _alembic_quote ) - def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + def truncate_and_render_constraint_name( + self, name: str, _alembic_quote: bool = True + ) -> str: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7517,7 +7797,9 @@ def truncate_and_render_constraint_name(self, name, _alembic_quote=True): name, max_, _alembic_quote ) - def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + def _truncate_and_render_maxlen_name( + self, name: str, max_: int, _alembic_quote: bool + ) -> str: if isinstance(name, elements._truncated_label): if len(name) > max_: name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] @@ -7529,13 +7811,21 @@ def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): else: return self.quote(name) - def format_index(self, index): - return self.format_constraint(index) + def format_index(self, index: Index) -> str: + name = self.format_constraint(index) + assert name is not None + return name - def format_table(self, table, use_schema=True, name=None): + def format_table( + self, + table: FromClause, + use_schema: bool = True, + name: Optional[str] = None, + ) -> str: """Prepare a quoted table and schema name.""" - if name is None: + if TYPE_CHECKING: + assert isinstance(table, NamedFromClause) name = table.name result = self.quote(name) @@ -7567,17 +7857,18 @@ def format_label_name( def format_column( self, - column, - use_table=False, - name=None, - table_name=None, - use_schema=False, - anon_map=None, - ): + column: ColumnElement[Any], + use_table: bool = False, + name: Optional[str] = None, + table_name: Optional[str] = None, + use_schema: bool = False, + anon_map: Optional[Mapping[str, Any]] = None, + ) -> str: """Prepare a quoted column name.""" if name is None: name = column.name + assert name is not None if anon_map is not None and isinstance( name, elements._truncated_label @@ -7645,7 +7936,7 @@ def _r_identifiers(self): ) return r - def unformat_identifiers(self, identifiers): + def unformat_identifiers(self, identifiers: str) -> Sequence[str]: """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e51403eceda..4a592ff7b97 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -1,5 +1,5 @@ # sql/crud.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -241,7 +241,7 @@ def _get_crud_params( stmt_parameter_tuples = list(spd.items()) spd_str_key = {_column_as_key(key) for key in spd} else: - stmt_parameter_tuples = spd = spd_str_key = None + stmt_parameter_tuples = spd_str_key = None # if we have statement parameters - set defaults in the # compiled params @@ -393,9 +393,9 @@ def _create_bind_param( process: Literal[True] = ..., required: bool = False, name: Optional[str] = None, + force_anonymous: bool = False, **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -404,8 +404,7 @@ def _create_bind_param( col: ColumnElement[Any], value: Any, **kw: Any, -) -> str: - ... +) -> str: ... def _create_bind_param( @@ -415,10 +414,14 @@ def _create_bind_param( process: bool = True, required: bool = False, name: Optional[str] = None, + force_anonymous: bool = False, **kw: Any, ) -> Union[str, elements.BindParameter[Any]]: - if name is None: + if force_anonymous: + name = None + elif name is None: name = col.key + bindparam = elements.BindParameter( name, value, type_=col.type, required=required ) @@ -488,7 +491,7 @@ def _key_getters_for_crud_column( ) def _column_as_key( - key: Union[ColumnClause[Any], str] + key: Union[ColumnClause[Any], str], ) -> Union[str, Tuple[str, str]]: str_key = c_key_role(key) if hasattr(key, "table") and key.table in _et: @@ -834,6 +837,7 @@ def _append_param_parameter( ): value = parameters.pop(col_key) + has_visiting_cte = kw.get("visiting_cte") is not None col_value = compiler.preparer.format_column( c, use_table=compile_state.include_table_with_column_exprs ) @@ -859,11 +863,14 @@ def _append_param_parameter( c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not _compile_state_isinsert(compile_state) - or not compile_state._has_multi_parameters - else "%s_m0" % _col_bind_name(c), + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), accumulate_bind_names=accumulated_bind_names, + force_anonymous=has_visiting_cte, **kw, ) elif value._is_bind_parameter: @@ -884,10 +891,12 @@ def _append_param_parameter( compiler, c, value, - name=_col_bind_name(c) - if not _compile_state_isinsert(compile_state) - or not compile_state._has_multi_parameters - else "%s_m0" % _col_bind_name(c), + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), accumulate_bind_names=accumulated_bind_names, **kw, ) @@ -1213,8 +1222,7 @@ def _create_insert_prefetch_bind_param( c: ColumnElement[Any], process: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -1223,8 +1231,7 @@ def _create_insert_prefetch_bind_param( c: ColumnElement[Any], process: Literal[False], **kw: Any, -) -> elements.BindParameter[Any]: - ... +) -> elements.BindParameter[Any]: ... def _create_insert_prefetch_bind_param( @@ -1247,8 +1254,7 @@ def _create_update_prefetch_bind_param( c: ColumnElement[Any], process: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -1257,8 +1263,7 @@ def _create_update_prefetch_bind_param( c: ColumnElement[Any], process: Literal[False], **kw: Any, -) -> elements.BindParameter[Any]: - ... +) -> elements.BindParameter[Any]: ... def _create_update_prefetch_bind_param( @@ -1288,7 +1293,7 @@ def __init__(self, original, index): def compare(self, other, **kw): raise NotImplementedError() - def _copy_internals(self, other, **kw): + def _copy_internals(self, **kw): raise NotImplementedError() def __eq__(self, other): @@ -1437,6 +1442,7 @@ def _extend_values_for_multiparams( values_0 = initial_values values = [initial_values] + has_visiting_cte = kw.get("visiting_cte") is not None mp = compile_state._multi_parameters assert mp is not None for i, row in enumerate(mp[1:]): @@ -1453,7 +1459,8 @@ def _extend_values_for_multiparams( compiler, col, row[key], - name="%s_m%d" % (col.key, i + 1), + name=("%s_m%d" % (col.key, i + 1)), + force_anonymous=has_visiting_cte, **kw, ) else: diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 06bbcae2e4b..70a83cb8a73 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -1,5 +1,5 @@ # sql/ddl.py -# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -17,11 +17,14 @@ import typing from typing import Any from typing import Callable +from typing import Generic from typing import Iterable from typing import List from typing import Optional from typing import Sequence as typing_Sequence from typing import Tuple +from typing import TypeVar +from typing import Union from . import roles from .base import _generative @@ -38,10 +41,12 @@ from .compiler import Compiled from .compiler import DDLCompiler from .elements import BindParameter + from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint + from .schema import Index from .schema import SchemaItem - from .schema import Sequence + from .schema import Sequence as Sequence # noqa: F401 from .schema import Table from .selectable import TableClause from ..engine.base import Connection @@ -50,6 +55,8 @@ from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType +_SI = TypeVar("_SI", bound=Union["SchemaItem", str]) + class BaseDDLElement(ClauseElement): """The root of DDL constructs, including those that are sub-elements @@ -87,7 +94,7 @@ class DDLIfCallable(Protocol): def __call__( self, ddl: BaseDDLElement, - target: SchemaItem, + target: Union[SchemaItem, str], bind: Optional[Connection], tables: Optional[List[Table]] = None, state: Optional[Any] = None, @@ -95,8 +102,7 @@ def __call__( dialect: Dialect, compiler: Optional[DDLCompiler] = ..., checkfirst: bool, - ) -> bool: - ... + ) -> bool: ... class DDLIf(typing.NamedTuple): @@ -107,7 +113,7 @@ class DDLIf(typing.NamedTuple): def _should_execute( self, ddl: BaseDDLElement, - target: SchemaItem, + target: Union[SchemaItem, str], bind: Optional[Connection], compiler: Optional[DDLCompiler] = None, **kw: Any, @@ -156,8 +162,8 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): event.listen( users, - 'after_create', - AddConstraint(constraint).execute_if(dialect='postgresql') + "after_create", + AddConstraint(constraint).execute_if(dialect="postgresql"), ) .. seealso:: @@ -173,7 +179,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): """ _ddl_if: Optional[DDLIf] = None - target: Optional[SchemaItem] = None + target: Union[SchemaItem, str, None] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -232,20 +238,20 @@ def execute_if( Used to provide a wrapper for event listening:: event.listen( - metadata, - 'before_create', - DDL("my_ddl").execute_if(dialect='postgresql') - ) + metadata, + "before_create", + DDL("my_ddl").execute_if(dialect="postgresql"), + ) :param dialect: May be a string or tuple of strings. If a string, it will be compared to the name of the executing database dialect:: - DDL('something').execute_if(dialect='postgresql') + DDL("something").execute_if(dialect="postgresql") If a tuple, specifies multiple dialect names:: - DDL('something').execute_if(dialect=('postgresql', 'mysql')) + DDL("something").execute_if(dialect=("postgresql", "mysql")) :param callable\_: A callable, which will be invoked with three positional arguments as well as optional keyword @@ -343,17 +349,19 @@ class DDL(ExecutableDDLElement): from sqlalchemy import event, DDL - tbl = Table('users', metadata, Column('uid', Integer)) - event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger')) + tbl = Table("users", metadata, Column("uid", Integer)) + event.listen(tbl, "before_create", DDL("DROP TRIGGER users_trigger")) - spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE') - event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb')) + spow = DDL("ALTER TABLE %(table)s SET secretpowers TRUE") + event.listen(tbl, "after_create", spow.execute_if(dialect="somedb")) - drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + drop_spow = DDL("ALTER TABLE users SET secretpowers FALSE") connection.execute(drop_spow) When operating on Table events, the following ``statement`` - string substitutions are available:: + string substitutions are available: + + .. sourcecode:: text %(table)s - the Table name, with any required quoting applied %(schema)s - the schema name, with any required quoting applied @@ -414,7 +422,7 @@ def __repr__(self): ) -class _CreateDropBase(ExecutableDDLElement): +class _CreateDropBase(ExecutableDDLElement, Generic[_SI]): """Base class for DDL constructs that represent CREATE and DROP or equivalents. @@ -424,15 +432,15 @@ class _CreateDropBase(ExecutableDDLElement): """ - def __init__( - self, - element, - ): + element: _SI + + def __init__(self, element: _SI) -> None: self.element = self.target = element self._ddl_if = getattr(element, "_ddl_if", None) @property - def stringify_dialect(self): + def stringify_dialect(self): # type: ignore[override] + assert not isinstance(self.element, str) return self.element.create_drop_stringify_dialect def _create_rule_disable(self, compiler): @@ -446,19 +454,19 @@ def _create_rule_disable(self, compiler): return False -class _CreateBase(_CreateDropBase): - def __init__(self, element, if_not_exists=False): +class _CreateBase(_CreateDropBase[_SI]): + def __init__(self, element: _SI, if_not_exists: bool = False) -> None: super().__init__(element) self.if_not_exists = if_not_exists -class _DropBase(_CreateDropBase): - def __init__(self, element, if_exists=False): +class _DropBase(_CreateDropBase[_SI]): + def __init__(self, element: _SI, if_exists: bool = False) -> None: super().__init__(element) self.if_exists = if_exists -class CreateSchema(_CreateBase): +class CreateSchema(_CreateBase[str]): """Represent a CREATE SCHEMA statement. The argument here is the string name of the schema. @@ -471,15 +479,15 @@ class CreateSchema(_CreateBase): def __init__( self, - name, - if_not_exists=False, - ): + name: str, + if_not_exists: bool = False, + ) -> None: """Create a new :class:`.CreateSchema` construct.""" super().__init__(element=name, if_not_exists=if_not_exists) -class DropSchema(_DropBase): +class DropSchema(_DropBase[str]): """Represent a DROP SCHEMA statement. The argument here is the string name of the schema. @@ -492,17 +500,17 @@ class DropSchema(_DropBase): def __init__( self, - name, - cascade=False, - if_exists=False, - ): + name: str, + cascade: bool = False, + if_exists: bool = False, + ) -> None: """Create a new :class:`.DropSchema` construct.""" super().__init__(element=name, if_exists=if_exists) self.cascade = cascade -class CreateTable(_CreateBase): +class CreateTable(_CreateBase["Table"]): """Represent a CREATE TABLE statement.""" __visit_name__ = "create_table" @@ -514,7 +522,7 @@ def __init__( typing_Sequence[ForeignKeyConstraint] ] = None, if_not_exists: bool = False, - ): + ) -> None: """Create a :class:`.CreateTable` construct. :param element: a :class:`_schema.Table` that's the subject @@ -536,7 +544,7 @@ def __init__( self.include_foreign_key_constraints = include_foreign_key_constraints -class _DropView(_DropBase): +class _DropView(_DropBase["Table"]): """Semi-public 'DROP VIEW' construct. Used by the test suite for dialect-agnostic drops of views. @@ -548,7 +556,9 @@ class _DropView(_DropBase): class CreateConstraint(BaseDDLElement): - def __init__(self, element: Constraint): + element: Constraint + + def __init__(self, element: Constraint) -> None: self.element = element @@ -569,6 +579,7 @@ class CreateColumn(BaseDDLElement): from sqlalchemy import schema from sqlalchemy.ext.compiler import compiles + @compiles(schema.CreateColumn) def compile(element, compiler, **kw): column = element.element @@ -577,9 +588,9 @@ def compile(element, compiler, **kw): return compiler.visit_create_column(element, **kw) text = "%s SPECIAL DIRECTIVE %s" % ( - column.name, - compiler.type_compiler.process(column.type) - ) + column.name, + compiler.type_compiler.process(column.type), + ) default = compiler.get_column_default_string(column) if default is not None: text += " DEFAULT " + default @@ -589,8 +600,8 @@ def compile(element, compiler, **kw): if column.constraints: text += " ".join( - compiler.process(const) - for const in column.constraints) + compiler.process(const) for const in column.constraints + ) return text The above construct can be applied to a :class:`_schema.Table` @@ -601,17 +612,21 @@ def compile(element, compiler, **kw): metadata = MetaData() - table = Table('mytable', MetaData(), - Column('x', Integer, info={"special":True}, primary_key=True), - Column('y', String(50)), - Column('z', String(20), info={"special":True}) - ) + table = Table( + "mytable", + MetaData(), + Column("x", Integer, info={"special": True}, primary_key=True), + Column("y", String(50)), + Column("z", String(20), info={"special": True}), + ) metadata.create_all(conn) Above, the directives we've added to the :attr:`_schema.Column.info` collection - will be detected by our custom compilation scheme:: + will be detected by our custom compilation scheme: + + .. sourcecode:: sql CREATE TABLE mytable ( x SPECIAL DIRECTIVE INTEGER NOT NULL, @@ -636,18 +651,21 @@ def compile(element, compiler, **kw): from sqlalchemy.schema import CreateColumn + @compiles(CreateColumn, "postgresql") def skip_xmin(element, compiler, **kw): - if element.element.name == 'xmin': + if element.element.name == "xmin": return None else: return compiler.visit_create_column(element, **kw) - my_table = Table('mytable', metadata, - Column('id', Integer, primary_key=True), - Column('xmin', Integer) - ) + my_table = Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("xmin", Integer), + ) Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE`` which only includes the ``id`` column in the string; the ``xmin`` column @@ -657,16 +675,18 @@ def skip_xmin(element, compiler, **kw): __visit_name__ = "create_column" - def __init__(self, element): + element: Column[Any] + + def __init__(self, element: Column[Any]) -> None: self.element = element -class DropTable(_DropBase): +class DropTable(_DropBase["Table"]): """Represent a DROP TABLE statement.""" __visit_name__ = "drop_table" - def __init__(self, element: Table, if_exists: bool = False): + def __init__(self, element: Table, if_exists: bool = False) -> None: """Create a :class:`.DropTable` construct. :param element: a :class:`_schema.Table` that's the subject @@ -681,30 +701,24 @@ def __init__(self, element: Table, if_exists: bool = False): super().__init__(element, if_exists=if_exists) -class CreateSequence(_CreateBase): +class CreateSequence(_CreateBase["Sequence"]): """Represent a CREATE SEQUENCE statement.""" __visit_name__ = "create_sequence" - def __init__(self, element: Sequence, if_not_exists: bool = False): - super().__init__(element, if_not_exists=if_not_exists) - -class DropSequence(_DropBase): +class DropSequence(_DropBase["Sequence"]): """Represent a DROP SEQUENCE statement.""" __visit_name__ = "drop_sequence" - def __init__(self, element: Sequence, if_exists: bool = False): - super().__init__(element, if_exists=if_exists) - -class CreateIndex(_CreateBase): +class CreateIndex(_CreateBase["Index"]): """Represent a CREATE INDEX statement.""" __visit_name__ = "create_index" - def __init__(self, element, if_not_exists=False): + def __init__(self, element: Index, if_not_exists: bool = False) -> None: """Create a :class:`.Createindex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -718,12 +732,12 @@ def __init__(self, element, if_not_exists=False): super().__init__(element, if_not_exists=if_not_exists) -class DropIndex(_DropBase): +class DropIndex(_DropBase["Index"]): """Represent a DROP INDEX statement.""" __visit_name__ = "drop_index" - def __init__(self, element, if_exists=False): + def __init__(self, element: Index, if_exists: bool = False) -> None: """Create a :class:`.DropIndex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -737,38 +751,88 @@ def __init__(self, element, if_exists=False): super().__init__(element, if_exists=if_exists) -class AddConstraint(_CreateBase): +class AddConstraint(_CreateBase["Constraint"]): """Represent an ALTER TABLE ADD CONSTRAINT statement.""" __visit_name__ = "add_constraint" - def __init__(self, element): + def __init__( + self, + element: Constraint, + *, + isolate_from_table: bool = True, + ) -> None: + """Construct a new :class:`.AddConstraint` construct. + + :param element: a :class:`.Constraint` object + + :param isolate_from_table: optional boolean, defaults to True. Has + the effect of the incoming constraint being isolated from being + included in a CREATE TABLE sequence when associated with a + :class:`.Table`. + + .. versionadded:: 2.0.39 - added + :paramref:`.AddConstraint.isolate_from_table`, defaulting + to True. Previously, the behavior of this parameter was implicitly + turned on in all cases. + + """ super().__init__(element) - element._create_rule = util.portable_instancemethod( - self._create_rule_disable - ) + + if isolate_from_table: + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) -class DropConstraint(_DropBase): +class DropConstraint(_DropBase["Constraint"]): """Represent an ALTER TABLE DROP CONSTRAINT statement.""" __visit_name__ = "drop_constraint" - def __init__(self, element, cascade=False, if_exists=False, **kw): + def __init__( + self, + element: Constraint, + *, + cascade: bool = False, + if_exists: bool = False, + isolate_from_table: bool = True, + **kw: Any, + ) -> None: + """Construct a new :class:`.DropConstraint` construct. + + :param element: a :class:`.Constraint` object + :param cascade: optional boolean, indicates backend-specific + "CASCADE CONSTRAINT" directive should be rendered if available + :param if_exists: optional boolean, indicates backend-specific + "IF EXISTS" directive should be rendered if available + :param isolate_from_table: optional boolean, defaults to True. Has + the effect of the incoming constraint being isolated from being + included in a CREATE TABLE sequence when associated with a + :class:`.Table`. + + .. versionadded:: 2.0.39 - added + :paramref:`.DropConstraint.isolate_from_table`, defaulting + to True. Previously, the behavior of this parameter was implicitly + turned on in all cases. + + """ self.cascade = cascade super().__init__(element, if_exists=if_exists, **kw) - element._create_rule = util.portable_instancemethod( - self._create_rule_disable - ) + + if isolate_from_table: + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) -class SetTableComment(_CreateDropBase): +class SetTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE IS statement.""" __visit_name__ = "set_table_comment" -class DropTableComment(_CreateDropBase): +class DropTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE '' statement. Note this varies a lot across database backends. @@ -778,33 +842,34 @@ class DropTableComment(_CreateDropBase): __visit_name__ = "drop_table_comment" -class SetColumnComment(_CreateDropBase): +class SetColumnComment(_CreateDropBase["Column[Any]"]): """Represent a COMMENT ON COLUMN IS statement.""" __visit_name__ = "set_column_comment" -class DropColumnComment(_CreateDropBase): +class DropColumnComment(_CreateDropBase["Column[Any]"]): """Represent a COMMENT ON COLUMN IS NULL statement.""" __visit_name__ = "drop_column_comment" -class SetConstraintComment(_CreateDropBase): +class SetConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS statement.""" __visit_name__ = "set_constraint_comment" -class DropConstraintComment(_CreateDropBase): +class DropConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS NULL statement.""" __visit_name__ = "drop_constraint_comment" class InvokeDDLBase(SchemaVisitor): - def __init__(self, connection): + def __init__(self, connection, **kw): self.connection = connection + assert not kw, f"Unexpected keywords: {kw.keys()}" @contextlib.contextmanager def with_ddl_events(self, target, **kw): @@ -1021,10 +1086,12 @@ def visit_metadata(self, metadata): reversed( sort_tables_and_constraints( unsorted_tables, - filter_fn=lambda constraint: False - if not self.dialect.supports_alter - or constraint.name is None - else None, + filter_fn=lambda constraint: ( + False + if not self.dialect.supports_alter + or constraint.name is None + else None + ), ) ) ) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 5dbf3e3573f..62c1be452e1 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -1,12 +1,11 @@ # sql/default_comparator.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Default implementation of SQL comparison operations. -""" +"""Default implementation of SQL comparison operations.""" from __future__ import annotations @@ -56,7 +55,6 @@ def _boolean_compare( negate_op: Optional[OperatorType] = None, reverse: bool = False, _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), - _any_all_expr: bool = False, result_type: Optional[TypeEngine[bool]] = None, **kwargs: Any, ) -> OperatorExpression[bool]: @@ -90,7 +88,7 @@ def _boolean_compare( negate=negate_op, modifiers=kwargs, ) - elif _any_all_expr: + elif expr._is_collection_aggregate: obj = coercions.expect( roles.ConstExprRole, element=obj, operator=op, expr=expr ) @@ -248,7 +246,7 @@ def _unsupported_impl( expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any ) -> NoReturn: raise NotImplementedError( - "Operator '%s' is not supported on " "this expression" % op.__name__ + "Operator '%s' is not supported on this expression" % op.__name__ ) @@ -297,9 +295,11 @@ def _match_impl( operator=operators.match_op, ), result_type=type_api.MATCHTYPE, - negate_op=operators.not_match_op - if op is operators.match_op - else operators.match_op, + negate_op=( + operators.not_match_op + if op is operators.match_op + else operators.match_op + ), **kw, ) @@ -341,9 +341,11 @@ def _between_impl( group=False, ), op, - negate=operators.not_between_op - if op is operators.between_op - else operators.between_op, + negate=( + operators.not_between_op + if op is operators.between_op + else operators.between_op + ), modifiers=kw, ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 4ca6ed338f4..f5071146be2 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1,5 +1,5 @@ # sql/dml.py -# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2009-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -23,6 +23,7 @@ from typing import Optional from typing import overload from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -42,6 +43,7 @@ from .base import _generative from .base import _select_iterables from .base import ColumnCollection +from .base import ColumnSet from .base import CompileState from .base import DialectKWArgs from .base import Executable @@ -91,14 +93,11 @@ from .selectable import Select from .selectable import Selectable - def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: - ... + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ... - def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: - ... + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ... - def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: - ... + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ... else: isupdate = operator.attrgetter("isupdate") @@ -137,9 +136,11 @@ def __init__( @classmethod def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: return { - "name": statement.table.name - if is_named_from_clause(statement.table) - else None, + "name": ( + statement.table.name + if is_named_from_clause(statement.table) + else None + ), "table": statement.table, } @@ -163,8 +164,7 @@ def dml_table(self) -> _DMLTableElement: if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: - ... + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: ... @classmethod def _get_multi_crud_kv_pairs( @@ -190,13 +190,15 @@ def _get_crud_kv_pairs( return [ ( coercions.expect(roles.DMLColumnRole, k), - v - if not needs_to_be_cacheable - else coercions.expect( - roles.ExpressionElementRole, - v, - type_=NullType(), - is_crud=True, + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ) ), ) for k, v in kv_iterator @@ -306,12 +308,14 @@ def _process_values(self, statement: ValuesBase) -> None: def _process_multi_values(self, statement: ValuesBase) -> None: for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ - { - c.key: value - for c, value in zip(statement.table.c, parameter_set) - } - if isinstance(parameter_set, collections_abc.Sequence) - else parameter_set + ( + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + ) for parameter_set in parameters ] @@ -396,9 +400,9 @@ class UpdateBase( __visit_name__ = "update_base" - _hints: util.immutabledict[ - Tuple[_DMLTableElement, str], str - ] = util.EMPTY_DICT + _hints: util.immutabledict[Tuple[_DMLTableElement, str], str] = ( + util.EMPTY_DICT + ) named_with_column = False _label_style: SelectLabelStyle = ( @@ -407,19 +411,25 @@ class UpdateBase( table: _DMLTableElement _return_defaults = False - _return_defaults_columns: Optional[ - Tuple[_ColumnsClauseElement, ...] - ] = None + _return_defaults_columns: Optional[Tuple[_ColumnsClauseElement, ...]] = ( + None + ) _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None _returning: Tuple[_ColumnsClauseElement, ...] = () is_dml = True def _generate_fromclause_column_proxies( - self, fromclause: FromClause + self, + fromclause: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], ) -> None: - fromclause._columns._populate_separate_keys( - col._make_proxy(fromclause) + columns._populate_separate_keys( + col._make_proxy( + fromclause, primary_key=primary_key, foreign_keys=foreign_keys + ) for col in self._all_selected_columns if is_column_element(col) ) @@ -523,11 +533,11 @@ def return_defaults( E.g.:: - stmt = table.insert().values(data='newdata').return_defaults() + stmt = table.insert().values(data="newdata").return_defaults() result = connection.execute(stmt) - server_created_at = result.returned_defaults['created_at'] + server_created_at = result.returned_defaults["created_at"] When used against an UPDATE statement :meth:`.UpdateBase.return_defaults` instead looks for columns that @@ -685,6 +695,16 @@ def return_defaults( return self + def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: + """Return ``True`` if this :class:`.ReturnsRows` is + 'derived' from the given :class:`.FromClause`. + + Since these are DMLs, we dont want such statements ever being adapted + so we return False for derives. + + """ + return False + @_generative def returning( self, @@ -1030,7 +1050,7 @@ def values( users.insert().values(name="some name") - users.update().where(users.c.id==5).values(name="some name") + users.update().where(users.c.id == 5).values(name="some name") :param \*args: As an alternative to passing key/value parameters, a dictionary, tuple, or list of dictionaries or tuples can be passed @@ -1060,13 +1080,17 @@ def values( this syntax is supported on backends such as SQLite, PostgreSQL, MySQL, but not necessarily others:: - users.insert().values([ - {"name": "some name"}, - {"name": "some other name"}, - {"name": "yet another name"}, - ]) + users.insert().values( + [ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ] + ) + + The above form would render a multiple VALUES statement similar to: - The above form would render a multiple VALUES statement similar to:: + .. sourcecode:: sql INSERT INTO users (name) VALUES (:name_1), @@ -1244,7 +1268,7 @@ def from_select( e.g.:: sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5) - ins = table2.insert().from_select(['a', 'b'], sel) + ins = table2.insert().from_select(["a", "b"], sel) :param names: a sequence of string column names or :class:`_schema.Column` @@ -1295,8 +1319,7 @@ def from_select( @overload def returning( self, __ent0: _TCCA[_T0], *, sort_by_parameter_order: bool = False - ) -> ReturningInsert[Tuple[_T0]]: - ... + ) -> ReturningInsert[Tuple[_T0]]: ... @overload def returning( @@ -1305,8 +1328,7 @@ def returning( __ent1: _TCCA[_T1], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1]]: ... @overload def returning( @@ -1316,8 +1338,7 @@ def returning( __ent2: _TCCA[_T2], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: ... @overload def returning( @@ -1328,8 +1349,7 @@ def returning( __ent3: _TCCA[_T3], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def returning( @@ -1341,8 +1361,7 @@ def returning( __ent4: _TCCA[_T4], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def returning( @@ -1355,8 +1374,7 @@ def returning( __ent5: _TCCA[_T5], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def returning( @@ -1370,8 +1388,7 @@ def returning( __ent6: _TCCA[_T6], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def returning( @@ -1386,8 +1403,9 @@ def returning( __ent7: _TCCA[_T7], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> ReturningInsert[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... # END OVERLOADED FUNCTIONS self.returning @@ -1397,16 +1415,14 @@ def returning( *cols: _ColumnsClauseArgument[Any], sort_by_parameter_order: bool = False, **__kw: Any, - ) -> ReturningInsert[Any]: - ... + ) -> ReturningInsert[Any]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], sort_by_parameter_order: bool = False, **__kw: Any, - ) -> ReturningInsert[Any]: - ... + ) -> ReturningInsert[Any]: ... class ReturningInsert(Insert, TypedReturnsRows[_TP]): @@ -1541,9 +1557,7 @@ def ordered_values(self, *args: Tuple[_DMLColumnArgument, Any]) -> Self: E.g.:: - stmt = table.update().ordered_values( - ("name", "ed"), ("ident": "foo") - ) + stmt = table.update().ordered_values(("name", "ed"), ("ident", "foo")) .. seealso:: @@ -1556,7 +1570,7 @@ def ordered_values(self, *args: Tuple[_DMLColumnArgument, Any]) -> Self: :paramref:`_expression.update.preserve_parameter_order` parameter, which will be removed in SQLAlchemy 2.0. - """ + """ # noqa: E501 if self._values: raise exc.ArgumentError( "This statement already has values present" @@ -1596,20 +1610,19 @@ def inline(self) -> Self: # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: - ... + def returning( + self, __ent0: _TCCA[_T0] + ) -> ReturningUpdate[Tuple[_T0]]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> ReturningUpdate[Tuple[_T0, _T1]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1]]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: ... @overload def returning( @@ -1618,8 +1631,7 @@ def returning( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def returning( @@ -1629,8 +1641,7 @@ def returning( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def returning( @@ -1641,8 +1652,7 @@ def returning( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def returning( @@ -1654,8 +1664,7 @@ def returning( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def returning( @@ -1668,21 +1677,20 @@ def returning( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> ReturningUpdate[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... # END OVERLOADED FUNCTIONS self.returning @overload def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningUpdate[Any]: - ... + ) -> ReturningUpdate[Any]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningUpdate[Any]: - ... + ) -> ReturningUpdate[Any]: ... class ReturningUpdate(Update, TypedReturnsRows[_TP]): @@ -1734,20 +1742,19 @@ def __init__(self, table: _DMLTableArgument): # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: - ... + def returning( + self, __ent0: _TCCA[_T0] + ) -> ReturningDelete[Tuple[_T0]]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> ReturningDelete[Tuple[_T0, _T1]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1]]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: ... @overload def returning( @@ -1756,8 +1763,7 @@ def returning( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def returning( @@ -1767,8 +1773,7 @@ def returning( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def returning( @@ -1779,8 +1784,7 @@ def returning( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def returning( @@ -1792,8 +1796,7 @@ def returning( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def returning( @@ -1806,21 +1809,20 @@ def returning( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + ) -> ReturningDelete[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... # END OVERLOADED FUNCTIONS self.returning @overload def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Any]: - ... + ) -> ReturningDelete[Any]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Any]: - ... + ) -> ReturningDelete[Any]: ... class ReturningDelete(Update, TypedReturnsRows[_TP]): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 90ee100aae0..2d9ee575620 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1,5 +1,5 @@ # sql/elements.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -14,7 +14,7 @@ from __future__ import annotations from decimal import Decimal -from enum import IntEnum +from enum import Enum import itertools import operator import re @@ -77,14 +77,19 @@ from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly from ..util.typing import Literal +from ..util.typing import ParamSpec from ..util.typing import Self + if typing.TYPE_CHECKING: + from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _HasDialect from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument + from .base import ColumnSet from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled @@ -103,9 +108,9 @@ from .type_api import TypeEngine from .visitors import _CloneCallableType from .visitors import _TraverseInternalsType + from .visitors import anon_map from ..engine import Connection from ..engine import Dialect - from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType @@ -113,6 +118,7 @@ from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result + _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] @@ -129,8 +135,7 @@ def literal( value: Any, type_: _TypeEngineArgument[_T], literal_execute: bool = False, -) -> BindParameter[_T]: - ... +) -> BindParameter[_T]: ... @overload @@ -138,8 +143,7 @@ def literal( value: _T, type_: None = None, literal_execute: bool = False, -) -> BindParameter[_T]: - ... +) -> BindParameter[_T]: ... @overload @@ -147,8 +151,7 @@ def literal( value: Any, type_: Optional[_TypeEngineArgument[Any]] = None, literal_execute: bool = False, -) -> BindParameter[Any]: - ... +) -> BindParameter[Any]: ... def literal( @@ -245,7 +248,7 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.url") def compile( self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> Compiled: @@ -281,7 +284,7 @@ def compile( from sqlalchemy.sql import table, column, select - t = table('t', column('x')) + t = table("t", column("x")) s = select(t).where(t.c.x == 5) @@ -297,8 +300,7 @@ def compile( if bind: dialect = bind.dialect elif self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() + dialect = self._default_dialect() else: url = util.preloaded.engine_url dialect = url.URL.create( @@ -307,6 +309,10 @@ def compile( return self._compiler(dialect, **kw) + def _default_dialect(self): + default = util.preloaded.engine_default + return default.StrCompileDialect() + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" @@ -387,8 +393,7 @@ def _order_by_label_element(self) -> Optional[Label[Any]]: def get_children( self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any - ) -> Iterable[ClauseElement]: - ... + ) -> Iterable[ClauseElement]: ... @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -404,6 +409,10 @@ def _set_propagate_attrs(self, values: Mapping[str, Any]) -> Self: self._propagate_attrs = util.immutabledict(values) return self + def _default_compiler(self) -> SQLCompiler: + dialect = self._default_dialect() + return dialect.statement_compiler(dialect, self) # type: ignore + def _clone(self, **kw: Any) -> Self: """Create a shallow copy of this ClauseElement. @@ -452,7 +461,7 @@ def _with_binary_element_type(self, type_): return self @property - def _constructor(self): + def _constructor(self): # type: ignore[override] """return the 'constructor' for this ClauseElement. This is for the purposes for creating a new object of @@ -585,10 +594,10 @@ def params( :func:`_expression.bindparam` elements replaced with values taken from the given dictionary:: - >>> clause = column('x') + bindparam('foo') + >>> clause = column("x") + bindparam("foo") >>> print(clause.compile().params) {'foo':None} - >>> print(clause.params({'foo':7}).compile().params) + >>> print(clause.params({"foo": 7}).compile().params) {'foo':7} """ @@ -685,6 +694,7 @@ def _compile_w_cache( else: elem_cache_key = None + extracted_params: Optional[Sequence[BindParameter[Any]]] if elem_cache_key is not None: if TYPE_CHECKING: assert compiled_cache is not None @@ -778,11 +788,10 @@ def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: def compile( # noqa: A001 self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, - ) -> SQLCompiler: - ... + ) -> SQLCompiler: ... class CompilerColumnElement( @@ -800,6 +809,7 @@ class CompilerColumnElement( __slots__ = () _propagate_attrs = util.EMPTY_DICT + _is_collection_aggregate = False # SQLCoreOperations should be suiting the ExpressionElementRole @@ -814,18 +824,15 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): if typing.TYPE_CHECKING: @util.non_memoized_property - def _propagate_attrs(self) -> _PropagateAttrsType: - ... + def _propagate_attrs(self) -> _PropagateAttrsType: ... def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def op( @@ -836,8 +843,7 @@ def op( *, return_type: _TypeEngineArgument[_OPT], python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[_OPT]]: - ... + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... @overload def op( @@ -847,8 +853,7 @@ def op( is_comparison: bool = ..., return_type: Optional[_TypeEngineArgument[Any]] = ..., python_impl: Optional[Callable[..., Any]] = ..., - ) -> Callable[[Any], BinaryExpression[Any]]: - ... + ) -> Callable[[Any], BinaryExpression[Any]]: ... def op( self, @@ -857,38 +862,30 @@ def op( is_comparison: bool = False, return_type: Optional[_TypeEngineArgument[Any]] = None, python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[Any]]: - ... + ) -> Callable[[Any], BinaryExpression[Any]]: ... def bool_op( self, opstring: str, precedence: int = 0, python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[bool]]: - ... + ) -> Callable[[Any], BinaryExpression[bool]]: ... - def __and__(self, other: Any) -> BooleanClauseList: - ... + def __and__(self, other: Any) -> BooleanClauseList: ... - def __or__(self, other: Any) -> BooleanClauseList: - ... + def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> ColumnElement[_T_co]: - ... + def __invert__(self) -> ColumnElement[_T_co]: ... - def __lt__(self, other: Any) -> ColumnElement[bool]: - ... + def __lt__(self, other: Any) -> ColumnElement[bool]: ... - def __le__(self, other: Any) -> ColumnElement[bool]: - ... + def __le__(self, other: Any) -> ColumnElement[bool]: ... # declare also that this class has an hash method otherwise # it may be assumed to be None by type checkers since the # object defines __eq__ and python sets it to None in that case: # https://docs.python.org/3/reference/datamodel.html#object.__hash__ - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... @@ -896,226 +893,172 @@ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> ColumnElement[bool]: - ... + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: - ... + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def __gt__(self, other: Any) -> ColumnElement[bool]: - ... + def __gt__(self, other: Any) -> ColumnElement[bool]: ... - def __ge__(self, other: Any) -> ColumnElement[bool]: - ... + def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> UnaryExpression[_T_co]: - ... + def __neg__(self) -> UnaryExpression[_T_co]: ... - def __contains__(self, other: Any) -> ColumnElement[bool]: - ... + def __contains__(self, other: Any) -> ColumnElement[bool]: ... - def __getitem__(self, index: Any) -> ColumnElement[Any]: - ... + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... @overload - def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: - ... + def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __lshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... - def __lshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: - ... + def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __rshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... - def __rshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... @overload - def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: - ... + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @overload - def concat(self, other: Any) -> ColumnElement[Any]: - ... + def concat(self, other: Any) -> ColumnElement[Any]: ... - def concat(self, other: Any) -> ColumnElement[Any]: - ... + def concat(self, other: Any) -> ColumnElement[Any]: ... def like( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def ilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def bitwise_xor(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_xor(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_or(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_or(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_and(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_and(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_not(self) -> UnaryExpression[_T_co]: - ... + def bitwise_not(self) -> UnaryExpression[_T_co]: ... - def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_rshift(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_rshift(self, other: Any) -> BinaryExpression[Any]: ... def in_( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_in( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notin_( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_like( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notlike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_ilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def is_(self, other: Any) -> BinaryExpression[bool]: - ... + def is_(self, other: Any) -> BinaryExpression[bool]: ... - def is_not(self, other: Any) -> BinaryExpression[bool]: - ... + def is_not(self, other: Any) -> BinaryExpression[bool]: ... - def isnot(self, other: Any) -> BinaryExpression[bool]: - ... + def isnot(self, other: Any) -> BinaryExpression[bool]: ... def startswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def istartswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def endswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def iendswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... - def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: - ... + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def icontains(self, other: Any, **kw: Any) -> ColumnElement[bool]: - ... + def icontains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: - ... + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... def regexp_match( self, pattern: Any, flags: Optional[str] = None - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def regexp_replace( self, pattern: Any, replacement: Any, flags: Optional[str] = None - ) -> ColumnElement[str]: - ... + ) -> ColumnElement[str]: ... - def desc(self) -> UnaryExpression[_T_co]: - ... + def desc(self) -> UnaryExpression[_T_co]: ... - def asc(self) -> UnaryExpression[_T_co]: - ... + def asc(self) -> UnaryExpression[_T_co]: ... - def nulls_first(self) -> UnaryExpression[_T_co]: - ... + def nulls_first(self) -> UnaryExpression[_T_co]: ... - def nullsfirst(self) -> UnaryExpression[_T_co]: - ... + def nullsfirst(self) -> UnaryExpression[_T_co]: ... - def nulls_last(self) -> UnaryExpression[_T_co]: - ... + def nulls_last(self) -> UnaryExpression[_T_co]: ... - def nullslast(self) -> UnaryExpression[_T_co]: - ... + def nullslast(self) -> UnaryExpression[_T_co]: ... - def collate(self, collation: str) -> CollationClause: - ... + def collate(self, collation: str) -> CollationClause: ... def between( self, cleft: Any, cright: Any, symmetric: bool = False - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: - ... + def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: ... - def any_(self) -> CollectionAggregate[Any]: - ... + def any_(self) -> CollectionAggregate[Any]: ... - def all_(self) -> CollectionAggregate[Any]: - ... + def all_(self) -> CollectionAggregate[Any]: ... # numeric overloads. These need more tweaking # in particular they all need to have a variant for Optiona[_T] @@ -1126,159 +1069,129 @@ def all_(self) -> CollectionAggregate[Any]: def __add__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload def __add__( self: _SQO[str], other: Any, - ) -> ColumnElement[str]: - ... + ) -> ColumnElement[str]: ... - def __add__(self, other: Any) -> ColumnElement[Any]: - ... + @overload + def __add__(self, other: Any) -> ColumnElement[Any]: ... + + def __add__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: - ... + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... - def __radd__(self, other: Any) -> ColumnElement[Any]: - ... + def __radd__(self, other: Any) -> ColumnElement[Any]: ... @overload def __sub__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __sub__(self, other: Any) -> ColumnElement[Any]: - ... + def __sub__(self, other: Any) -> ColumnElement[Any]: ... - def __sub__(self, other: Any) -> ColumnElement[Any]: - ... + def __sub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rsub__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __rsub__(self, other: Any) -> ColumnElement[Any]: - ... + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... - def __rsub__(self, other: Any) -> ColumnElement[Any]: - ... + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __mul__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __mul__(self, other: Any) -> ColumnElement[Any]: - ... + def __mul__(self, other: Any) -> ColumnElement[Any]: ... - def __mul__(self, other: Any) -> ColumnElement[Any]: - ... + def __mul__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rmul__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __rmul__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... - def __rmul__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __mod__(self, other: Any) -> ColumnElement[Any]: - ... + def __mod__(self, other: Any) -> ColumnElement[Any]: ... - def __mod__(self, other: Any) -> ColumnElement[Any]: - ... + def __mod__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rmod__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... - def __rmod__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... @overload def __truediv__( self: _SQO[int], other: Any - ) -> ColumnElement[_NUMERIC]: - ... + ) -> ColumnElement[_NUMERIC]: ... @overload - def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: - ... + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... @overload - def __truediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... - def __truediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rtruediv__( self: _SQO[_NMT], other: Any - ) -> ColumnElement[_NUMERIC]: - ... + ) -> ColumnElement[_NUMERIC]: ... @overload - def __rtruediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... - def __rtruediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __floordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... @overload - def __floordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __floordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __rfloordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... @overload - def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... class SQLColumnExpression( @@ -1384,9 +1297,9 @@ class ColumnElement( .. sourcecode:: pycon+sql >>> from sqlalchemy.sql import column - >>> column('a') + column('b') + >>> column("a") + column("b") - >>> print(column('a') + column('b')) + >>> print(column("a") + column("b")) {printsql}a + b .. seealso:: @@ -1404,6 +1317,7 @@ class ColumnElement( _is_column_element = True _insert_sentinel: bool = False _omit_from_statements = False + _is_collection_aggregate = False foreign_keys: AbstractSet[ForeignKey] = frozenset() @@ -1474,7 +1388,9 @@ def _non_anon_label(self) -> Optional[str]: SQL. Concretely, this is the "name" of a column or a label in a - SELECT statement; ```` and ```` below:: + SELECT statement; ```` and ```` below: + + .. sourcecode:: sql SELECT FROM table @@ -1527,16 +1443,12 @@ def _non_anon_label(self) -> Optional[str]: _alt_names: Sequence[str] = () @overload - def self_group( - self: ColumnElement[_T], against: Optional[OperatorType] = None - ) -> ColumnElement[_T]: - ... + def self_group(self, against: None = None) -> ColumnElement[_T]: ... @overload def self_group( - self: ColumnElement[Any], against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: - ... + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: ... def self_group( self, against: Optional[OperatorType] = None @@ -1552,12 +1464,10 @@ def self_group( return self @overload - def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: - ... + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: ... @overload - def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: - ... + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: ... def _negate(self) -> ColumnElement[Any]: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: @@ -1740,6 +1650,8 @@ def _make_proxy( self, selectable: FromClause, *, + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], name: Optional[str] = None, key: Optional[str] = None, name_is_truncatable: bool = False, @@ -1761,9 +1673,11 @@ def _make_proxy( assert key is not None co: ColumnClause[_T] = ColumnClause( - coercions.expect(roles.TruncatedLabelRole, name) - if name_is_truncatable - else name, + ( + coercions.expect(roles.TruncatedLabelRole, name) + if name_is_truncatable + else name + ), type_=getattr(self, "type", None), _selectable=selectable, ) @@ -2014,8 +1928,9 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): from sqlalchemy import bindparam - stmt = select(users_table).\ - where(users_table.c.name == bindparam('username')) + stmt = select(users_table).where( + users_table.c.name == bindparam("username") + ) Detailed discussion of how :class:`.BindParameter` is used is at :func:`.bindparam`. @@ -2075,9 +1990,12 @@ def __init__( if unique: self.key = _anonymous_label.safe_construct( id(self), - key - if key is not None and not isinstance(key, _anonymous_label) - else "param", + ( + key + if key is not None + and not isinstance(key, _anonymous_label) + else "param" + ), sanitize_key=True, ) self._key_is_anon = True @@ -2138,13 +2056,13 @@ def __init__( check_value = value[0] else: check_value = value - cast( - "BindParameter[typing_Tuple[Any, ...]]", self - ).type = type_._resolve_values_to_types(check_value) + cast("BindParameter[typing_Tuple[Any, ...]]", self).type = ( + type_._resolve_values_to_types(check_value) + ) else: - cast( - "BindParameter[typing_Tuple[Any, ...]]", self - ).type = type_ + cast("BindParameter[typing_Tuple[Any, ...]]", self).type = ( + type_ + ) else: self.type = type_ @@ -2209,8 +2127,8 @@ def _negate_in_binary(self, negated_op, original_op): else: return self - def _with_binary_element_type(self, type_): - c = ClauseElement._clone(self) + def _with_binary_element_type(self, type_: TypeEngine[Any]) -> Self: + c: Self = ClauseElement._clone(self) c.type = type_ return c @@ -2300,8 +2218,9 @@ class TypeClause(DQLDMLClauseElement): _traverse_internals: _TraverseInternalsType = [ ("type", InternalTraversal.dp_type) ] + type: TypeEngine[Any] - def __init__(self, type_): + def __init__(self, type_: TypeEngine[Any]): self.type = type_ @@ -2329,7 +2248,6 @@ class TextClause( t = text("SELECT * FROM users") result = connection.execute(t) - The :class:`_expression.TextClause` construct is produced using the :func:`_expression.text` function; see that function for full documentation. @@ -2358,6 +2276,8 @@ class TextClause( _omit_from_statements = False + _is_collection_aggregate = False + @property def _hide_froms(self) -> Iterable[FromClause]: return () @@ -2378,7 +2298,7 @@ def _select_iterable(self) -> _SelectIterable: _allow_label_resolve = False @property - def _is_star(self): + def _is_star(self): # type: ignore[override] return self.text == "*" def __init__(self, text: str): @@ -2404,16 +2324,19 @@ def bindparams( Given a text construct such as:: from sqlalchemy import text - stmt = text("SELECT id, name FROM user WHERE name=:name " - "AND timestamp=:timestamp") + + stmt = text( + "SELECT id, name FROM user WHERE name=:name AND timestamp=:timestamp" + ) the :meth:`_expression.TextClause.bindparams` method can be used to establish the initial value of ``:name`` and ``:timestamp``, using simple keyword arguments:: - stmt = stmt.bindparams(name='jack', - timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)) + stmt = stmt.bindparams( + name="jack", timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5) + ) Where above, new :class:`.BindParameter` objects will be generated with the names ``name`` and ``timestamp``, and @@ -2428,10 +2351,11 @@ def bindparams( argument, then an optional value and type:: from sqlalchemy import bindparam + stmt = stmt.bindparams( - bindparam('name', value='jack', type_=String), - bindparam('timestamp', type_=DateTime) - ) + bindparam("name", value="jack", type_=String), + bindparam("timestamp", type_=DateTime), + ) Above, we specified the type of :class:`.DateTime` for the ``timestamp`` bind, and the type of :class:`.String` for the ``name`` @@ -2441,8 +2365,9 @@ def bindparams( Additional bound parameters can be supplied at statement execution time, e.g.:: - result = connection.execute(stmt, - timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)) + result = connection.execute( + stmt, timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5) + ) The :meth:`_expression.TextClause.bindparams` method can be called repeatedly, @@ -2452,15 +2377,15 @@ def bindparams( first with typing information, and a second time with value information, and it will be combined:: - stmt = text("SELECT id, name FROM user WHERE name=:name " - "AND timestamp=:timestamp") + stmt = text( + "SELECT id, name FROM user WHERE name=:name " + "AND timestamp=:timestamp" + ) stmt = stmt.bindparams( - bindparam('name', type_=String), - bindparam('timestamp', type_=DateTime) + bindparam("name", type_=String), bindparam("timestamp", type_=DateTime) ) stmt = stmt.bindparams( - name='jack', - timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5) + name="jack", timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5) ) The :meth:`_expression.TextClause.bindparams` @@ -2474,18 +2399,17 @@ def bindparams( object:: stmt1 = text("select id from table where name=:name").bindparams( - bindparam("name", value='name1', unique=True) + bindparam("name", value="name1", unique=True) ) stmt2 = text("select id from table where name=:name").bindparams( - bindparam("name", value='name2', unique=True) + bindparam("name", value="name2", unique=True) ) - union = union_all( - stmt1.columns(column("id")), - stmt2.columns(column("id")) - ) + union = union_all(stmt1.columns(column("id")), stmt2.columns(column("id"))) + + The above statement will render as: - The above statement will render as:: + .. sourcecode:: sql select id from table where name=:name_1 UNION ALL select id from table where name=:name_2 @@ -2495,7 +2419,7 @@ def bindparams( :func:`_expression.text` constructs. - """ + """ # noqa: E501 self._bindparams = new_params = self._bindparams.copy() for bind in binds: @@ -2526,7 +2450,9 @@ def bindparams( @util.preload_module("sqlalchemy.sql.selectable") def columns( - self, *cols: _ColumnExpressionArgument[Any], **types: TypeEngine[Any] + self, + *cols: _ColumnExpressionArgument[Any], + **types: _TypeEngineArgument[Any], ) -> TextualSelect: r"""Turn this :class:`_expression.TextClause` object into a :class:`_expression.TextualSelect` @@ -2547,12 +2473,13 @@ def columns( from sqlalchemy.sql import column, text stmt = text("SELECT id, name FROM some_table") - stmt = stmt.columns(column('id'), column('name')).subquery('st') + stmt = stmt.columns(column("id"), column("name")).subquery("st") - stmt = select(mytable).\ - select_from( - mytable.join(stmt, mytable.c.name == stmt.c.name) - ).where(stmt.c.id > 5) + stmt = ( + select(mytable) + .select_from(mytable.join(stmt, mytable.c.name == stmt.c.name)) + .where(stmt.c.id > 5) + ) Above, we pass a series of :func:`_expression.column` elements to the :meth:`_expression.TextClause.columns` method positionally. These @@ -2573,10 +2500,10 @@ def columns( stmt = text("SELECT id, name, timestamp FROM some_table") stmt = stmt.columns( - column('id', Integer), - column('name', Unicode), - column('timestamp', DateTime) - ) + column("id", Integer), + column("name", Unicode), + column("timestamp", DateTime), + ) for id, name, timestamp in connection.execute(stmt): print(id, name, timestamp) @@ -2585,11 +2512,7 @@ def columns( types alone may be used, if only type conversion is needed:: stmt = text("SELECT id, name, timestamp FROM some_table") - stmt = stmt.columns( - id=Integer, - name=Unicode, - timestamp=DateTime - ) + stmt = stmt.columns(id=Integer, name=Unicode, timestamp=DateTime) for id, name, timestamp in connection.execute(stmt): print(id, name, timestamp) @@ -2603,26 +2526,31 @@ def columns( the result set will match to those columns positionally, meaning the name or origin of the column in the textual SQL doesn't matter:: - stmt = text("SELECT users.id, addresses.id, users.id, " - "users.name, addresses.email_address AS email " - "FROM users JOIN addresses ON users.id=addresses.user_id " - "WHERE users.id = 1").columns( - User.id, - Address.id, - Address.user_id, - User.name, - Address.email_address - ) + stmt = text( + "SELECT users.id, addresses.id, users.id, " + "users.name, addresses.email_address AS email " + "FROM users JOIN addresses ON users.id=addresses.user_id " + "WHERE users.id = 1" + ).columns( + User.id, + Address.id, + Address.user_id, + User.name, + Address.email_address, + ) - query = session.query(User).from_statement(stmt).options( - contains_eager(User.addresses)) + query = ( + session.query(User) + .from_statement(stmt) + .options(contains_eager(User.addresses)) + ) The :meth:`_expression.TextClause.columns` method provides a direct route to calling :meth:`_expression.FromClause.subquery` as well as :meth:`_expression.SelectBase.cte` against a textual SELECT statement:: - stmt = stmt.columns(id=Integer, name=String).cte('st') + stmt = stmt.columns(id=Integer, name=String).cte("st") stmt = select(sometable).where(sometable.c.id == stmt.c.id) @@ -2646,9 +2574,11 @@ def columns( ] positional_input_cols = [ - ColumnClause(col.key, types.pop(col.key)) - if col.key in types - else col + ( + ColumnClause(col.key, types.pop(col.key)) + if col.key in types + else col + ) for col in input_cols ] keyed_input_cols: List[NamedColumn[Any]] = [ @@ -2673,7 +2603,9 @@ def comparator(self): # be using this method. return self.type.comparator_factory(self) # type: ignore - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[Any]]: if against is operators.in_op: return Grouping(self) else: @@ -2693,9 +2625,11 @@ class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): _traverse_internals: _TraverseInternalsType = [] _singleton: Null - @util.memoized_property - def type(self): - return type_api.NULLTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.NULLTYPE @classmethod def _instance(cls) -> Null: @@ -2721,9 +2655,11 @@ class False_( _traverse_internals: _TraverseInternalsType = [] _singleton: False_ - @util.memoized_property - def type(self): - return type_api.BOOLEANTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.BOOLEANTYPE def _negate(self) -> True_: return True_._singleton @@ -2749,9 +2685,11 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): _traverse_internals: _TraverseInternalsType = [] _singleton: True_ - @util.memoized_property - def type(self): - return type_api.BOOLEANTYPE + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return type_api.BOOLEANTYPE def _negate(self) -> False_: return False_._singleton @@ -2872,7 +2810,9 @@ def append(self, clause): def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[Any]]: if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2895,7 +2835,9 @@ class OperatorExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if ( self.group and operators.is_precedent(self.operator, against) @@ -2957,6 +2899,9 @@ def _construct_for_op( *(left_flattened + right_flattened), ) + if right._is_collection_aggregate: + negate = None + return BinaryExpression( left, right, op, type_=type_, negate=negate, modifiers=modifiers ) @@ -3043,6 +2988,10 @@ def _construct_for_list( self.clauses = clauses self.operator = operator self.type = type_ + for c in clauses: + if c._propagate_attrs: + self._propagate_attrs = c._propagate_attrs + break return self def _negate(self) -> Any: @@ -3151,9 +3100,11 @@ def _construct( # which will link elements against the operator. flattened_clauses = itertools.chain.from_iterable( - (c for c in to_flat._flattened_operator_clauses) - if getattr(to_flat, "operator", None) is operator - else (to_flat,) + ( + (c for c in to_flat._flattened_operator_clauses) + if getattr(to_flat, "operator", None) is operator + else (to_flat,) + ) for to_flat in convert_clauses ) @@ -3250,7 +3201,9 @@ def or_( def _select_iterable(self) -> _SelectIterable: return (self,) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[bool]]: if not self.clauses: return self else: @@ -3333,7 +3286,7 @@ def _bind_param(self, operator, obj, type_=None, expanding=False): ] ) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: # Tuple is parenthesized by definition. return self @@ -3346,14 +3299,13 @@ class Case(ColumnElement[_T]): from sqlalchemy import case - stmt = select(users_table).\ - where( - case( - (users_table.c.name == 'wendy', 'W'), - (users_table.c.name == 'jack', 'J'), - else_='E' - ) - ) + stmt = select(users_table).where( + case( + (users_table.c.name == "wendy", "W"), + (users_table.c.name == "jack", "J"), + else_="E", + ) + ) Details on :class:`.Case` usage is at :func:`.case`. @@ -3395,7 +3347,7 @@ def __init__( except TypeError: pass - whenlist = [ + self.whens = [ ( coercions.expect( roles.ExpressionElementRole, @@ -3407,24 +3359,28 @@ def __init__( for (c, r) in new_whens ] - if whenlist: - type_ = whenlist[-1][-1].type - else: - type_ = None - if value is None: self.value = None else: self.value = coercions.expect(roles.ExpressionElementRole, value) - self.type = cast(_T, type_) - self.whens = whenlist - if else_ is not None: self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None + type_ = next( + ( + then.type + # Iterate `whens` in reverse to match previous behaviour + # where type of final element took priority + for *_, then in reversed(self.whens) + if not then.type._isnull + ), + self.else_.type if self.else_ is not None else type_api.NULLTYPE, + ) + self.type = cast(_T, type_) + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( @@ -3562,7 +3518,9 @@ def typed_expression(self): def wrapped_column_expression(self): return self.clause - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> TypeCoerce[_T]: grouped = self.clause.self_group(against=against) if grouped is not self.clause: return TypeCoerce(grouped, self.type) @@ -3757,7 +3715,7 @@ def _create_bitwise_not( @property def _order_by_label_element(self) -> Optional[Label[Any]]: - if self.modifier in (operators.desc_op, operators.asc_op): + if operators.is_order_by_modifier(self.modifier): return self.element._order_by_label_element else: return None @@ -3777,7 +3735,9 @@ def _negate(self): else: return ClauseElement._negate(self) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3795,6 +3755,7 @@ class CollectionAggregate(UnaryExpression[_T]): """ inherit_cache = True + _is_collection_aggregate = True @classmethod def _create_any( @@ -3831,15 +3792,19 @@ def _create_all( # operate and reverse_operate are hardwired to # dispatch onto the type comparator directly, so that we can # ensure "reversed" behavior. - def operate(self, op, *other, **kwargs): + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_T]: if not operators.is_comparison(op): raise exc.ArgumentError( "Only comparison operators may be used with ANY/ALL" ) - kwargs["reverse"] = kwargs["_any_all_expr"] = True + kwargs["reverse"] = True return self.comparator.operate(operators.mirror(op), *other, **kwargs) - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[_T]: # comparison operators should never call reverse_operate assert not operators.is_comparison(op) raise exc.ArgumentError( @@ -3863,7 +3828,7 @@ def __init__(self, element, operator, negate): def wrapped_column_expression(self): return self.element - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self def _negate(self): @@ -3882,9 +3847,9 @@ class BinaryExpression(OperatorExpression[_T]): .. sourcecode:: pycon+sql >>> from sqlalchemy.sql import column - >>> column('a') + column('b') + >>> column("a") + column("b") - >>> print(column('a') + column('b')) + >>> print(column("a") + column("b")) {printsql}a + b """ @@ -3922,10 +3887,9 @@ class BinaryExpression(OperatorExpression[_T]): """ - modifiers: Optional[Mapping[str, Any]] - left: ColumnElement[Any] right: ColumnElement[Any] + modifiers: Mapping[str, Any] def __init__( self, @@ -3973,7 +3937,7 @@ def __bool__(self): The rationale here is so that ColumnElement objects can be hashable. What? Well, suppose you do this:: - c1, c2 = column('x'), column('y') + c1, c2 = column("x"), column("y") s1 = set([c1, c2]) We do that **a lot**, columns inside of sets is an extremely basic @@ -4006,8 +3970,7 @@ def __bool__(self): def __invert__( self: BinaryExpression[_T], - ) -> BinaryExpression[_T]: - ... + ) -> BinaryExpression[_T]: ... @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4024,7 +3987,7 @@ def _negate(self): modifiers=self.modifiers, ) else: - return super()._negate() + return self.self_group()._negate() class Slice(ColumnElement[Any]): @@ -4064,7 +4027,7 @@ def __init__(self, start, stop, step, _name=None): ) self.type = type_api.NULLTYPE - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: assert against is operator.getitem return self @@ -4083,7 +4046,7 @@ class GroupedElement(DQLDMLClauseElement): element: ClauseElement - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self def _ungroup(self): @@ -4147,8 +4110,65 @@ def __setstate__(self, state): self.element = state["element"] self.type = state["type"] + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + -class _OverRange(IntEnum): +class _OverrideBinds(Grouping[_T]): + """used by cache_key->_apply_params_to_element to allow compilation / + execution of a SQL element that's been cached, using an alternate set of + bound parameter values. + + This is used by the ORM to swap new parameter values into expressions + that are embedded into loader options like with_expression(), + selectinload(). Previously, this task was accomplished using the + .params() method which would perform a deep-copy instead. This deep + copy proved to be too expensive for more complex expressions. + + See #11085 + + """ + + __visit_name__ = "override_binds" + + def __init__( + self, + element: ColumnElement[_T], + bindparams: Sequence[BindParameter[Any]], + replaces_params: Sequence[BindParameter[Any]], + ): + self.element = element + self.translate = { + k.key: v.value for k, v in zip(replaces_params, bindparams) + } + + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[typing_Tuple[Any, ...]]: + """generate a cache key for the given element, substituting its bind + values for the translation values present.""" + + existing_bps: List[BindParameter[Any]] = [] + ck = self.element._gen_cache_key(anon_map, existing_bps) + + bindparams.extend( + ( + bp._with_value( + self.translate[bp.key], maintain_key=True, required=False + ) + if bp.key in self.translate + else bp + ) + for bp in existing_bps + ) + + return ck + + +class _OverRange(Enum): RANGE_UNBOUNDED = 0 RANGE_CURRENT = 1 @@ -4156,6 +4176,8 @@ class _OverRange(IntEnum): RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED RANGE_CURRENT = _OverRange.RANGE_CURRENT +_IntOrRange = Union[int, _OverRange] + class Over(ColumnElement[_T]): """Represent an OVER clause. @@ -4175,6 +4197,7 @@ class Over(ColumnElement[_T]): ("partition_by", InternalTraversal.dp_clauseelement), ("range_", InternalTraversal.dp_plain_obj), ("rows", InternalTraversal.dp_plain_obj), + ("groups", InternalTraversal.dp_plain_obj), ] order_by: Optional[ClauseList] = None @@ -4184,25 +4207,18 @@ class Over(ColumnElement[_T]): """The underlying expression object to which this :class:`.Over` object refers.""" - range_: Optional[typing_Tuple[int, int]] + range_: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] + rows: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] + groups: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] def __init__( self, element: ColumnElement[_T], - partition_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, - order_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ): self.element = element if order_by is not None: @@ -4215,19 +4231,14 @@ def __init__( _literal_as_text_role=roles.ByOfRole, ) - if range_: - self.range_ = self._interpret_range(range_) - if rows: - raise exc.ArgumentError( - "'range_' and 'rows' are mutually exclusive" - ) - else: - self.rows = None - elif rows: - self.rows = self._interpret_range(rows) - self.range_ = None + if sum(bool(item) for item in (range_, rows, groups)) > 1: + raise exc.ArgumentError( + "only one of 'rows', 'range_', or 'groups' may be provided" + ) else: - self.rows = self.range_ = None + self.range_ = self._interpret_range(range_) if range_ else None + self.rows = self._interpret_range(rows) if rows else None + self.groups = self._interpret_range(groups) if groups else None def __reduce__(self): return self.__class__, ( @@ -4236,22 +4247,28 @@ def __reduce__(self): self.order_by, self.range_, self.rows, + self.groups, ) def _interpret_range( - self, range_: typing_Tuple[Optional[int], Optional[int]] - ) -> typing_Tuple[int, int]: + self, + range_: typing_Tuple[Optional[_IntOrRange], Optional[_IntOrRange]], + ) -> typing_Tuple[_IntOrRange, _IntOrRange]: if not isinstance(range_, tuple) or len(range_) != 2: raise exc.ArgumentError("2-tuple expected for range/rows") - lower: int - upper: int + r0, r1 = range_ + + lower: _IntOrRange + upper: _IntOrRange - if range_[0] is None: + if r0 is None: lower = RANGE_UNBOUNDED + elif isinstance(r0, _OverRange): + lower = r0 else: try: - lower = int(range_[0]) + lower = int(r0) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" @@ -4260,11 +4277,13 @@ def _interpret_range( if lower == 0: lower = RANGE_CURRENT - if range_[1] is None: + if r1 is None: upper = RANGE_UNBOUNDED + elif isinstance(r1, _OverRange): + upper = r1 else: try: - upper = int(range_[1]) + upper = int(r1) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" @@ -4275,9 +4294,11 @@ def _interpret_range( return lower, upper - @util.memoized_property - def type(self): - return self.element.type + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.element.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4301,7 +4322,7 @@ class WithinGroup(ColumnElement[_T]): ``rank()``, ``dense_rank()``, etc. It's supported only by certain database backends, such as PostgreSQL, - Oracle and MS SQL Server. + Oracle Database and MS SQL Server. The :class:`.WithinGroup` construct extracts its type from the method :meth:`.FunctionElement.within_group_type`. If this returns @@ -4320,7 +4341,7 @@ class WithinGroup(ColumnElement[_T]): def __init__( self, - element: FunctionElement[_T], + element: Union[FunctionElement[_T], FunctionFilter[_T]], *order_by: _ColumnExpressionArgument[Any], ): self.element = element @@ -4334,7 +4355,15 @@ def __reduce__(self): tuple(self.order_by) if self.order_by is not None else () ) - def over(self, partition_by=None, order_by=None, range_=None, rows=None): + def over( + self, + *, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4348,15 +4377,36 @@ def over(self, partition_by=None, order_by=None, range_=None, rows=None): order_by=order_by, range_=range_, rows=rows, + groups=groups, ) - @util.memoized_property - def type(self): - wgt = self.element.within_group_type(self) - if wgt is not None: - return wgt - else: - return self.element.type + @overload + def filter(self) -> Self: ... + + @overload + def filter( + self, + __criterion0: _ColumnExpressionArgument[bool], + *criterion: _ColumnExpressionArgument[bool], + ) -> FunctionFilter[_T]: ... + + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> Union[Self, FunctionFilter[_T]]: + """Produce a FILTER clause against this function.""" + if not criterion: + return self + return FunctionFilter(self, *criterion) + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + wgt = self.element.within_group_type(self) + if wgt is not None: + return wgt + else: + return self.element.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4371,7 +4421,7 @@ def _from_objects(self) -> List[FromClause]: ) -class FunctionFilter(ColumnElement[_T]): +class FunctionFilter(Generative, ColumnElement[_T]): """Represent a function FILTER clause. This is a special operator against aggregate and window functions, @@ -4400,13 +4450,14 @@ class FunctionFilter(ColumnElement[_T]): def __init__( self, - func: FunctionElement[_T], + func: Union[FunctionElement[_T], WithinGroup[_T]], *criterion: _ColumnExpressionArgument[bool], ): self.func = func - self.filter(*criterion) + self.filter.non_generative(self, *criterion) # type: ignore - def filter(self, *criterion): + @_generative + def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: """Produce an additional FILTER against the function. This method adds additional criteria to the initial criteria @@ -4444,6 +4495,7 @@ def over( ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: """Produce an OVER clause against this filtered function. @@ -4452,12 +4504,13 @@ def over( The expression:: - func.rank().filter(MyClass.y > 5).over(order_by='x') + func.rank().filter(MyClass.y > 5).over(order_by="x") is shorthand for:: from sqlalchemy import over, funcfilter - over(funcfilter(func.rank(), MyClass.y > 5), order_by='x') + + over(funcfilter(func.rank(), MyClass.y > 5), order_by="x") See :func:`_expression.over` for a full description. @@ -4468,17 +4521,35 @@ def over( order_by=order_by, range_=range_, rows=rows, + groups=groups, ) - def self_group(self, against=None): + def within_group( + self, *order_by: _ColumnExpressionArgument[Any] + ) -> WithinGroup[_T]: + """Produce a WITHIN GROUP (ORDER BY expr) clause against + this function. + """ + return WithinGroup(self, *order_by) + + def within_group_type( + self, within_group: WithinGroup[_T] + ) -> Optional[TypeEngine[_T]]: + return None + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: return self - @util.memoized_property - def type(self): - return self.func.type + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.func.type @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4509,7 +4580,7 @@ def description(self) -> str: return self.name @HasMemoized.memoized_attribute - def _tq_key_label(self): + def _tq_key_label(self) -> Optional[str]: """table qualified label based on column key. for table-bound columns this is _; @@ -4567,6 +4638,8 @@ def _make_proxy( self, selectable: FromClause, *, + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], name: Optional[str] = None, key: Optional[str] = None, name_is_truncatable: bool = False, @@ -4575,9 +4648,11 @@ def _make_proxy( **kw: Any, ) -> typing_Tuple[str, ColumnClause[_T]]: c = ColumnClause( - coercions.expect(roles.TruncatedLabelRole, name or self.name) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect(roles.TruncatedLabelRole, name or self.name) + if name_is_truncatable + else (name or self.name) + ), type_=self.type, _selectable=selectable, is_literal=False, @@ -4596,6 +4671,9 @@ def _make_proxy( return c.key, c +_PS = ParamSpec("_PS") + + class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): """Represents a column label (AS). @@ -4693,13 +4771,18 @@ def _order_by_label_element(self): def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]: return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): return self._apply_to_inner(self._element._negate) - def _apply_to_inner(self, fn, *arg, **kw): + def _apply_to_inner( + self, + fn: Callable[_PS, ColumnElement[_T]], + *arg: _PS.args, + **kw: _PS.kwargs, + ) -> Label[_T]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) @@ -4707,11 +4790,11 @@ def _apply_to_inner(self, fn, *arg, **kw): return self @property - def primary_key(self): + def primary_key(self): # type: ignore[override] return self.element.primary_key @property - def foreign_keys(self): + def foreign_keys(self): # type: ignore[override] return self.element.foreign_keys def _copy_internals( @@ -4737,6 +4820,8 @@ def _make_proxy( self, selectable: FromClause, *, + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], name: Optional[str] = None, compound_select_cols: Optional[Sequence[ColumnElement[Any]]] = None, **kw: Any, @@ -4749,6 +4834,8 @@ def _make_proxy( disallow_is_literal=True, name_is_truncatable=isinstance(name, _truncated_label), compound_select_cols=compound_select_cols, + primary_key=primary_key, + foreign_keys=foreign_keys, ) # there was a note here to remove this assertion, which was here @@ -4792,7 +4879,9 @@ class ColumnClause( id, name = column("id"), column("name") stmt = select(id, name).select_from("user") - The above statement would produce SQL like:: + The above statement would produce SQL like: + + .. sourcecode:: sql SELECT id, name FROM user @@ -4838,7 +4927,7 @@ class is usable by itself in those cases where behavioral requirements _is_multiparam_column = False @property - def _is_star(self): + def _is_star(self): # type: ignore[override] return self.is_literal and self.name == "*" def __init__( @@ -4981,6 +5070,8 @@ def _make_proxy( self, selectable: FromClause, *, + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], name: Optional[str] = None, key: Optional[str] = None, name_is_truncatable: bool = False, @@ -5005,9 +5096,11 @@ def _make_proxy( ) ) c = self._constructor( - coercions.expect(roles.TruncatedLabelRole, name or self.name) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect(roles.TruncatedLabelRole, name or self.name) + if name_is_truncatable + else (name or self.name) + ), type_=self.type, _selectable=selectable, is_literal=is_literal, @@ -5058,15 +5151,25 @@ class CollationClause(ColumnElement[str]): ] @classmethod + @util.preload_module("sqlalchemy.sql.sqltypes") def _create_collation_expression( cls, expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: + + sqltypes = util.preloaded.sql_sqltypes + expr = coercions.expect(roles.ExpressionElementRole[str], expression) + + if expr.type._type_affinity is sqltypes.String: + collate_type = expr.type._with_collation(collation) + else: + collate_type = expr.type + return BinaryExpression( expr, CollationClause(collation), operators.collate, - type_=expr.type, + type_=collate_type, ) def __init__(self, collation): @@ -5110,7 +5213,7 @@ class quoted_name(util.MemoizedSlots, str): A :class:`.quoted_name` object with ``quote=True`` is also prevented from being modified in the case of a so-called "name normalize" option. Certain database backends, such as - Oracle, Firebird, and DB2 "normalize" case-insensitive names + Oracle Database, Firebird, and DB2 "normalize" case-insensitive names as uppercase. The SQLAlchemy dialects for these backends convert from SQLAlchemy's lower-case-means-insensitive convention to the upper-case-means-insensitive conventions of those backends. @@ -5131,11 +5234,11 @@ class quoted_name(util.MemoizedSlots, str): from sqlalchemy import inspect from sqlalchemy.sql import quoted_name - engine = create_engine("oracle+cx_oracle://some_dsn") + engine = create_engine("oracle+oracledb://some_dsn") print(inspect(engine).has_table(quoted_name("some_table", True))) - The above logic will run the "has table" logic against the Oracle backend, - passing the name exactly as ``"some_table"`` without converting to + The above logic will run the "has table" logic against the Oracle Database + backend, passing the name exactly as ``"some_table"`` without converting to upper case. .. versionchanged:: 1.2 The :class:`.quoted_name` construct is now @@ -5150,13 +5253,11 @@ class quoted_name(util.MemoizedSlots, str): @overload @classmethod - def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: - ... + def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: ... @overload @classmethod - def construct(cls, value: None, quote: Optional[bool]) -> None: - ... + def construct(cls, value: None, quote: Optional[bool]) -> None: ... @classmethod def construct( @@ -5202,12 +5303,12 @@ def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: return cols -def _type_from_args(args): +def _type_from_args(args: Sequence[ColumnElement[_T]]) -> TypeEngine[_T]: for a in args: if not a.type._isnull: return a.type else: - return type_api.NULLTYPE + return type_api.NULLTYPE # type: ignore def _corresponding_column_or_error(fromclause, column, require_embedded=False): @@ -5223,6 +5324,20 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): return c +class _memoized_property_but_not_nulltype( + util.memoized_property["TypeEngine[_T]"] +): + """memoized property, but dont memoize NullType""" + + def __get__(self, obj, cls): + if obj is None: + return self + result = self.fget(obj) + if not result._isnull: + obj.__dict__[self.__name__] = result + return result + + class AnnotatedColumnElement(Annotated): _Annotated__element: ColumnElement[Any] @@ -5234,6 +5349,7 @@ def __init__(self, element, values): "_tq_key_label", "_tq_label", "_non_anon_label", + "type", ): self.__dict__.pop(attr, None) for attr in ("name", "key", "table"): @@ -5242,7 +5358,14 @@ def __init__(self, element, values): def _with_annotations(self, values): clone = super()._with_annotations(values) - clone.__dict__.pop("comparator", None) + for attr in ( + "comparator", + "_proxy_key", + "_tq_key_label", + "_tq_label", + "_non_anon_label", + ): + clone.__dict__.pop(attr, None) return clone @util.memoized_property @@ -5250,6 +5373,20 @@ def name(self): """pull 'name' from parent, if not present""" return self._Annotated__element.name + @_memoized_property_but_not_nulltype + def type(self): + """pull 'type' from parent and don't cache if null. + + type is routinely changed on existing columns within the + mapped_column() initialization process, and "type" is also consulted + during the creation of SQL expressions. Therefore it can change after + it was already retrieved. At the same time we don't want annotated + objects having overhead when expressions are produced, so continue + to memoize, but only when we have a non-null type. + + """ + return self._Annotated__element.type + @util.memoized_property def table(self): """pull 'table' from parent, if not present""" @@ -5299,11 +5436,12 @@ class conv(_truncated_label): E.g. when we create a :class:`.Constraint` using a naming convention as follows:: - m = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - t = Table('t', m, Column('x', Integer), - CheckConstraint('x > 5', name='x5')) + m = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + t = Table( + "t", m, Column("x", Integer), CheckConstraint("x > 5", name="x5") + ) The name of the above constraint will be rendered as ``"ck_t_x5"``. That is, the existing name ``x5`` is used in the naming convention as the @@ -5316,11 +5454,15 @@ class conv(_truncated_label): use this explicitly as follows:: - m = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - t = Table('t', m, Column('x', Integer), - CheckConstraint('x > 5', name=conv('ck_t_x5'))) + m = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + t = Table( + "t", + m, + Column("x", Integer), + CheckConstraint("x > 5", name=conv("ck_t_x5")), + ) Where above, the :func:`_schema.conv` marker indicates that the constraint name here is final, and the name will render as ``"ck_t_x5"`` and not diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py index b34d0741209..601092fd912 100644 --- a/lib/sqlalchemy/sql/events.py +++ b/lib/sqlalchemy/sql/events.py @@ -1,5 +1,5 @@ -# sqlalchemy/sql/events.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# sql/events.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -63,13 +63,14 @@ class DDLEvents(event.Events[SchemaEventTarget]): from sqlalchemy import Table, Column, Metadata, Integer m = MetaData() - some_table = Table('some_table', m, Column('data', Integer)) + some_table = Table("some_table", m, Column("data", Integer)) + @event.listens_for(some_table, "after_create") def after_create(target, connection, **kw): - connection.execute(text( - "ALTER TABLE %s SET name=foo_%s" % (target.name, target.name) - )) + connection.execute( + text("ALTER TABLE %s SET name=foo_%s" % (target.name, target.name)) + ) some_engine = create_engine("postgresql://scott:tiger@host/test") @@ -127,10 +128,11 @@ def after_create(target, connection, **kw): as listener callables:: from sqlalchemy import DDL + event.listen( some_table, "after_create", - DDL("ALTER TABLE %(table)s SET name=foo_%(table)s") + DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"), ) **Event Propagation to MetaData Copies** @@ -149,7 +151,7 @@ def after_create(target, connection, **kw): some_table, "after_create", DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"), - propagate=True + propagate=True, ) new_metadata = MetaData() @@ -169,7 +171,7 @@ def after_create(target, connection, **kw): :ref:`schema_ddl_sequences` - """ + """ # noqa: E501 _target_class_doc = "SomeSchemaClassOrObject" _dispatch_target = SchemaEventTarget @@ -358,16 +360,17 @@ def column_reflect( metadata = MetaData() - @event.listens_for(metadata, 'column_reflect') + + @event.listens_for(metadata, "column_reflect") def receive_column_reflect(inspector, table, column_info): # receives for all Table objects that are reflected # under this MetaData + ... # will use the above event hook my_table = Table("my_table", metadata, autoload_with=some_engine) - .. versionadded:: 1.4.0b2 The :meth:`_events.DDLEvents.column_reflect` hook may now be applied to a :class:`_schema.MetaData` object as well as the :class:`_schema.MetaData` class itself where it will @@ -379,9 +382,11 @@ def receive_column_reflect(inspector, table, column_info): from sqlalchemy import Table - @event.listens_for(Table, 'column_reflect') + + @event.listens_for(Table, "column_reflect") def receive_column_reflect(inspector, table, column_info): # receives for all Table objects that are reflected + ... It can also be applied to a specific :class:`_schema.Table` at the point that one is being reflected using the @@ -390,9 +395,7 @@ def receive_column_reflect(inspector, table, column_info): t1 = Table( "my_table", autoload_with=some_engine, - listeners=[ - ('column_reflect', receive_column_reflect) - ] + listeners=[("column_reflect", receive_column_reflect)], ) The dictionary of column information as returned by the diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b25fb50d40f..dc7dee13b12 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1,14 +1,11 @@ # sql/expression.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Defines the public namespace for SQL expression constructs. - - -""" +"""Defines the public namespace for SQL expression constructs.""" from __future__ import annotations diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fc23e9d2156..02ed4fa6bbb 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,14 +1,11 @@ # sql/functions.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls -"""SQL function API, factories, and built-in functions. - -""" +"""SQL function API, factories, and built-in functions.""" from __future__ import annotations @@ -17,13 +14,16 @@ from typing import Any from typing import cast from typing import Dict +from typing import List from typing import Mapping from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import annotation from . import coercions @@ -62,20 +62,34 @@ if TYPE_CHECKING: + from ._typing import _ByArgument + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrLiteralArgument + from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _StarOrOne from ._typing import _TypeEngineArgument + from .base import _EntityNamespace + from .elements import ClauseElement + from .elements import KeyedColumnElement + from .elements import TableValuedColumn + from .operators import OperatorType from ..engine.base import Connection from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter + from ..util.typing import Self _T = TypeVar("_T", bound=Any) +_S = TypeVar("_S", bound=Any) -_registry: util.defaultdict[ - str, Dict[str, Type[Function[Any]]] -] = util.defaultdict(dict) +_registry: util.defaultdict[str, Dict[str, Type[Function[Any]]]] = ( + util.defaultdict(dict) +) -def register_function(identifier, fn, package="_default"): +def register_function( + identifier: str, fn: Type[Function[Any]], package: str = "_default" +) -> None: """Associate a callable with a particular func. name. This is normally called by GenericFunction, but is also @@ -138,7 +152,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): clause_expr: Grouping[Any] - def __init__(self, *clauses: Any): + def __init__( + self, *clauses: _ColumnExpressionOrLiteralArgument[Any] + ) -> None: r"""Construct a :class:`.FunctionElement`. :param \*clauses: list of column expressions that form the arguments @@ -154,7 +170,7 @@ def __init__(self, *clauses: Any): :class:`.Function` """ - args = [ + args: Sequence[_ColumnExpressionArgument[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, @@ -171,7 +187,7 @@ def __init__(self, *clauses: Any): _non_anon_label = None @property - def _proxy_key(self): + def _proxy_key(self) -> Any: return super()._proxy_key or getattr(self, "name", None) def _execute_on_connection( @@ -184,7 +200,9 @@ def _execute_on_connection( self, distilled_params, execution_options ) - def scalar_table_valued(self, name, type_=None): + def scalar_table_valued( + self, name: str, type_: Optional[_TypeEngineArgument[_T]] = None + ) -> ScalarFunctionColumn[_T]: """Return a column expression that's against this :class:`_functions.FunctionElement` as a scalar table-valued expression. @@ -217,7 +235,9 @@ def scalar_table_valued(self, name, type_=None): return ScalarFunctionColumn(self, name, type_) - def table_valued(self, *expr, **kw): + def table_valued( + self, *expr: _ColumnExpressionOrStrLabelArgument[Any], **kw: Any + ) -> TableValuedAlias: r"""Return a :class:`_sql.TableValuedAlias` representation of this :class:`_functions.FunctionElement` with table-valued expressions added. @@ -225,9 +245,8 @@ def table_valued(self, *expr, **kw): .. sourcecode:: pycon+sql - >>> fn = ( - ... func.generate_series(1, 5). - ... table_valued("value", "start", "stop", "step") + >>> fn = func.generate_series(1, 5).table_valued( + ... "value", "start", "stop", "step" ... ) >>> print(select(fn)) @@ -244,7 +263,9 @@ def table_valued(self, *expr, **kw): .. sourcecode:: pycon+sql - >>> fn = func.generate_series(4, 1, -1).table_valued("gen", with_ordinality="ordinality") + >>> fn = func.generate_series(4, 1, -1).table_valued( + ... "gen", with_ordinality="ordinality" + ... ) >>> print(select(fn)) {printsql}SELECT anon_1.gen, anon_1.ordinality FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) WITH ORDINALITY AS anon_1 @@ -303,7 +324,9 @@ def table_valued(self, *expr, **kw): return new_func.alias(name=name, joins_implicitly=joins_implicitly) - def column_valued(self, name=None, joins_implicitly=False): + def column_valued( + self, name: Optional[str] = None, joins_implicitly: bool = False + ) -> TableValuedColumn[_T]: """Return this :class:`_functions.FunctionElement` as a column expression that selects from itself as a FROM clause. @@ -345,7 +368,7 @@ def column_valued(self, name=None, joins_implicitly=False): return self.alias(name=name, joins_implicitly=joins_implicitly).column @util.ro_non_memoized_property - def columns(self): + def columns(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 r"""The set of columns exported by this :class:`.FunctionElement`. This is a placeholder collection that allows the function to be @@ -354,7 +377,7 @@ def columns(self): .. sourcecode:: pycon+sql >>> from sqlalchemy import column, select, func - >>> stmt = select(column('x'), column('y')).select_from(func.myfunction()) + >>> stmt = select(column("x"), column("y")).select_from(func.myfunction()) >>> print(stmt) {printsql}SELECT x, y FROM myfunction() @@ -371,7 +394,7 @@ def columns(self): return self.c @util.ro_memoized_property - def c(self): + def c(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 """synonym for :attr:`.FunctionElement.columns`.""" return ColumnCollection( @@ -379,16 +402,21 @@ def c(self): ) @property - def _all_selected_columns(self): + def _all_selected_columns(self) -> Sequence[KeyedColumnElement[Any]]: if is_table_value_type(self.type): - cols = self.type._elements + # TODO: this might not be fully accurate + cols = cast( + "Sequence[KeyedColumnElement[Any]]", self.type._elements + ) else: cols = [self.label(None)] return cols @property - def exported_columns(self): + def exported_columns( # type: ignore[override] + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: return self.columns @HasMemoized.memoized_attribute @@ -399,7 +427,15 @@ def clauses(self) -> ClauseList: """ return cast(ClauseList, self.clause_expr.element) - def over(self, partition_by=None, order_by=None, rows=None, range_=None): + def over( + self, + *, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, + rows: Optional[Tuple[Optional[int], Optional[int]]] = None, + range_: Optional[Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this function. Used against aggregate or so-called "window" functions, @@ -407,12 +443,13 @@ def over(self, partition_by=None, order_by=None, rows=None, range_=None): The expression:: - func.row_number().over(order_by='x') + func.row_number().over(order_by="x") is shorthand for:: from sqlalchemy import over - over(func.row_number(), order_by='x') + + over(func.row_number(), order_by="x") See :func:`_expression.over` for a full description. @@ -429,9 +466,12 @@ def over(self, partition_by=None, order_by=None, rows=None, range_=None): order_by=order_by, rows=rows, range_=range_, + groups=groups, ) - def within_group(self, *order_by): + def within_group( + self, *order_by: _ColumnExpressionArgument[Any] + ) -> WithinGroup[_T]: """Produce a WITHIN GROUP (ORDER BY expr) clause against this function. Used against so-called "ordered set aggregate" and "hypothetical @@ -449,7 +489,19 @@ def within_group(self, *order_by): """ return WithinGroup(self, *order_by) - def filter(self, *criterion): + @overload + def filter(self) -> Self: ... + + @overload + def filter( + self, + __criterion0: _ColumnExpressionArgument[bool], + *criterion: _ColumnExpressionArgument[bool], + ) -> FunctionFilter[_T]: ... + + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> Union[Self, FunctionFilter[_T]]: """Produce a FILTER clause against this function. Used against aggregate and window functions, @@ -462,6 +514,7 @@ def filter(self, *criterion): is shorthand for:: from sqlalchemy import funcfilter + funcfilter(func.count(1), True) .. seealso:: @@ -479,7 +532,9 @@ def filter(self, *criterion): return self return FunctionFilter(self, *criterion) - def as_comparison(self, left_index, right_index): + def as_comparison( + self, left_index: int, right_index: int + ) -> FunctionAsBinary: """Interpret this expression as a boolean comparison between two values. @@ -516,7 +571,7 @@ def as_comparison(self, left_index, right_index): An ORM example is as follows:: class Venue(Base): - __tablename__ = 'venue' + __tablename__ = "venue" id = Column(Integer, primary_key=True) name = Column(String) @@ -524,9 +579,10 @@ class Venue(Base): "Venue", primaryjoin=func.instr( remote(foreign(name)), name + "/" - ).as_comparison(1, 2) == 1, + ).as_comparison(1, 2) + == 1, viewonly=True, - order_by=name + order_by=name, ) Above, the "Venue" class can load descendant "Venue" objects by @@ -554,10 +610,12 @@ class Venue(Base): return FunctionAsBinary(self, left_index, right_index) @property - def _from_objects(self): + def _from_objects(self) -> Any: return self.clauses._from_objects - def within_group_type(self, within_group): + def within_group_type( + self, within_group: WithinGroup[_S] + ) -> Optional[TypeEngine[_S]]: """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the :class:`.WithinGroup` construct. @@ -569,7 +627,9 @@ def within_group_type(self, within_group): return None - def alias(self, name=None, joins_implicitly=False): + def alias( + self, name: Optional[str] = None, joins_implicitly: bool = False + ) -> TableValuedAlias: r"""Produce a :class:`_expression.Alias` construct against this :class:`.FunctionElement`. @@ -647,7 +707,7 @@ def alias(self, name=None, joins_implicitly=False): joins_implicitly=joins_implicitly, ) - def select(self) -> Select[Any]: + def select(self) -> Select[Tuple[_T]]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -661,7 +721,14 @@ def select(self) -> Select[Any]: s = s.execution_options(**self._execution_options) return s - def _bind_param(self, operator, obj, type_=None, **kw): + def _bind_param( + self, + operator: OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + **kw: Any, + ) -> BindParameter[_T]: return BindParameter( None, obj, @@ -669,10 +736,11 @@ def _bind_param(self, operator, obj, type_=None, **kw): _compared_to_type=self.type, unique=True, type_=type_, + expanding=expanding, **kw, ) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> ClauseElement: # type: ignore[override] # noqa E501 # for the moment, we are parenthesizing all array-returning # expressions against getitem. This may need to be made # more portable if in the future we support other DBs @@ -685,7 +753,7 @@ def self_group(self, against=None): return super().self_group(against=against) @property - def entity_namespace(self): + def entity_namespace(self) -> _EntityNamespace: """overrides FromClause.entity_namespace as functions are generally column expressions and not FromClauses. @@ -707,12 +775,12 @@ class FunctionAsBinary(BinaryExpression[Any]): left_index: int right_index: int - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key(self, anon_map: Any, bindparams: Any) -> Any: return ColumnElement._gen_cache_key(self, anon_map, bindparams) def __init__( self, fn: FunctionElement[Any], left_index: int, right_index: int - ): + ) -> None: self.sql_function = fn self.left_index = left_index self.right_index = right_index @@ -721,7 +789,7 @@ def __init__( self.type = sqltypes.BOOLEANTYPE self.negate = None self._is_implicitly_boolean = True - self.modifiers = {} + self.modifiers = util.immutabledict({}) @property def left_expr(self) -> ColumnElement[Any]: @@ -764,7 +832,7 @@ def __init__( fn: FunctionElement[_T], name: str, type_: Optional[_TypeEngineArgument[_T]] = None, - ): + ) -> None: self.fn = fn self.name = name @@ -818,8 +886,11 @@ class _FunctionGenerator: .. sourcecode:: pycon+sql - >>> print(func.my_string(u'hi', type_=Unicode) + ' ' + - ... func.my_string(u'there', type_=Unicode)) + >>> print( + ... func.my_string("hi", type_=Unicode) + ... + " " + ... + func.my_string("there", type_=Unicode) + ... ) {printsql}my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3) The object returned by a :data:`.func` call is usually an instance of @@ -860,8 +931,8 @@ class _FunctionGenerator: """ # noqa - def __init__(self, **opts): - self.__names = [] + def __init__(self, **opts: Any) -> None: + self.__names: List[str] = [] self.opts = opts def __getattr__(self, name: str) -> _FunctionGenerator: @@ -881,12 +952,10 @@ def __getattr__(self, name: str) -> _FunctionGenerator: @overload def __call__( self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any - ) -> Function[_T]: - ... + ) -> Function[_T]: ... @overload - def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: - ... + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: ... def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: o = self.opts.copy() @@ -917,148 +986,274 @@ def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: # statically generated** by tools/generate_sql_functions.py @property - def aggregate_strings(self) -> Type[aggregate_strings]: - ... + def aggregate_strings(self) -> Type[aggregate_strings]: ... @property - def ansifunction(self) -> Type[AnsiFunction[Any]]: - ... + def ansifunction(self) -> Type[AnsiFunction[Any]]: ... - @property - def array_agg(self) -> Type[array_agg[Any]]: - ... + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 - @property - def cast(self) -> Type[Cast[Any]]: - ... + @overload + def array_agg( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... - @property - def char_length(self) -> Type[char_length]: - ... + @overload + def array_agg( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... - @property - def coalesce(self) -> Type[coalesce[Any]]: - ... + @overload + def array_agg( + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... - @property - def concat(self) -> Type[concat]: - ... + def array_agg( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... @property - def count(self) -> Type[count]: - ... + def cast(self) -> Type[Cast[Any]]: ... @property - def cube(self) -> Type[cube[Any]]: - ... + def char_length(self) -> Type[char_length]: ... - @property - def cume_dist(self) -> Type[cume_dist]: - ... + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def coalesce( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: ... + + @overload + def coalesce( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: ... + + @overload + def coalesce( + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: ... + + def coalesce( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: ... @property - def current_date(self) -> Type[current_date]: - ... + def concat(self) -> Type[concat]: ... @property - def current_time(self) -> Type[current_time]: - ... + def count(self) -> Type[count]: ... @property - def current_timestamp(self) -> Type[current_timestamp]: - ... + def cube(self) -> Type[cube[Any]]: ... @property - def current_user(self) -> Type[current_user]: - ... + def cume_dist(self) -> Type[cume_dist]: ... @property - def dense_rank(self) -> Type[dense_rank]: - ... + def current_date(self) -> Type[current_date]: ... @property - def extract(self) -> Type[Extract]: - ... + def current_time(self) -> Type[current_time]: ... @property - def grouping_sets(self) -> Type[grouping_sets[Any]]: - ... + def current_timestamp(self) -> Type[current_timestamp]: ... @property - def localtime(self) -> Type[localtime]: - ... + def current_user(self) -> Type[current_user]: ... @property - def localtimestamp(self) -> Type[localtimestamp]: - ... + def dense_rank(self) -> Type[dense_rank]: ... @property - def max(self) -> Type[max[Any]]: # noqa: A001 - ... + def extract(self) -> Type[Extract]: ... @property - def min(self) -> Type[min[Any]]: # noqa: A001 - ... + def grouping_sets(self) -> Type[grouping_sets[Any]]: ... @property - def mode(self) -> Type[mode[Any]]: - ... + def localtime(self) -> Type[localtime]: ... @property - def next_value(self) -> Type[next_value]: - ... + def localtimestamp(self) -> Type[localtimestamp]: ... + + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def max( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> max[_T]: ... + + @overload + def max( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> max[_T]: ... + + @overload + def max( # noqa: A001 + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> max[_T]: ... + + def max( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> max[_T]: ... + + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def min( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> min[_T]: ... + + @overload + def min( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> min[_T]: ... + + @overload + def min( # noqa: A001 + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> min[_T]: ... + + def min( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> min[_T]: ... @property - def now(self) -> Type[now]: - ... + def mode(self) -> Type[mode[Any]]: ... @property - def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: - ... + def next_value(self) -> Type[next_value]: ... @property - def percent_rank(self) -> Type[percent_rank]: - ... + def now(self) -> Type[now]: ... @property - def percentile_cont(self) -> Type[percentile_cont[Any]]: - ... + def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: ... @property - def percentile_disc(self) -> Type[percentile_disc[Any]]: - ... + def percent_rank(self) -> Type[percent_rank]: ... @property - def random(self) -> Type[random]: - ... + def percentile_cont(self) -> Type[percentile_cont[Any]]: ... @property - def rank(self) -> Type[rank]: - ... + def percentile_disc(self) -> Type[percentile_disc[Any]]: ... @property - def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]: - ... + def random(self) -> Type[random]: ... @property - def rollup(self) -> Type[rollup[Any]]: - ... + def rank(self) -> Type[rank]: ... @property - def session_user(self) -> Type[session_user]: - ... + def rollup(self) -> Type[rollup[Any]]: ... @property - def sum(self) -> Type[sum[Any]]: # noqa: A001 - ... + def session_user(self) -> Type[session_user]: ... + + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def sum( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> sum[_T]: ... + + @overload + def sum( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> sum[_T]: ... + + @overload + def sum( # noqa: A001 + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> sum[_T]: ... + + def sum( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> sum[_T]: ... @property - def sysdate(self) -> Type[sysdate]: - ... + def sysdate(self) -> Type[sysdate]: ... @property - def user(self) -> Type[user]: - ... + def user(self) -> Type[user]: ... # END GENERATED FUNCTION ACCESSORS @@ -1131,13 +1326,31 @@ class Function(FunctionElement[_T]): """ + @overload + def __init__( + self, + name: str, + *clauses: _ColumnExpressionOrLiteralArgument[_T], + type_: None = ..., + packagenames: Optional[Tuple[str, ...]] = ..., + ) -> None: ... + + @overload + def __init__( + self, + name: str, + *clauses: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T] = ..., + packagenames: Optional[Tuple[str, ...]] = ..., + ) -> None: ... + def __init__( self, name: str, - *clauses: Any, + *clauses: _ColumnExpressionOrLiteralArgument[Any], type_: Optional[_TypeEngineArgument[_T]] = None, packagenames: Optional[Tuple[str, ...]] = None, - ): + ) -> None: """Construct a :class:`.Function`. The :data:`.func` construct is normally used to construct @@ -1153,7 +1366,14 @@ def __init__( FunctionElement.__init__(self, *clauses) - def _bind_param(self, operator, obj, type_=None, **kw): + def _bind_param( + self, + operator: OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + **kw: Any, + ) -> BindParameter[_T]: return BindParameter( self.name, obj, @@ -1161,6 +1381,7 @@ def _bind_param(self, operator, obj, type_=None, **kw): _compared_to_type=self.type, type_=type_, unique=True, + expanding=expanding, **kw, ) @@ -1187,10 +1408,12 @@ class that is instantiated automatically when called from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.types import DateTime + class as_utc(GenericFunction): type = DateTime() inherit_cache = True + print(select(func.as_utc())) User-defined generic functions can be organized into @@ -1238,6 +1461,7 @@ class GeoBuffer(GenericFunction): from sqlalchemy.sql import quoted_name + class GeoBuffer(GenericFunction): type = Geometry() package = "geo" @@ -1306,7 +1530,9 @@ def _register_generic_function( # Set _register to True to register child classes by default cls._register = True - def __init__(self, *args, **kwargs): + def __init__( + self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any + ) -> None: parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: parsed_args = [ @@ -1332,8 +1558,8 @@ def __init__(self, *args, **kwargs): ) -register_function("cast", Cast) -register_function("extract", Extract) +register_function("cast", Cast) # type: ignore +register_function("extract", Extract) # type: ignore class next_value(GenericFunction[int]): @@ -1353,7 +1579,7 @@ class next_value(GenericFunction[int]): ("sequence", InternalTraversal.dp_named_ddl_element) ] - def __init__(self, seq, **kw): + def __init__(self, seq: schema.Sequence, **kw: Any) -> None: assert isinstance( seq, schema.Sequence ), "next_value() accepts a Sequence object as input." @@ -1362,14 +1588,14 @@ def __init__(self, seq, **kw): seq.data_type or getattr(self, "type", None) ) - def compare(self, other, **kw): + def compare(self, other: Any, **kw: Any) -> bool: return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) @property - def _from_objects(self): + def _from_objects(self) -> Any: return [] @@ -1378,17 +1604,52 @@ class AnsiFunction(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args, **kwargs): + def __init__( + self, *args: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> None: GenericFunction.__init__(self, *args, **kwargs) class ReturnTypeFromArgs(GenericFunction[_T]): - """Define a function whose return type is the same as its arguments.""" + """Define a function whose return type is bound to the type of its + arguments. + """ inherit_cache = True - def __init__(self, *args, **kwargs): - fn_args = [ + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def __init__( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + def __init__( + self, *args: _ColumnExpressionOrLiteralArgument[_T], **kwargs: Any + ) -> None: + fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, @@ -1444,7 +1705,7 @@ class concat(GenericFunction[str]): .. sourcecode:: pycon+sql - >>> print(select(func.concat('a', 'b'))) + >>> print(select(func.concat("a", "b"))) {printsql}SELECT concat(:concat_2, :concat_3) AS concat_1 String concatenation in SQLAlchemy is more commonly available using the @@ -1469,7 +1730,7 @@ class char_length(GenericFunction[int]): type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg, **kw): + def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any) -> None: # slight hack to limit to just one positional argument # not sure why this one function has this special treatment super().__init__(arg, **kw) @@ -1492,21 +1753,30 @@ class count(GenericFunction[int]): from sqlalchemy import select from sqlalchemy import table, column - my_table = table('some_table', column('id')) + my_table = table("some_table", column("id")) stmt = select(func.count()).select_from(my_table) - Executing ``stmt`` would emit:: + Executing ``stmt`` would emit: + + .. sourcecode:: sql SELECT count(*) AS count_1 FROM some_table """ + type = sqltypes.Integer() inherit_cache = True - def __init__(self, expression=None, **kwargs): + def __init__( + self, + expression: Union[ + _ColumnExpressionArgument[Any], _StarOrOne, None + ] = None, + **kwargs: Any, + ) -> None: if expression is None: expression = literal_column("*") super().__init__(expression, **kwargs) @@ -1575,7 +1845,7 @@ class user(AnsiFunction[str]): inherit_cache = True -class array_agg(GenericFunction[_T]): +class array_agg(ReturnTypeFromArgs[Sequence[_T]]): """Support for the ARRAY_AGG function. The ``func.array_agg(expr)`` construct returns an expression of @@ -1595,8 +1865,10 @@ class array_agg(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args, **kwargs): - fn_args = [ + def __init__( + self, *args: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> None: + fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self ) @@ -1624,9 +1896,13 @@ class OrderedSetAgg(GenericFunction[_T]): array_for_multi_clause = False inherit_cache = True - def within_group_type(self, within_group): + def within_group_type( + self, within_group: WithinGroup[Any] + ) -> TypeEngine[Any]: func_clauses = cast(ClauseList, self.clause_expr.element) - order_by = sqlutil.unwrap_order_by(within_group.order_by) + order_by: Sequence[ColumnElement[Any]] = sqlutil.unwrap_order_by( + within_group.order_by + ) if self.array_for_multi_clause and len(func_clauses.clauses) > 1: return sqltypes.ARRAY(order_by[0].type) else: @@ -1747,6 +2023,7 @@ class cube(GenericFunction[_T]): .. versionadded:: 1.2 """ + _has_args = True inherit_cache = True @@ -1764,6 +2041,7 @@ class rollup(GenericFunction[_T]): .. versionadded:: 1.2 """ + _has_args = True inherit_cache = True @@ -1783,9 +2061,7 @@ class grouping_sets(GenericFunction[_T]): from sqlalchemy import tuple_ stmt = select( - func.sum(table.c.value), - table.c.col_1, table.c.col_2, - table.c.col_3 + func.sum(table.c.value), table.c.col_1, table.c.col_2, table.c.col_3 ).group_by( func.grouping_sets( tuple_(table.c.col_1, table.c.col_2), @@ -1793,10 +2069,10 @@ class grouping_sets(GenericFunction[_T]): ) ) - .. versionadded:: 1.2 - """ + """ # noqa: E501 + _has_args = True inherit_cache = True @@ -1824,5 +2100,7 @@ class aggregate_strings(GenericFunction[str]): _has_args = True inherit_cache = True - def __init__(self, clause, separator): + def __init__( + self, clause: _ColumnExpressionArgument[Any], separator: str + ) -> None: super().__init__(clause, separator) diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 7aef605ac72..21c69fed5af 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -1,5 +1,5 @@ # sql/lambdas.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -256,10 +256,7 @@ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): self.closure_cache_key = cache_key - try: - rec = lambda_cache[tracker_key + cache_key] - except KeyError: - rec = None + rec = lambda_cache.get(tracker_key + cache_key) else: cache_key = _cache_key.NO_CACHE rec = None @@ -278,7 +275,7 @@ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): rec = AnalyzedFunction( tracker, self, apply_propagate_attrs, fn ) - rec.closure_bindparams = bindparams + rec.closure_bindparams = list(bindparams) lambda_cache[key] = rec else: rec = lambda_cache[key] @@ -303,7 +300,9 @@ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): while lambda_element is not None: rec = lambda_element._rec if rec.bindparam_trackers: - tracker_instrumented_fn = rec.tracker_instrumented_fn + tracker_instrumented_fn = ( + rec.tracker_instrumented_fn # type:ignore [union-attr] # noqa: E501 + ) for tracker in rec.bindparam_trackers: tracker( lambda_element.fn, @@ -407,9 +406,9 @@ def _gen_cache_key(self, anon_map, bindparams): while parent is not None: assert parent.closure_cache_key is not CacheConst.NO_CACHE - parent_closure_cache_key: Tuple[ - Any, ... - ] = parent.closure_cache_key + parent_closure_cache_key: Tuple[Any, ...] = ( + parent.closure_cache_key + ) cache_key = ( (parent.fn.__code__,) + parent_closure_cache_key + cache_key @@ -437,7 +436,7 @@ class DeferredLambdaElement(LambdaElement): def __init__( self, - fn: _LambdaType, + fn: _AnyLambdaType, role: Type[roles.SQLRole], opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, lambda_args: Tuple[Any, ...] = (), @@ -518,7 +517,6 @@ class StatementLambdaElement( stmt += lambda s: s.where(table.c.col == parameter) - .. versionadded:: 1.4 .. seealso:: @@ -535,8 +533,7 @@ def __init__( role: Type[SQLRole], opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, apply_propagate_attrs: Optional[ClauseElement] = None, - ): - ... + ): ... def __add__( self, other: _StmtLambdaElementType[Any] @@ -559,9 +556,7 @@ def add_criteria( ... stmt = lambda_stmt( ... lambda: select(table.c.x, table.c.y), ... ) - ... stmt = stmt.add_criteria( - ... lambda: table.c.x > parameter - ... ) + ... stmt = stmt.add_criteria(lambda: table.c.x > parameter) ... return stmt The :meth:`_sql.StatementLambdaElement.add_criteria` method is @@ -572,18 +567,15 @@ def add_criteria( >>> def my_stmt(self, foo): ... stmt = lambda_stmt( ... lambda: select(func.max(foo.x, foo.y)), - ... track_closure_variables=False - ... ) - ... stmt = stmt.add_criteria( - ... lambda: self.where_criteria, - ... track_on=[self] + ... track_closure_variables=False, ... ) + ... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self]) ... return stmt See :func:`_sql.lambda_stmt` for a description of the parameters accepted. - """ + """ # noqa: E501 opts = self.opts + dict( enable_tracking=enable_tracking, @@ -612,7 +604,7 @@ def _proxied(self) -> Any: return self._rec_expected_expr @property - def _with_options(self): + def _with_options(self): # type: ignore[override] return self._proxied._with_options @property @@ -620,7 +612,7 @@ def _effective_plugin_target(self): return self._proxied._effective_plugin_target @property - def _execution_options(self): + def _execution_options(self): # type: ignore[override] return self._proxied._execution_options @property @@ -628,27 +620,27 @@ def _all_selected_columns(self): return self._proxied._all_selected_columns @property - def is_select(self): + def is_select(self): # type: ignore[override] return self._proxied.is_select @property - def is_update(self): + def is_update(self): # type: ignore[override] return self._proxied.is_update @property - def is_insert(self): + def is_insert(self): # type: ignore[override] return self._proxied.is_insert @property - def is_text(self): + def is_text(self): # type: ignore[override] return self._proxied.is_text @property - def is_delete(self): + def is_delete(self): # type: ignore[override] return self._proxied.is_delete @property - def is_dml(self): + def is_dml(self): # type: ignore[override] return self._proxied.is_dml def spoil(self) -> NullLambdaStatement: @@ -737,9 +729,9 @@ class AnalyzedCode: "closure_trackers", "build_py_wrappers", ) - _fns: weakref.WeakKeyDictionary[ - CodeType, AnalyzedCode - ] = weakref.WeakKeyDictionary() + _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = ( + weakref.WeakKeyDictionary() + ) _generation_mutex = threading.RLock() @@ -1180,16 +1172,16 @@ def _instrument_and_run_function(self, lambda_element): closure_pywrappers.append(bind) else: value = fn.__globals__[name] - new_globals[name] = bind = PyWrapper(fn, name, value) + new_globals[name] = PyWrapper(fn, name, value) # rewrite the original fn. things that look like they will # become bound parameters are wrapped in a PyWrapper. - self.tracker_instrumented_fn = ( - tracker_instrumented_fn - ) = self._rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, + self.tracker_instrumented_fn = tracker_instrumented_fn = ( + self._rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) ) # now invoke the function. This will give us a new SQL diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 03c9aab67ba..ce68acf15b9 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -1,15 +1,12 @@ -# sqlalchemy/naming.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# sql/naming.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -"""Establish constraint and index naming conventions. - - -""" +"""Establish constraint and index naming conventions.""" from __future__ import annotations diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 6402d0fd1b2..d5f876cb0d8 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1,5 +1,5 @@ # sql/operators.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -77,8 +77,7 @@ def __call__( right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def __call__( @@ -87,8 +86,7 @@ def __call__( right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... def __call__( self, @@ -96,8 +94,7 @@ def __call__( right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... add = cast(OperatorType, _uncast_add) @@ -151,6 +148,7 @@ def __and__(self, other: Any) -> Operators: is equivalent to:: from sqlalchemy import and_ + and_(a, b) Care should be taken when using ``&`` regarding @@ -175,6 +173,7 @@ def __or__(self, other: Any) -> Operators: is equivalent to:: from sqlalchemy import or_ + or_(a, b) Care should be taken when using ``|`` regarding @@ -199,6 +198,7 @@ def __invert__(self) -> Operators: is equivalent to:: from sqlalchemy import not_ + not_(a) """ @@ -227,7 +227,7 @@ def op( This function can also be used to make bitwise operators explicit. For example:: - somecolumn.op('&')(0xff) + somecolumn.op("&")(0xFF) is a bitwise AND of the value in ``somecolumn``. @@ -278,7 +278,7 @@ def op( e.g.:: - >>> expr = column('x').op('+', python_impl=lambda a, b: a + b)('y') + >>> expr = column("x").op("+", python_impl=lambda a, b: a + b)("y") The operator for the above expression will also work for non-SQL left and right objects:: @@ -392,10 +392,9 @@ class custom_op(OperatorType, Generic[_T]): from sqlalchemy.sql import operators from sqlalchemy import Numeric - unary = UnaryExpression(table.c.somecolumn, - modifier=operators.custom_op("!"), - type_=Numeric) - + unary = UnaryExpression( + table.c.somecolumn, modifier=operators.custom_op("!"), type_=Numeric + ) .. seealso:: @@ -403,7 +402,7 @@ class custom_op(OperatorType, Generic[_T]): :meth:`.Operators.bool_op` - """ + """ # noqa: E501 __name__ = "custom_op" @@ -466,8 +465,7 @@ def __call__( right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def __call__( @@ -476,8 +474,7 @@ def __call__( right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... def __call__( self, @@ -545,13 +542,11 @@ def eq(a, b): def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... def __lt__(self, other: Any) -> ColumnOperators: """Implement the ``<`` operator. @@ -574,8 +569,7 @@ def __le__(self, other: Any) -> ColumnOperators: # https://docs.python.org/3/reference/datamodel.html#object.__hash__ if TYPE_CHECKING: - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... else: __hash__ = Operators.__hash__ @@ -623,8 +617,7 @@ def is_not_distinct_from(self, other: Any) -> ColumnOperators: # deprecated 1.4; see #5435 if TYPE_CHECKING: - def isnot_distinct_from(self, other: Any) -> ColumnOperators: - ... + def isnot_distinct_from(self, other: Any) -> ColumnOperators: ... else: isnot_distinct_from = is_not_distinct_from @@ -707,14 +700,15 @@ def like( ) -> ColumnOperators: r"""Implement the ``like`` operator. - In a column context, produces the expression:: + In a column context, produces the expression: + + .. sourcecode:: sql a LIKE other E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.like("%foobar%")) + stmt = select(sometable).where(sometable.c.column.like("%foobar%")) :param other: expression to be compared :param escape: optional escape character, renders the ``ESCAPE`` @@ -734,18 +728,21 @@ def ilike( ) -> ColumnOperators: r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE. - In a column context, produces an expression either of the form:: + In a column context, produces an expression either of the form: + + .. sourcecode:: sql lower(a) LIKE lower(other) - Or on backends that support the ILIKE operator:: + Or on backends that support the ILIKE operator: + + .. sourcecode:: sql a ILIKE other E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.ilike("%foobar%")) + stmt = select(sometable).where(sometable.c.column.ilike("%foobar%")) :param other: expression to be compared :param escape: optional escape character, renders the ``ESCAPE`` @@ -757,7 +754,7 @@ def ilike( :meth:`.ColumnOperators.like` - """ + """ # noqa: E501 return self.operate(ilike_op, other, escape=escape) def bitwise_xor(self, other: Any) -> ColumnOperators: @@ -851,12 +848,15 @@ def in_(self, other: Any) -> ColumnOperators: The given parameter ``other`` may be: - * A list of literal values, e.g.:: + * A list of literal values, + e.g.:: stmt.where(column.in_([1, 2, 3])) In this calling form, the list of items is converted to a set of - bound parameters the same length as the list given:: + bound parameters the same length as the list given: + + .. sourcecode:: sql WHERE COL IN (?, ?, ?) @@ -864,16 +864,20 @@ def in_(self, other: Any) -> ColumnOperators: :func:`.tuple_` containing multiple expressions:: from sqlalchemy import tuple_ + stmt.where(tuple_(col1, col2).in_([(1, 10), (2, 20), (3, 30)])) - * An empty list, e.g.:: + * An empty list, + e.g.:: stmt.where(column.in_([])) In this calling form, the expression renders an "empty set" expression. These expressions are tailored to individual backends and are generally trying to get an empty SELECT statement as a - subquery. Such as on SQLite, the expression is:: + subquery. Such as on SQLite, the expression is: + + .. sourcecode:: sql WHERE col IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1) @@ -883,10 +887,12 @@ def in_(self, other: Any) -> ColumnOperators: * A bound parameter, e.g. :func:`.bindparam`, may be used if it includes the :paramref:`.bindparam.expanding` flag:: - stmt.where(column.in_(bindparam('value', expanding=True))) + stmt.where(column.in_(bindparam("value", expanding=True))) In this calling form, the expression renders a special non-SQL - placeholder expression that looks like:: + placeholder expression that looks like: + + .. sourcecode:: sql WHERE COL IN ([EXPANDING_value]) @@ -896,7 +902,9 @@ def in_(self, other: Any) -> ColumnOperators: connection.execute(stmt, {"value": [1, 2, 3]}) - The database would be passed a bound parameter for each value:: + The database would be passed a bound parameter for each value: + + .. sourcecode:: sql WHERE COL IN (?, ?, ?) @@ -904,7 +912,9 @@ def in_(self, other: Any) -> ColumnOperators: If an empty list is passed, a special "empty list" expression, which is specific to the database in use, is rendered. On - SQLite this would be:: + SQLite this would be: + + .. sourcecode:: sql WHERE COL IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1) @@ -915,13 +925,12 @@ def in_(self, other: Any) -> ColumnOperators: correlated scalar select:: stmt.where( - column.in_( - select(othertable.c.y). - where(table.c.x == othertable.c.x) - ) + column.in_(select(othertable.c.y).where(table.c.x == othertable.c.x)) ) - In this calling form, :meth:`.ColumnOperators.in_` renders as given:: + In this calling form, :meth:`.ColumnOperators.in_` renders as given: + + .. sourcecode:: sql WHERE COL IN (SELECT othertable.y FROM othertable WHERE othertable.x = table.x) @@ -930,7 +939,7 @@ def in_(self, other: Any) -> ColumnOperators: construct, or a :func:`.bindparam` construct that includes the :paramref:`.bindparam.expanding` flag set to True. - """ + """ # noqa: E501 return self.operate(in_op, other) def not_in(self, other: Any) -> ColumnOperators: @@ -964,8 +973,7 @@ def not_in(self, other: Any) -> ColumnOperators: # deprecated 1.4; see #5429 if TYPE_CHECKING: - def notin_(self, other: Any) -> ColumnOperators: - ... + def notin_(self, other: Any) -> ColumnOperators: ... else: notin_ = not_in @@ -994,8 +1002,7 @@ def not_like( def notlike( self, other: Any, escape: Optional[str] = None - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... else: notlike = not_like @@ -1024,8 +1031,7 @@ def not_ilike( def notilike( self, other: Any, escape: Optional[str] = None - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... else: notilike = not_ilike @@ -1063,8 +1069,7 @@ def is_not(self, other: Any) -> ColumnOperators: # deprecated 1.4; see #5429 if TYPE_CHECKING: - def isnot(self, other: Any) -> ColumnOperators: - ... + def isnot(self, other: Any) -> ColumnOperators: ... else: isnot = is_not @@ -1078,14 +1083,15 @@ def startswith( r"""Implement the ``startswith`` operator. Produces a LIKE expression that tests against a match for the start - of a string value:: + of a string value: + + .. sourcecode:: sql column LIKE || '%' E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.startswith("foobar")) + stmt = select(sometable).where(sometable.c.column.startswith("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1114,7 +1120,9 @@ def startswith( somecolumn.startswith("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE :param || '%' ESCAPE '/' @@ -1130,7 +1138,9 @@ def startswith( somecolumn.startswith("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE :param || '%' ESCAPE '^' @@ -1150,7 +1160,7 @@ def startswith( :meth:`.ColumnOperators.like` - """ + """ # noqa: E501 return self.operate( startswith_op, other, escape=escape, autoescape=autoescape ) @@ -1165,14 +1175,15 @@ def istartswith( version of :meth:`.ColumnOperators.startswith`. Produces a LIKE expression that tests against an insensitive - match for the start of a string value:: + match for the start of a string value: + + .. sourcecode:: sql lower(column) LIKE lower() || '%' E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.istartswith("foobar")) + stmt = select(sometable).where(sometable.c.column.istartswith("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1201,7 +1212,9 @@ def istartswith( somecolumn.istartswith("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE lower(:param) || '%' ESCAPE '/' @@ -1217,7 +1230,9 @@ def istartswith( somecolumn.istartswith("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE lower(:param) || '%' ESCAPE '^' @@ -1232,7 +1247,7 @@ def istartswith( .. seealso:: :meth:`.ColumnOperators.startswith` - """ + """ # noqa: E501 return self.operate( istartswith_op, other, escape=escape, autoescape=autoescape ) @@ -1246,14 +1261,15 @@ def endswith( r"""Implement the 'endswith' operator. Produces a LIKE expression that tests against a match for the end - of a string value:: + of a string value: + + .. sourcecode:: sql column LIKE '%' || E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.endswith("foobar")) + stmt = select(sometable).where(sometable.c.column.endswith("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1282,7 +1298,9 @@ def endswith( somecolumn.endswith("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE '%' || :param ESCAPE '/' @@ -1298,7 +1316,9 @@ def endswith( somecolumn.endswith("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE '%' || :param ESCAPE '^' @@ -1318,7 +1338,7 @@ def endswith( :meth:`.ColumnOperators.like` - """ + """ # noqa: E501 return self.operate( endswith_op, other, escape=escape, autoescape=autoescape ) @@ -1333,14 +1353,15 @@ def iendswith( version of :meth:`.ColumnOperators.endswith`. Produces a LIKE expression that tests against an insensitive match - for the end of a string value:: + for the end of a string value: + + .. sourcecode:: sql lower(column) LIKE '%' || lower() E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.iendswith("foobar")) + stmt = select(sometable).where(sometable.c.column.iendswith("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1369,7 +1390,9 @@ def iendswith( somecolumn.iendswith("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE '%' || lower(:param) ESCAPE '/' @@ -1385,7 +1408,9 @@ def iendswith( somecolumn.iendswith("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE '%' || lower(:param) ESCAPE '^' @@ -1400,7 +1425,7 @@ def iendswith( .. seealso:: :meth:`.ColumnOperators.endswith` - """ + """ # noqa: E501 return self.operate( iendswith_op, other, escape=escape, autoescape=autoescape ) @@ -1409,14 +1434,15 @@ def contains(self, other: Any, **kw: Any) -> ColumnOperators: r"""Implement the 'contains' operator. Produces a LIKE expression that tests against a match for the middle - of a string value:: + of a string value: + + .. sourcecode:: sql column LIKE '%' || || '%' E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.contains("foobar")) + stmt = select(sometable).where(sometable.c.column.contains("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1445,7 +1471,9 @@ def contains(self, other: Any, **kw: Any) -> ColumnOperators: somecolumn.contains("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE '%' || :param || '%' ESCAPE '/' @@ -1461,7 +1489,9 @@ def contains(self, other: Any, **kw: Any) -> ColumnOperators: somecolumn.contains("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql somecolumn LIKE '%' || :param || '%' ESCAPE '^' @@ -1482,7 +1512,7 @@ def contains(self, other: Any, **kw: Any) -> ColumnOperators: :meth:`.ColumnOperators.like` - """ + """ # noqa: E501 return self.operate(contains_op, other, **kw) def icontains(self, other: Any, **kw: Any) -> ColumnOperators: @@ -1490,14 +1520,15 @@ def icontains(self, other: Any, **kw: Any) -> ColumnOperators: version of :meth:`.ColumnOperators.contains`. Produces a LIKE expression that tests against an insensitive match - for the middle of a string value:: + for the middle of a string value: + + .. sourcecode:: sql lower(column) LIKE '%' || lower() || '%' E.g.:: - stmt = select(sometable).\ - where(sometable.c.column.icontains("foobar")) + stmt = select(sometable).where(sometable.c.column.icontains("foobar")) Since the operator uses ``LIKE``, wildcard characters ``"%"`` and ``"_"`` that are present inside the expression @@ -1526,7 +1557,9 @@ def icontains(self, other: Any, **kw: Any) -> ColumnOperators: somecolumn.icontains("foo%bar", autoescape=True) - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE '%' || lower(:param) || '%' ESCAPE '/' @@ -1542,7 +1575,9 @@ def icontains(self, other: Any, **kw: Any) -> ColumnOperators: somecolumn.icontains("foo/%bar", escape="^") - Will render as:: + Will render as: + + .. sourcecode:: sql lower(somecolumn) LIKE '%' || lower(:param) || '%' ESCAPE '^' @@ -1558,7 +1593,7 @@ def icontains(self, other: Any, **kw: Any) -> ColumnOperators: :meth:`.ColumnOperators.contains` - """ + """ # noqa: E501 return self.operate(icontains_op, other, **kw) def match(self, other: Any, **kwargs: Any) -> ColumnOperators: @@ -1582,7 +1617,7 @@ def match(self, other: Any, **kwargs: Any) -> ColumnOperators: :class:`_mysql.match` - MySQL specific construct with additional features. - * Oracle - renders ``CONTAINS(x, y)`` + * Oracle Database - renders ``CONTAINS(x, y)`` * other backends may provide special implementations. * Backends without any special implementation will emit the operator as "MATCH". This is compatible with SQLite, for @@ -1599,7 +1634,7 @@ def regexp_match( E.g.:: stmt = select(table.c.some_column).where( - table.c.some_column.regexp_match('^(b|c)') + table.c.some_column.regexp_match("^(b|c)") ) :meth:`_sql.ColumnOperators.regexp_match` attempts to resolve to @@ -1610,7 +1645,7 @@ def regexp_match( Examples include: * PostgreSQL - renders ``x ~ y`` or ``x !~ y`` when negated. - * Oracle - renders ``REGEXP_LIKE(x, y)`` + * Oracle Database - renders ``REGEXP_LIKE(x, y)`` * SQLite - uses SQLite's ``REGEXP`` placeholder operator and calls into the Python ``re.match()`` builtin. * other backends may provide special implementations. @@ -1618,9 +1653,9 @@ def regexp_match( the operator as "REGEXP" or "NOT REGEXP". This is compatible with SQLite and MySQL, for example. - Regular expression support is currently implemented for Oracle, - PostgreSQL, MySQL and MariaDB. Partial support is available for - SQLite. Support among third-party dialects may vary. + Regular expression support is currently implemented for Oracle + Database, PostgreSQL, MySQL and MariaDB. Partial support is available + for SQLite. Support among third-party dialects may vary. :param pattern: The regular expression pattern string or column clause. @@ -1657,11 +1692,7 @@ def regexp_replace( E.g.:: stmt = select( - table.c.some_column.regexp_replace( - 'b(..)', - 'X\1Y', - flags='g' - ) + table.c.some_column.regexp_replace("b(..)", "X\1Y", flags="g") ) :meth:`_sql.ColumnOperators.regexp_replace` attempts to resolve to @@ -1671,8 +1702,8 @@ def regexp_replace( **not backend agnostic**. Regular expression replacement support is currently implemented for - Oracle, PostgreSQL, MySQL 8 or greater and MariaDB. Support among - third-party dialects may vary. + Oracle Database, PostgreSQL, MySQL 8 or greater and MariaDB. Support + among third-party dialects may vary. :param pattern: The regular expression pattern string or column clause. @@ -1728,8 +1759,7 @@ def nulls_first(self) -> ColumnOperators: # deprecated 1.4; see #5435 if TYPE_CHECKING: - def nullsfirst(self) -> ColumnOperators: - ... + def nullsfirst(self) -> ColumnOperators: ... else: nullsfirst = nulls_first @@ -1747,8 +1777,7 @@ def nulls_last(self) -> ColumnOperators: # deprecated 1.4; see #5429 if TYPE_CHECKING: - def nullslast(self) -> ColumnOperators: - ... + def nullslast(self) -> ColumnOperators: ... else: nullslast = nulls_last @@ -1819,10 +1848,10 @@ def any_(self) -> ColumnOperators: See the documentation for :func:`_sql.any_` for examples. .. note:: be sure to not confuse the newer - :meth:`_sql.ColumnOperators.any_` method with its older - :class:`_types.ARRAY`-specific counterpart, the - :meth:`_types.ARRAY.Comparator.any` method, which a different - calling syntax and usage pattern. + :meth:`_sql.ColumnOperators.any_` method with the **legacy** + version of this method, the :meth:`_types.ARRAY.Comparator.any` + method that's specific to :class:`_types.ARRAY`, which uses a + different calling style. """ return self.operate(any_op) @@ -1834,10 +1863,10 @@ def all_(self) -> ColumnOperators: See the documentation for :func:`_sql.all_` for examples. .. note:: be sure to not confuse the newer - :meth:`_sql.ColumnOperators.all_` method with its older - :class:`_types.ARRAY`-specific counterpart, the - :meth:`_types.ARRAY.Comparator.all` method, which a different - calling syntax and usage pattern. + :meth:`_sql.ColumnOperators.all_` method with the **legacy** + version of this method, the :meth:`_types.ARRAY.Comparator.all` + method that's specific to :class:`_types.ARRAY`, which uses a + different calling style. """ return self.operate(all_op) @@ -1968,8 +1997,7 @@ def is_true(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def istrue(a: Any) -> Any: - ... + def istrue(a: Any) -> Any: ... else: istrue = is_true @@ -1984,8 +2012,7 @@ def is_false(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isfalse(a: Any) -> Any: - ... + def isfalse(a: Any) -> Any: ... else: isfalse = is_false @@ -2007,8 +2034,7 @@ def is_not_distinct_from(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isnot_distinct_from(a: Any, b: Any) -> Any: - ... + def isnot_distinct_from(a: Any, b: Any) -> Any: ... else: isnot_distinct_from = is_not_distinct_from @@ -2030,8 +2056,7 @@ def is_not(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isnot(a: Any, b: Any) -> Any: - ... + def isnot(a: Any, b: Any) -> Any: ... else: isnot = is_not @@ -2063,8 +2088,7 @@ def not_like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: if TYPE_CHECKING: @_operator_fn - def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: - ... + def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ... else: notlike_op = not_like_op @@ -2086,8 +2110,7 @@ def not_ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: if TYPE_CHECKING: @_operator_fn - def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: - ... + def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ... else: notilike_op = not_ilike_op @@ -2109,8 +2132,9 @@ def not_between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: if TYPE_CHECKING: @_operator_fn - def notbetween_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: - ... + def notbetween_op( + a: Any, b: Any, c: Any, symmetric: bool = False + ) -> Any: ... else: notbetween_op = not_between_op @@ -2132,8 +2156,7 @@ def not_in_op(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def notin_op(a: Any, b: Any) -> Any: - ... + def notin_op(a: Any, b: Any) -> Any: ... else: notin_op = not_in_op @@ -2198,8 +2221,7 @@ def not_startswith_op( @_operator_fn def notstartswith_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notstartswith_op = not_startswith_op @@ -2243,8 +2265,7 @@ def not_endswith_op( @_operator_fn def notendswith_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notendswith_op = not_endswith_op @@ -2288,8 +2309,7 @@ def not_contains_op( @_operator_fn def notcontains_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notcontains_op = not_contains_op @@ -2346,8 +2366,7 @@ def not_match_op(a: Any, b: Any, **kw: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def notmatch_op(a: Any, b: Any, **kw: Any) -> Any: - ... + def notmatch_op(a: Any, b: Any, **kw: Any) -> Any: ... else: notmatch_op = not_match_op @@ -2392,8 +2411,7 @@ def nulls_first_op(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def nullsfirst_op(a: Any) -> Any: - ... + def nullsfirst_op(a: Any) -> Any: ... else: nullsfirst_op = nulls_first_op @@ -2408,8 +2426,7 @@ def nulls_last_op(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def nullslast_op(a: Any) -> Any: - ... + def nullslast_op(a: Any) -> Any: ... else: nullslast_op = nulls_last_op @@ -2501,6 +2518,12 @@ def is_associative(op: OperatorType) -> bool: return op in _associative +def is_order_by_modifier(op: Optional[OperatorType]) -> bool: + return op in _order_by_modifier + + +_order_by_modifier = {desc_op, asc_op, nulls_first_op, nulls_last_op} + _natural_self_precedent = _associative.union( [getitem, json_getitem_op, json_path_getitem_op] ) @@ -2582,9 +2605,13 @@ class _OpLimit(IntEnum): } -def is_precedent(operator: OperatorType, against: OperatorType) -> bool: +def is_precedent( + operator: OperatorType, against: Optional[OperatorType] +) -> bool: if operator is against and is_natural_self_precedent(operator): return False + elif against is None: + return True else: return bool( _PRECEDENCE.get( diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 6f299224328..da69616dc46 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -1,5 +1,5 @@ # sql/roles.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -227,8 +227,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): def _anonymous_fromclause( self, *, name: Optional[str] = None, flat: bool = False - ) -> FromClause: - ... + ) -> FromClause: ... class ReturnsRowsRole(SQLRole): @@ -246,8 +245,7 @@ class StatementRole(SQLRole): if TYPE_CHECKING: @util.memoized_property - def _propagate_attrs(self) -> _PropagateAttrsType: - ... + def _propagate_attrs(self) -> _PropagateAttrsType: ... else: _propagate_attrs = util.EMPTY_DICT diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c464d7eb0ea..db0678a7378 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1,5 +1,5 @@ # sql/schema.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -60,6 +60,7 @@ from . import type_api from . import visitors from .base import _DefaultDescriptionTuple +from .base import _NoArg from .base import _NoneName from .base import _SentinelColumnCharacterization from .base import _SentinelDefaultCharacterization @@ -67,6 +68,7 @@ from .base import DialectKWArgs from .base import Executable from .base import SchemaEventTarget as SchemaEventTarget +from .base import SchemaVisitable as SchemaVisitable from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause @@ -76,7 +78,6 @@ from .selectable import TableClause from .type_api import to_instance from .visitors import ExternallyTraversible -from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection @@ -91,23 +92,24 @@ if typing.TYPE_CHECKING: from ._typing import _AutoIncrementType + from ._typing import _CreateDropBind from ._typing import _DDLColumnArgument from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument + from .base import ColumnSet from .base import ReadOnlyColumnCollection from .compiler import DDLCompiler from .elements import BindParameter + from .elements import KeyedColumnElement from .functions import Function from .type_api import TypeEngine - from .visitors import _TraverseInternalsType from .visitors import anon_map from ..engine import Connection from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import ExecutionContext - from ..engine.mock import MockConnection from ..engine.reflection import _ReflectionInfo from ..sql.selectable import FromClause @@ -116,14 +118,14 @@ _TAB = TypeVar("_TAB", bound="Table") -_CreateDropBind = Union["Engine", "Connection", "MockConnection"] - _ConstraintNameArgument = Optional[Union[str, _NoneName]] _ServerDefaultArgument = Union[ "FetchedValue", str, TextClause, ColumnElement[Any] ] +_ServerOnUpdateArgument = _ServerDefaultArgument + class SchemaConst(Enum): RETAIN_SCHEMA = 1 @@ -159,15 +161,15 @@ class SchemaConst(Enum): """ -RETAIN_SCHEMA: Final[ - Literal[SchemaConst.RETAIN_SCHEMA] -] = SchemaConst.RETAIN_SCHEMA -BLANK_SCHEMA: Final[ - Literal[SchemaConst.BLANK_SCHEMA] -] = SchemaConst.BLANK_SCHEMA -NULL_UNSPECIFIED: Final[ - Literal[SchemaConst.NULL_UNSPECIFIED] -] = SchemaConst.NULL_UNSPECIFIED +RETAIN_SCHEMA: Final[Literal[SchemaConst.RETAIN_SCHEMA]] = ( + SchemaConst.RETAIN_SCHEMA +) +BLANK_SCHEMA: Final[Literal[SchemaConst.BLANK_SCHEMA]] = ( + SchemaConst.BLANK_SCHEMA +) +NULL_UNSPECIFIED: Final[Literal[SchemaConst.NULL_UNSPECIFIED]] = ( + SchemaConst.NULL_UNSPECIFIED +) def _get_table_key(name: str, schema: Optional[str]) -> str: @@ -209,7 +211,7 @@ def replace( @inspection._self_inspects -class SchemaItem(SchemaEventTarget, visitors.Visitable): +class SchemaItem(SchemaVisitable): """Base class for items that define a database schema.""" __visit_name__ = "schema_item" @@ -225,7 +227,7 @@ def _init_items(self, *args: SchemaItem, **kw: Any) -> None: except AttributeError as err: raise exc.ArgumentError( "'SchemaItem' object, such as a 'Column' or a " - "'Constraint' expected, got %r" % item + f"'Constraint' expected, got {item!r}" ) from err else: spwd(self, **kw) @@ -319,9 +321,10 @@ class Table( e.g.:: mytable = Table( - "mytable", metadata, - Column('mytable_id', Integer, primary_key=True), - Column('value', String(50)) + "mytable", + metadata, + Column("mytable_id", Integer, primary_key=True), + Column("value", String(50)), ) The :class:`_schema.Table` @@ -344,12 +347,10 @@ class Table( if TYPE_CHECKING: @util.ro_non_memoized_property - def primary_key(self) -> PrimaryKeyConstraint: - ... + def primary_key(self) -> PrimaryKeyConstraint: ... @util.ro_non_memoized_property - def foreign_keys(self) -> Set[ForeignKey]: - ... + def foreign_keys(self) -> Set[ForeignKey]: ... _columns: DedupeColumnCollection[Column[Any]] @@ -393,26 +394,18 @@ def foreign_keys(self) -> Set[ForeignKey]: """ - _traverse_internals: _TraverseInternalsType = ( - TableClause._traverse_internals - + [("schema", InternalTraversal.dp_string)] - ) - if TYPE_CHECKING: @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... @util.ro_non_memoized_property def exported_columns( self, - ) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + ) -> ReadOnlyColumnCollection[str, Column[Any]]: ... @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... def _gen_cache_key( self, anon_map: anon_map, bindparams: List[BindParameter[Any]] @@ -466,11 +459,11 @@ def _new(cls, *args: Any, **kw: Any) -> Any: if key in metadata.tables: if not keep_existing and not extend_existing and bool(args): raise exc.InvalidRequestError( - "Table '%s' is already defined for this MetaData " + f"Table '{key}' is already defined for this MetaData " "instance. Specify 'extend_existing=True' " "to redefine " "options and columns on an " - "existing Table object." % key + "existing Table object." ) table = metadata.tables[key] if extend_existing: @@ -478,12 +471,12 @@ def _new(cls, *args: Any, **kw: Any) -> Any: return table else: if must_exist: - raise exc.InvalidRequestError("Table '%s' not defined" % (key)) + raise exc.InvalidRequestError(f"Table '{key}' not defined") table = object.__new__(cls) table.dispatch.before_parent_attach(table, metadata) metadata._add_table(name, schema, table) try: - table.__init__(name, metadata, *args, _no_init=False, **kw) + table.__init__(name, metadata, *args, _no_init=False, **kw) # type: ignore[misc] # noqa: E501 table.dispatch.after_parent_attach(table, metadata) return table except Exception: @@ -641,11 +634,13 @@ def __init__( :class:`_schema.Column` named "y":: - Table("mytable", metadata, - Column('y', Integer), - extend_existing=True, - autoload_with=engine - ) + Table( + "mytable", + metadata, + Column("y", Integer), + extend_existing=True, + autoload_with=engine, + ) .. seealso:: @@ -742,12 +737,12 @@ def listen_for_reflect(table, column_info): "handle the column reflection event" # ... + t = Table( - 'sometable', + "sometable", autoload_with=engine, - listeners=[ - ('column_reflect', listen_for_reflect) - ]) + listeners=[("column_reflect", listen_for_reflect)], + ) .. seealso:: @@ -955,8 +950,8 @@ def _init_existing(self, *args: Any, **kwargs: Any) -> None: if schema and schema != self.schema: raise exc.ArgumentError( - "Can't change schema of existing table from '%s' to '%s'", - (self.schema, schema), + f"Can't change schema of existing table " + f"from '{self.schema}' to '{schema}'", ) include_columns = kwargs.pop("include_columns", None) @@ -1354,7 +1349,7 @@ def to_metadata( m1 = MetaData() - user = Table('user', m1, Column('id', Integer, primary_key=True)) + user = Table("user", m1, Column("id", Integer, primary_key=True)) m2 = MetaData() user_copy = user.to_metadata(m2) @@ -1378,7 +1373,7 @@ def to_metadata( unless set explicitly:: - m2 = MetaData(schema='newschema') + m2 = MetaData(schema="newschema") # user_copy_one will have "newschema" as the schema name user_copy_one = user.to_metadata(m2, schema=None) @@ -1405,15 +1400,16 @@ def to_metadata( E.g.:: - def referred_schema_fn(table, to_schema, - constraint, referred_schema): - if referred_schema == 'base_tables': + def referred_schema_fn(table, to_schema, constraint, referred_schema): + if referred_schema == "base_tables": return referred_schema else: return to_schema - new_table = table.to_metadata(m2, schema="alt_schema", - referred_schema_fn=referred_schema_fn) + + new_table = table.to_metadata( + m2, schema="alt_schema", referred_schema_fn=referred_schema_fn + ) :param name: optional string name indicating the target table name. If not specified or None, the table name is retained. This allows @@ -1421,7 +1417,7 @@ def referred_schema_fn(table, to_schema, :class:`_schema.MetaData` target with a new name. - """ + """ # noqa: E501 if name is None: name = self.name @@ -1436,14 +1432,14 @@ def referred_schema_fn(table, to_schema, key = _get_table_key(name, actual_schema) if key in metadata.tables: util.warn( - "Table '%s' already exists within the given " - "MetaData - not copying." % self.description + f"Table '{self.description}' already exists within the given " + "MetaData - not copying." ) return metadata.tables[key] args = [] for col in self.columns: - args.append(col._copy(schema=actual_schema)) + args.append(col._copy(schema=actual_schema, _to_metadata=metadata)) table = Table( name, metadata, @@ -1519,7 +1515,8 @@ def __init__( name: Optional[str] = None, type_: Optional[_TypeEngineArgument[_T]] = None, autoincrement: _AutoIncrementType = "auto", - default: Optional[Any] = None, + default: Optional[Any] = _NoArg.NO_ARG, + insert_default: Optional[Any] = _NoArg.NO_ARG, doc: Optional[str] = None, key: Optional[str] = None, index: Optional[bool] = None, @@ -1531,7 +1528,7 @@ def __init__( onupdate: Optional[Any] = None, primary_key: bool = False, server_default: Optional[_ServerDefaultArgument] = None, - server_onupdate: Optional[FetchedValue] = None, + server_onupdate: Optional[_ServerOnUpdateArgument] = None, quote: Optional[bool] = None, system: bool = False, comment: Optional[str] = None, @@ -1552,7 +1549,7 @@ def __init__( unless they are a reserved word. Names with any number of upper case characters will be quoted and sent exactly. Note that this behavior applies even for databases which standardize upper - case names as case insensitive such as Oracle. + case names as case insensitive such as Oracle Database. The name field may be omitted at construction time and applied later, at any time before the Column is associated with a @@ -1565,10 +1562,10 @@ def __init__( as well, e.g.:: # use a type with arguments - Column('data', String(50)) + Column("data", String(50)) # use no arguments - Column('level', Integer) + Column("level", Integer) The ``type`` argument may be the second positional argument or specified by keyword. @@ -1624,8 +1621,8 @@ def __init__( will imply that database-specific keywords such as PostgreSQL ``SERIAL``, MySQL ``AUTO_INCREMENT``, or ``IDENTITY`` on SQL Server should also be rendered. Not every database backend has an - "implied" default generator available; for example the Oracle - backend always needs an explicit construct such as + "implied" default generator available; for example the Oracle Database + backends alway needs an explicit construct such as :class:`.Identity` to be included with a :class:`.Column` in order for the DDL rendered to include auto-generating constructs to also be produced in the database. @@ -1670,8 +1667,12 @@ def __init__( # turn on autoincrement for this column despite # the ForeignKey() - Column('id', ForeignKey('other.id'), - primary_key=True, autoincrement='ignore_fk') + Column( + "id", + ForeignKey("other.id"), + primary_key=True, + autoincrement="ignore_fk", + ) It is typically not desirable to have "autoincrement" enabled on a column that refers to another via foreign key, as such a column is @@ -1699,7 +1700,7 @@ def __init__( is not included as this is unnecessary and not recommended by the database vendor. See the section :ref:`sqlite_autoincrement` for more background. - * Oracle - The Oracle dialect has no default "autoincrement" + * Oracle Database - The Oracle Database dialects have no default "autoincrement" feature available at this time, instead the :class:`.Identity` construct is recommended to achieve this (the :class:`.Sequence` construct may also be used). @@ -1716,10 +1717,10 @@ def __init__( (see `https://www.python.org/dev/peps/pep-0249/#lastrowid `_) - * PostgreSQL, SQL Server, Oracle - use RETURNING or an equivalent + * PostgreSQL, SQL Server, Oracle Database - use RETURNING or an equivalent construct when rendering an INSERT statement, and then retrieving the newly generated primary key values after execution - * PostgreSQL, Oracle for :class:`_schema.Table` objects that + * PostgreSQL, Oracle Database for :class:`_schema.Table` objects that set :paramref:`_schema.Table.implicit_returning` to False - for a :class:`.Sequence` only, the :class:`.Sequence` is invoked explicitly before the INSERT statement takes place so that the @@ -1756,6 +1757,11 @@ def __init__( :ref:`metadata_defaults_toplevel` + :param insert_default: An alias of :paramref:`.Column.default` + for compatibility with :func:`_orm.mapped_column`. + + .. versionadded: 2.0.31 + :param doc: optional String that can be used by the ORM or similar to document attributes on the Python side. This attribute does **not** render SQL comments; use the @@ -1783,7 +1789,7 @@ def __init__( "some_table", metadata, Column("x", Integer), - Index("ix_some_table_x", "x") + Index("ix_some_table_x", "x"), ) To add the :paramref:`_schema.Index.unique` flag to the @@ -1865,14 +1871,22 @@ def __init__( String types will be emitted as-is, surrounded by single quotes:: - Column('x', Text, server_default="val") + Column("x", Text, server_default="val") + + will render: + + .. sourcecode:: sql x TEXT DEFAULT 'val' A :func:`~sqlalchemy.sql.expression.text` expression will be rendered as-is, without quotes:: - Column('y', DateTime, server_default=text('NOW()')) + Column("y", DateTime, server_default=text("NOW()")) + + will render: + + .. sourcecode:: sql y DATETIME DEFAULT NOW() @@ -1887,20 +1901,21 @@ def __init__( from sqlalchemy.dialects.postgresql import array engine = create_engine( - 'postgresql+psycopg2://scott:tiger@localhost/mydatabase' + "postgresql+psycopg2://scott:tiger@localhost/mydatabase" ) metadata_obj = MetaData() tbl = Table( - "foo", - metadata_obj, - Column("bar", - ARRAY(Text), - server_default=array(["biz", "bang", "bash"]) - ) + "foo", + metadata_obj, + Column( + "bar", ARRAY(Text), server_default=array(["biz", "bang", "bash"]) + ), ) metadata_obj.create_all(engine) - The above results in a table created with the following SQL:: + The above results in a table created with the following SQL: + + .. sourcecode:: sql CREATE TABLE foo ( bar TEXT[] DEFAULT ARRAY['biz', 'bang', 'bash'] @@ -1965,12 +1980,7 @@ def __init__( :class:`_schema.UniqueConstraint` construct explicitly at the level of the :class:`_schema.Table` construct itself:: - Table( - "some_table", - metadata, - Column("x", Integer), - UniqueConstraint("x") - ) + Table("some_table", metadata, Column("x", Integer), UniqueConstraint("x")) The :paramref:`_schema.UniqueConstraint.name` parameter of the unique constraint object is left at its default value @@ -2068,7 +2078,7 @@ def __init__( name = quoted_name(name, quote) elif quote is not None: raise exc.ArgumentError( - "Explicit 'name' is required when " "sending 'quote' argument" + "Explicit 'name' is required when sending 'quote' argument" ) # name = None is expected to be an interim state @@ -2109,12 +2119,19 @@ def __init__( # otherwise, add DDL-related events self._set_type(self.type) - if default is not None: - if not isinstance(default, (ColumnDefault, Sequence)): - default = ColumnDefault(default) + if insert_default is not _NoArg.NO_ARG: + resolved_default = insert_default + elif default is not _NoArg.NO_ARG: + resolved_default = default + else: + resolved_default = None + + if resolved_default is not None: + if not isinstance(resolved_default, (ColumnDefault, Sequence)): + resolved_default = ColumnDefault(resolved_default) - self.default = default - l_args.append(default) + self.default = resolved_default + l_args.append(resolved_default) else: self.default = None @@ -2204,6 +2221,8 @@ def __init__( identity: Optional[Identity] def _set_type(self, type_: TypeEngine[Any]) -> None: + assert self.type._isnull or type_ is self.type + self.type = type_ if isinstance(self.type, SchemaEventTarget): self.type._set_parent_with_dispatch(self) @@ -2223,7 +2242,7 @@ def _onupdate_description_tuple(self) -> _DefaultDescriptionTuple: return _DefaultDescriptionTuple._from_column_default(self.onupdate) @util.memoized_property - def _gen_static_annotations_cache_key(self) -> bool: # type: ignore + def _gen_static_annotations_cache_key(self) -> bool: """special attribute used by cache key gen, if true, we will use a static cache key for the annotations dictionary, else we will generate a new cache key for annotations each time. @@ -2315,8 +2334,8 @@ def _set_parent( # type: ignore[override] existing = getattr(self, "table", None) if existing is not None and existing is not table: raise exc.ArgumentError( - "Column object '%s' already assigned to Table '%s'" - % (self.key, existing.description) + f"Column object '{self.key}' already " + f"assigned to Table '{existing.description}'" ) extra_remove = None @@ -2376,9 +2395,8 @@ def _set_parent( # type: ignore[override] table.primary_key._replace(self) elif self.key in table.primary_key: raise exc.ArgumentError( - "Trying to redefine primary-key column '%s' as a " - "non-primary-key column on table '%s'" - % (self.key, table.fullname) + f"Trying to redefine primary-key column '{self.key}' as a " + f"non-primary-key column on table '{table.fullname}'" ) if self.index: @@ -2462,14 +2480,16 @@ def _copy(self, **kw: Any) -> Column[Any]: dialect_option_key, dialect_option_value, ) in dialect_options.items(): - column_kwargs[ - dialect_name + "_" + dialect_option_key - ] = dialect_option_value + column_kwargs[dialect_name + "_" + dialect_option_key] = ( + dialect_option_value + ) server_default = self.server_default server_onupdate = self.server_onupdate if isinstance(server_default, (Computed, Identity)): # TODO: likely should be copied in all cases + # TODO: if a Sequence, we would need to transfer the Sequence + # .metadata as well args.append(server_default._copy(**kw)) server_default = server_onupdate = None @@ -2573,8 +2593,11 @@ def _merge(self, other: Column[Any]) -> None: new_onupdate = self.onupdate._copy() new_onupdate._set_parent(other) - if self.index and not other.index: - other.index = True + if self.index in (True, False) and other.index is None: + other.index = self.index + + if self.unique in (True, False) and other.unique is None: + other.unique = self.unique if self.doc and other.doc is None: other.doc = self.doc @@ -2582,9 +2605,6 @@ def _merge(self, other: Column[Any]) -> None: if self.comment and other.comment is None: other.comment = self.comment - if self.unique and not other.unique: - other.unique = True - for const in self.constraints: if not const._type_bound: new_const = const._copy() @@ -2598,6 +2618,8 @@ def _merge(self, other: Column[Any]) -> None: def _make_proxy( self, selectable: FromClause, + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], name: Optional[str] = None, key: Optional[str] = None, name_is_truncatable: bool = False, @@ -2635,19 +2657,23 @@ def _make_proxy( ) try: c = self._constructor( - coercions.expect( - roles.TruncatedLabelRole, name if name else self.name - ) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect( + roles.TruncatedLabelRole, name if name else self.name + ) + if name_is_truncatable + else (name or self.name) + ), self.type, # this may actually be ._proxy_key when the key is incoming key=key if key else name if name else self.key, primary_key=self.primary_key, nullable=self.nullable, - _proxies=list(compound_select_cols) - if compound_select_cols - else [self], + _proxies=( + list(compound_select_cols) + if compound_select_cols + else [self] + ), *fk, ) except TypeError as err: @@ -2663,10 +2689,13 @@ def _make_proxy( c._propagate_attrs = selectable._propagate_attrs if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) + if self.primary_key: - selectable.primary_key.add(c) # type: ignore + primary_key.add(c) + if fk: - selectable.foreign_keys.update(fk) # type: ignore + foreign_keys.update(fk) # type: ignore + return c.key, c @@ -2712,9 +2741,9 @@ def insert_sentinel( return Column( name=name, type_=type_api.INTEGERTYPE if type_ is None else type_, - default=default - if default is not None - else _InsertSentinelColumnDefault(), + default=( + default if default is not None else _InsertSentinelColumnDefault() + ), _omit_from_statements=omit_from_statements, insert_sentinel=True, ) @@ -2727,8 +2756,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): object, e.g.:: - t = Table("remote_table", metadata, - Column("remote_id", ForeignKey("main_table.id")) + t = Table( + "remote_table", + metadata, + Column("remote_id", ForeignKey("main_table.id")), ) Note that ``ForeignKey`` is only a marker object that defines @@ -2805,9 +2836,18 @@ def __init__( issuing DDL for this constraint. Typical values include CASCADE, DELETE and RESTRICT. + .. seealso:: + + :ref:`on_update_on_delete` + :param ondelete: Optional string. If set, emit ON DELETE when issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. + SET NULL and RESTRICT. Some dialects may allow for additional + syntaxes. + + .. seealso:: + + :ref:`on_update_on_delete` :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when issuing DDL for this constraint. @@ -3028,7 +3068,7 @@ def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: m = self._get_colspec().split(".") if m is None: raise exc.ArgumentError( - "Invalid foreign key column specification: %s" % self._colspec + f"Invalid foreign key column specification: {self._colspec}" ) if len(m) == 1: tname = m.pop() @@ -3119,9 +3159,9 @@ def _link_to_col_by_colstring( if _column is None: raise exc.NoReferencedColumnError( "Could not initialize target column " - "for ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" - % (self._colspec, parenttable.name, table.name, key), + f"for ForeignKey '{self._colspec}' " + f"on table '{parenttable.name}': " + f"table '{table.name}' has no column named '{key}'", table.name, key, ) @@ -3159,14 +3199,14 @@ def column(self) -> Column[Any]: return self._resolve_column() @overload - def _resolve_column(self, *, raiseerr: Literal[True] = ...) -> Column[Any]: - ... + def _resolve_column( + self, *, raiseerr: Literal[True] = ... + ) -> Column[Any]: ... @overload def _resolve_column( self, *, raiseerr: bool = ... - ) -> Optional[Column[Any]]: - ... + ) -> Optional[Column[Any]]: ... def _resolve_column( self, *, raiseerr: bool = True @@ -3180,18 +3220,18 @@ def _resolve_column( if not raiseerr: return None raise exc.NoReferencedTableError( - "Foreign key associated with column '%s' could not find " - "table '%s' with which to generate a " - "foreign key to target column '%s'" - % (self.parent, tablekey, colname), + f"Foreign key associated with column " + f"'{self.parent}' could not find " + f"table '{tablekey}' with which to generate a " + f"foreign key to target column '{colname}'", tablekey, ) elif parenttable.key not in parenttable.metadata: if not raiseerr: return None raise exc.InvalidRequestError( - "Table %s is no longer associated with its " - "parent MetaData" % parenttable + f"Table {parenttable} is no longer associated with its " + "parent MetaData" ) else: table = parenttable.metadata.tables[tablekey] @@ -3283,18 +3323,15 @@ def _set_table(self, column: Column[Any], table: Table) -> None: def default_is_sequence( obj: Optional[DefaultGenerator], - ) -> TypeGuard[Sequence]: - ... + ) -> TypeGuard[Sequence]: ... def default_is_clause_element( obj: Optional[DefaultGenerator], - ) -> TypeGuard[ColumnElementColumnDefault]: - ... + ) -> TypeGuard[ColumnElementColumnDefault]: ... def default_is_scalar( obj: Optional[DefaultGenerator], - ) -> TypeGuard[ScalarElementColumnDefault]: - ... + ) -> TypeGuard[ScalarElementColumnDefault]: ... else: default_is_sequence = operator.attrgetter("is_sequence") @@ -3380,12 +3417,11 @@ class ColumnDefault(DefaultGenerator, ABC): For example, the following:: - Column('foo', Integer, default=50) + Column("foo", Integer, default=50) Is equivalent to:: - Column('foo', Integer, ColumnDefault(50)) - + Column("foo", Integer, ColumnDefault(50)) """ @@ -3394,21 +3430,18 @@ class ColumnDefault(DefaultGenerator, ABC): @overload def __new__( cls, arg: Callable[..., Any], for_update: bool = ... - ) -> CallableColumnDefault: - ... + ) -> CallableColumnDefault: ... @overload def __new__( cls, arg: ColumnElement[Any], for_update: bool = ... - ) -> ColumnElementColumnDefault: - ... + ) -> ColumnElementColumnDefault: ... # if I return ScalarElementColumnDefault here, which is what's actually # returned, mypy complains that # overloads overlap w/ incompatible return types. @overload - def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault: - ... + def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault: ... def __new__( cls, arg: Any = None, for_update: bool = False @@ -3550,8 +3583,7 @@ def _arg_is_typed(self) -> bool: class _CallableColumnDefaultProtocol(Protocol): - def __call__(self, context: ExecutionContext) -> Any: - ... + def __call__(self, context: ExecutionContext) -> Any: ... class CallableColumnDefault(ColumnDefault): @@ -3676,9 +3708,14 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): The :class:`.Sequence` is typically associated with a primary key column:: some_table = Table( - 'some_table', metadata, - Column('id', Integer, Sequence('some_table_seq', start=1), - primary_key=True) + "some_table", + metadata, + Column( + "id", + Integer, + Sequence("some_table_seq", start=1), + primary_key=True, + ), ) When CREATE TABLE is emitted for the above :class:`_schema.Table`, if the @@ -3789,11 +3826,11 @@ def __init__( :param cache: optional integer value; number of future values in the sequence which are calculated in advance. Renders the CACHE keyword - understood by Oracle and PostgreSQL. + understood by Oracle Database and PostgreSQL. :param order: optional boolean value; if ``True``, renders the - ORDER keyword, understood by Oracle, indicating the sequence is - definitively ordered. May be necessary to provide deterministic + ORDER keyword, understood by Oracle Database, indicating the sequence + is definitively ordered. May be necessary to provide deterministic ordering using Oracle RAC. :param data_type: The type to be returned by the sequence, for @@ -3938,10 +3975,10 @@ def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: def _not_a_column_expr(self) -> NoReturn: raise exc.InvalidRequestError( - "This %s cannot be used directly " + f"This {self.__class__.__name__} cannot be used directly " "as a column expression. Use func.next_value(sequence) " "to produce a 'next value' function that's usable " - "as a column element." % self.__class__.__name__ + "as a column element." ) @@ -3954,7 +3991,7 @@ class FetchedValue(SchemaEventTarget): E.g.:: - Column('foo', Integer, FetchedValue()) + Column("foo", Integer, FetchedValue()) Would indicate that some trigger or default generator will create a new value for the ``foo`` column during an @@ -4020,11 +4057,11 @@ class DefaultClause(FetchedValue): For example, the following:: - Column('foo', Integer, server_default="50") + Column("foo", Integer, server_default="50") Is equivalent to:: - Column('foo', Integer, DefaultClause("50")) + Column("foo", Integer, DefaultClause("50")) """ @@ -4184,8 +4221,7 @@ class ColumnCollectionMixin: def _set_parent_with_dispatch( self, parent: SchemaEventTarget, **kw: Any - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -4204,6 +4240,10 @@ def __init__( ] = _gather_expressions if processed_expressions is not None: + + # this is expected to be an empty list + assert not processed_expressions + self._pending_colargs = [] for ( expr, @@ -4265,12 +4305,11 @@ def _col_attached(column: Column[Any], table: Table) -> None: table = columns[0].table others = [c for c in columns[1:] if c.table is not table] if others: + # black could not format this inline + other_str = ", ".join("'%s'" % c for c in others) raise exc.ArgumentError( - "Column(s) %s are not part of table '%s'." - % ( - ", ".join("'%s'" % c for c in others), - table.description, - ) + f"Column(s) {other_str} " + f"are not part of table '{table.description}'." ) @util.ro_memoized_property @@ -4399,9 +4438,9 @@ def _copy( dialect_option_key, dialect_option_value, ) in dialect_options.items(): - constraint_kwargs[ - dialect_name + "_" + dialect_option_key - ] = dialect_option_value + constraint_kwargs[dialect_name + "_" + dialect_option_key] = ( + dialect_option_value + ) assert isinstance(self.parent, Table) c = self.__class__( @@ -4592,12 +4631,21 @@ def __init__( :param name: Optional, the in-database name of the key. :param onupdate: Optional string. If set, emit ON UPDATE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + .. seealso:: + + :ref:`on_update_on_delete` :param ondelete: Optional string. If set, emit ON DELETE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. + issuing DDL for this constraint. Typical values include CASCADE, + SET NULL and RESTRICT. Some dialects may allow for additional + syntaxes. + + .. seealso:: + + :ref:`on_update_on_delete` :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when issuing DDL for this constraint. @@ -4755,9 +4803,9 @@ def _validate_dest_table(self, table: Table) -> None: if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( - "ForeignKeyConstraint on %s(%s) refers to " - "multiple remote tables: %s and %s" - % (table.fullname, self._col_description, elem0, elem1) + f"ForeignKeyConstraint on " + f"{table.fullname}({self._col_description}) refers to " + f"multiple remote tables: {elem0} and {elem1}" ) @property @@ -4822,10 +4870,12 @@ def _copy( [ x._get_colspec( schema=schema, - table_name=target_table.name - if target_table is not None - and x._table_key() == x.parent.table.key - else None, + table_name=( + target_table.name + if target_table is not None + and x._table_key() == x.parent.table.key + else None + ), _is_copy=True, ) for x in self.elements @@ -4853,11 +4903,13 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): :class:`_schema.Column` objects corresponding to those marked with the :paramref:`_schema.Column.primary_key` flag:: - >>> my_table = Table('mytable', metadata, - ... Column('id', Integer, primary_key=True), - ... Column('version_id', Integer, primary_key=True), - ... Column('data', String(50)) - ... ) + >>> my_table = Table( + ... "mytable", + ... metadata, + ... Column("id", Integer, primary_key=True), + ... Column("version_id", Integer, primary_key=True), + ... Column("data", String(50)), + ... ) >>> my_table.primary_key PrimaryKeyConstraint( Column('id', Integer(), table=, @@ -4871,13 +4923,14 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): the "name" of the constraint can also be specified, as well as other options which may be recognized by dialects:: - my_table = Table('mytable', metadata, - Column('id', Integer), - Column('version_id', Integer), - Column('data', String(50)), - PrimaryKeyConstraint('id', 'version_id', - name='mytable_pk') - ) + my_table = Table( + "mytable", + metadata, + Column("id", Integer), + Column("version_id", Integer), + Column("data", String(50)), + PrimaryKeyConstraint("id", "version_id", name="mytable_pk"), + ) The two styles of column-specification should generally not be mixed. An warning is emitted if the columns present in the @@ -4895,13 +4948,14 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): primary key column collection from the :class:`_schema.Table` based on the flags:: - my_table = Table('mytable', metadata, - Column('id', Integer, primary_key=True), - Column('version_id', Integer, primary_key=True), - Column('data', String(50)), - PrimaryKeyConstraint(name='mytable_pk', - mssql_clustered=True) - ) + my_table = Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("version_id", Integer, primary_key=True), + Column("data", String(50)), + PrimaryKeyConstraint(name="mytable_pk", mssql_clustered=True), + ) """ @@ -4943,17 +4997,20 @@ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: and table_pks and set(table_pks) != set(self._columns) ): + # black could not format these inline + table_pk_str = ", ".join("'%s'" % c.name for c in table_pks) + col_str = ", ".join("'%s'" % c.name for c in self._columns) + util.warn( - "Table '%s' specifies columns %s as primary_key=True, " - "not matching locally specified columns %s; setting the " - "current primary key columns to %s. This warning " - "may become an exception in a future release" - % ( - table.name, - ", ".join("'%s'" % c.name for c in table_pks), - ", ".join("'%s'" % c.name for c in self._columns), - ", ".join("'%s'" % c.name for c in self._columns), - ) + f"Table '{table.name}' specifies columns " + f"{table_pk_str} as " + f"primary_key=True, " + f"not matching locally specified columns {col_str}; " + f"setting the " + f"current primary key columns to " + f"{col_str}. " + f"This warning " + f"may become an exception in a future release" ) table_pks[:] = [] @@ -5020,8 +5077,8 @@ def _validate_autoinc(col: Column[Any], autoinc_true: bool) -> bool: ): if autoinc_true: raise exc.ArgumentError( - "Column type %s on column '%s' is not " - "compatible with autoincrement=True" % (col.type, col) + f"Column type {col.type} on column '{col}' is not " + f"compatible with autoincrement=True" ) else: return False @@ -5064,9 +5121,9 @@ def _validate_autoinc(col: Column[Any], autoinc_true: bool) -> bool: _validate_autoinc(col, True) if autoinc is not None: raise exc.ArgumentError( - "Only one Column may be marked " - "autoincrement=True, found both %s and %s." - % (col.name, autoinc.name) + f"Only one Column may be marked " + f"autoincrement=True, found both " + f"{col.name} and {autoinc.name}." ) else: autoinc = col @@ -5095,19 +5152,21 @@ class Index( E.g.:: - sometable = Table("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)) - ) + sometable = Table( + "sometable", + metadata, + Column("name", String(50)), + Column("address", String(100)), + ) Index("some_index", sometable.c.name) For a no-frills, single column index, adding :class:`_schema.Column` also supports ``index=True``:: - sometable = Table("sometable", metadata, - Column("name", String(50), index=True) - ) + sometable = Table( + "sometable", metadata, Column("name", String(50), index=True) + ) For a composite index, multiple columns can be specified:: @@ -5126,22 +5185,26 @@ class Index( the names of the indexed columns can be specified as strings:: - Table("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - Index("some_index", "name", "address") - ) + Table( + "sometable", + metadata, + Column("name", String(50)), + Column("address", String(100)), + Index("some_index", "name", "address"), + ) To support functional or expression-based indexes in this form, the :func:`_expression.text` construct may be used:: from sqlalchemy import text - Table("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - Index("some_index", text("lower(name)")) - ) + Table( + "sometable", + metadata, + Column("name", String(50)), + Column("address", String(100)), + Index("some_index", text("lower(name)")), + ) .. seealso:: @@ -5237,9 +5300,9 @@ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: if self.table is not None and table is not self.table: raise exc.ArgumentError( - "Index '%s' is against table '%s', and " - "cannot be associated with table '%s'." - % (self.name, self.table.description, table.description) + f"Index '{self.name}' is against table " + f"'{self.table.description}', and " + f"cannot be associated with table '{table.description}'." ) self.table = table table.indexes.add(self) @@ -5486,9 +5549,9 @@ def __init__( self.info = info self._schemas: Set[str] = set() self._sequences: Dict[str, Sequence] = {} - self._fk_memos: Dict[ - Tuple[str, Optional[str]], List[ForeignKey] - ] = collections.defaultdict(list) + self._fk_memos: Dict[Tuple[str, Optional[str]], List[ForeignKey]] = ( + collections.defaultdict(list) + ) tables: util.FacadeDict[str, Table] """A dictionary of :class:`_schema.Table` @@ -5620,6 +5683,38 @@ def sorted_tables(self) -> List[Table]: sorted(self.tables.values(), key=lambda t: t.key) # type: ignore ) + # overload needed to work around mypy this mypy + # https://github.com/python/mypy/issues/17093 + @overload + def reflect( + self, + bind: Engine, + schema: Optional[str] = ..., + views: bool = ..., + only: Union[ + _typing_Sequence[str], Callable[[str, MetaData], bool], None + ] = ..., + extend_existing: bool = ..., + autoload_replace: bool = ..., + resolve_fks: bool = ..., + **dialect_kwargs: Any, + ) -> None: ... + + @overload + def reflect( + self, + bind: Connection, + schema: Optional[str] = ..., + views: bool = ..., + only: Union[ + _typing_Sequence[str], Callable[[str, MetaData], bool], None + ] = ..., + extend_existing: bool = ..., + autoload_replace: bool = ..., + resolve_fks: bool = ..., + **dialect_kwargs: Any, + ) -> None: ... + @util.preload_module("sqlalchemy.engine.reflection") def reflect( self, @@ -5774,9 +5869,10 @@ def reflect( missing = [name for name in only if name not in available] if missing: s = schema and (" schema '%s'" % schema) or "" + missing_str = ", ".join(missing) raise exc.InvalidRequestError( - "Could not reflect: requested table(s) not available " - "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing)) + f"Could not reflect: requested table(s) not available " + f"in {bind.engine!r}{s}: ({missing_str})" ) load = [ name @@ -5799,7 +5895,7 @@ def reflect( try: Table(name, self, **reflect_opts) except exc.UnreflectableTableError as uerr: - util.warn("Skipping table %s: %s" % (name, uerr)) + util.warn(f"Skipping table {name}: {uerr}") def create_all( self, @@ -5866,9 +5962,11 @@ class Computed(FetchedValue, SchemaItem): from sqlalchemy import Computed - Table('square', metadata_obj, - Column('side', Float, nullable=False), - Column('area', Float, Computed('side * side')) + Table( + "square", + metadata_obj, + Column("side", Float, nullable=False), + Column("area", Float, Computed("side * side")), ) See the linked documentation below for complete details. @@ -5973,9 +6071,11 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): from sqlalchemy import Identity - Table('foo', metadata_obj, - Column('id', Integer, Identity()) - Column('description', Text), + Table( + "foo", + metadata_obj, + Column("id", Integer, Identity()), + Column("description", Text), ) See the linked documentation below for complete details. @@ -6035,7 +6135,7 @@ def __init__( :param on_null: Set to ``True`` to specify ON NULL in conjunction with a ``always=False`` identity column. This option is only supported on - some backends, like Oracle. + some backends, like Oracle Database. :param start: the starting index of the sequence. :param increment: the increment value of the sequence. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 91b939e0af5..ef7605a64b9 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1,5 +1,5 @@ # sql/selectable.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -47,6 +47,7 @@ from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import _no_kw +from ._typing import _T from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement @@ -71,6 +72,7 @@ from .base import ColumnSet from .base import CompileState from .base import DedupeColumnCollection +from .base import DialectKWArgs from .base import Executable from .base import Generative from .base import HasCompileState @@ -101,9 +103,9 @@ from ..util.typing import Protocol from ..util.typing import Self + and_ = BooleanClauseList.and_ -_T = TypeVar("_T", bound=Any) if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -154,12 +156,10 @@ class _JoinTargetProtocol(Protocol): @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... + def _from_objects(self) -> List[FromClause]: ... @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... _JoinTargetElement = Union["FromClause", _JoinTargetProtocol] @@ -242,7 +242,11 @@ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: raise NotImplementedError() def _generate_fromclause_column_proxies( - self, fromclause: FromClause + self, + fromclause: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], ) -> None: """Populate columns into an :class:`.AliasedReturnsRows` object.""" @@ -284,7 +288,7 @@ class ExecutableReturnsRows(Executable, ReturnsRows): class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): - """base for executable statements that return rows.""" + """base for a typed executable statements that return rows.""" class Selectable(ReturnsRows): @@ -394,8 +398,7 @@ def prefix_with( stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") # MySQL 5.7 optimizer hints - stmt = select(table).prefix_with( - "/*+ BKA(t1) */", dialect="mysql") + stmt = select(table).prefix_with("/*+ BKA(t1) */", dialect="mysql") Multiple prefixes can be specified by multiple calls to :meth:`_expression.HasPrefixes.prefix_with`. @@ -442,8 +445,13 @@ def suffix_with( E.g.:: - stmt = select(col1, col2).cte().suffix_with( - "cycle empno set y_cycle to 1 default 0", dialect="oracle") + stmt = ( + select(col1, col2) + .cte() + .suffix_with( + "cycle empno set y_cycle to 1 default 0", dialect="oracle" + ) + ) Multiple suffixes can be specified by multiple calls to :meth:`_expression.HasSuffixes.suffix_with`. @@ -465,9 +473,9 @@ def suffix_with( class HasHints: - _hints: util.immutabledict[ - Tuple[FromClause, str], str - ] = util.immutabledict() + _hints: util.immutabledict[Tuple[FromClause, str], str] = ( + util.immutabledict() + ) _statement_hints: Tuple[Tuple[str, str], ...] = () _has_hints_traverse_internals: _TraverseInternalsType = [ @@ -475,14 +483,25 @@ class HasHints: ("_hints", InternalTraversal.dp_table_hint_list), ] + @_generative def with_statement_hint(self, text: str, dialect_name: str = "*") -> Self: """Add a statement hint to this :class:`_expression.Select` or other selectable object. - This method is similar to :meth:`_expression.Select.with_hint` - except that - it does not require an individual table, and instead applies to the - statement as a whole. + .. tip:: + + :meth:`_expression.Select.with_statement_hint` generally adds hints + **at the trailing end** of a SELECT statement. To place + dialect-specific hints such as optimizer hints at the **front** of + the SELECT statement after the SELECT keyword, use the + :meth:`_expression.Select.prefix_with` method for an open-ended + space, or for table-specific hints the + :meth:`_expression.Select.with_hint` may be used, which places + hints in a dialect-specific location. + + This method is similar to :meth:`_expression.Select.with_hint` except + that it does not require an individual table, and instead applies to + the statement as a whole. Hints here are specific to the backend database and may include directives such as isolation levels, file directives, fetch directives, @@ -494,7 +513,7 @@ def with_statement_hint(self, text: str, dialect_name: str = "*") -> Self: :meth:`_expression.Select.prefix_with` - generic SELECT prefixing which also can suit some database-specific HINT syntaxes such as - MySQL optimizer hints + MySQL or Oracle Database optimizer hints """ return self._with_hint(None, text, dialect_name) @@ -510,6 +529,17 @@ def with_hint( selectable to this :class:`_expression.Select` or other selectable object. + .. tip:: + + The :meth:`_expression.Select.with_hint` method adds hints that are + **specific to a single table** to a statement, in a location that + is **dialect-specific**. To add generic optimizer hints to the + **beginning** of a statement ahead of the SELECT keyword such as + for MySQL or Oracle Database, use the + :meth:`_expression.Select.prefix_with` method. To add optimizer + hints to the **end** of a statement such as for PostgreSQL, use the + :meth:`_expression.Select.with_statement_hint` method. + The text of the hint is rendered in the appropriate location for the database backend in use, relative to the given :class:`_schema.Table` or :class:`_expression.Alias` @@ -517,28 +547,33 @@ def with_hint( ``selectable`` argument. The dialect implementation typically uses Python string substitution syntax with the token ``%(name)s`` to render the name of - the table or alias. E.g. when using Oracle, the + the table or alias. E.g. when using Oracle Database, the following:: - select(mytable).\ - with_hint(mytable, "index(%(name)s ix_mytable)") + select(mytable).with_hint(mytable, "index(%(name)s ix_mytable)") - Would render SQL as:: + Would render SQL as: + + .. sourcecode:: sql select /*+ index(mytable ix_mytable) */ ... from mytable The ``dialect_name`` option will limit the rendering of a particular hint to a particular backend. Such as, to add hints for both Oracle - and Sybase simultaneously:: + Database and MSSql simultaneously:: - select(mytable).\ - with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\ - with_hint(mytable, "WITH INDEX ix_mytable", 'mssql') + select(mytable).with_hint( + mytable, "index(%(name)s ix_mytable)", "oracle" + ).with_hint(mytable, "WITH INDEX ix_mytable", "mssql") .. seealso:: :meth:`_expression.Select.with_statement_hint` + :meth:`_expression.Select.prefix_with` - generic SELECT prefixing + which also can suit some database-specific HINT syntaxes such as + MySQL or Oracle Database optimizer hints + """ return self._with_hint(selectable, text, dialect_name) @@ -641,11 +676,14 @@ def join( from sqlalchemy import join - j = user_table.join(address_table, - user_table.c.id == address_table.c.user_id) + j = user_table.join( + address_table, user_table.c.id == address_table.c.user_id + ) stmt = select(user_table).select_from(j) - would emit SQL along the lines of:: + would emit SQL along the lines of: + + .. sourcecode:: sql SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id @@ -691,15 +729,15 @@ def outerjoin( from sqlalchemy import outerjoin - j = user_table.outerjoin(address_table, - user_table.c.id == address_table.c.user_id) + j = user_table.outerjoin( + address_table, user_table.c.id == address_table.c.user_id + ) The above is equivalent to:: j = user_table.join( - address_table, - user_table.c.id == address_table.c.user_id, - isouter=True) + address_table, user_table.c.id == address_table.c.user_id, isouter=True + ) :param right: the right side of the join; this is any :class:`_expression.FromClause` object such as a @@ -721,7 +759,7 @@ def outerjoin( :class:`_expression.Join` - """ + """ # noqa: E501 return Join(self, right, onclause, True, full) @@ -732,7 +770,7 @@ def alias( E.g.:: - a2 = some_table.alias('a2') + a2 = some_table.alias("a2") The above code creates an :class:`_expression.Alias` object which can be used @@ -801,10 +839,17 @@ def description(self) -> str: return getattr(self, "name", self.__class__.__name__ + " object") def _generate_fromclause_column_proxies( - self, fromclause: FromClause + self, + fromclause: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], ) -> None: - fromclause._columns._populate_separate_keys( - col._make_proxy(fromclause) for col in self.c + columns._populate_separate_keys( + col._make_proxy( + fromclause, primary_key=primary_key, foreign_keys=foreign_keys + ) + for col in self.c ) @util.ro_non_memoized_property @@ -858,10 +903,30 @@ def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """ if "_columns" not in self.__dict__: - self._init_collections() - self._populate_column_collection() + self._setup_collections() return self._columns.as_readonly() + def _setup_collections(self) -> None: + assert "_columns" not in self.__dict__ + assert "primary_key" not in self.__dict__ + assert "foreign_keys" not in self.__dict__ + + _columns: ColumnCollection[Any, Any] = ColumnCollection() + primary_key = ColumnSet() + foreign_keys: Set[KeyedColumnElement[Any]] = set() + + self._populate_column_collection( + columns=_columns, + primary_key=primary_key, + foreign_keys=foreign_keys, + ) + + # assigning these three collections separately is not itself atomic, + # but greatly reduces the surface for problems + self._columns = _columns + self.primary_key = primary_key # type: ignore + self.foreign_keys = foreign_keys # type: ignore + @util.ro_non_memoized_property def entity_namespace(self) -> _EntityNamespace: """Return a namespace used for name-based access in SQL expressions. @@ -869,7 +934,7 @@ def entity_namespace(self) -> _EntityNamespace: This is the namespace that is used to resolve "filter_by()" type expressions, such as:: - stmt.filter_by(address='some address') + stmt.filter_by(address="some address") It defaults to the ``.c`` collection, however internally it can be overridden using the "entity_namespace" annotation to deliver @@ -888,8 +953,7 @@ def primary_key(self) -> Iterable[NamedColumn[Any]]: iterable collection of :class:`_schema.Column` objects. """ - self._init_collections() - self._populate_column_collection() + self._setup_collections() return self.primary_key @util.ro_memoized_property @@ -906,8 +970,7 @@ def foreign_keys(self) -> Iterable[ForeignKey]: :attr:`_schema.Table.foreign_key_constraints` """ - self._init_collections() - self._populate_column_collection() + self._setup_collections() return self.foreign_keys def _reset_column_collection(self) -> None: @@ -931,20 +994,16 @@ def _reset_column_collection(self) -> None: def _select_iterable(self) -> _SelectIterable: return (c for c in self.c if not _never_select_column(c)) - def _init_collections(self) -> None: - assert "_columns" not in self.__dict__ - assert "primary_key" not in self.__dict__ - assert "foreign_keys" not in self.__dict__ - - self._columns = ColumnCollection() - self.primary_key = ColumnSet() # type: ignore - self.foreign_keys = set() # type: ignore - @property def _cols_populated(self) -> bool: return "_columns" in self.__dict__ - def _populate_column_collection(self) -> None: + def _populate_column_collection( + self, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], + ) -> None: """Called on subclasses to establish the .c collection. Each implementation has a different way of establishing @@ -988,8 +1047,7 @@ def _anonymous_fromclause( def self_group( self, against: Optional[OperatorType] = None - ) -> Union[FromGrouping, Self]: - ... + ) -> Union[FromGrouping, Self]: ... class NamedFromClause(FromClause): @@ -1013,7 +1071,7 @@ def table_valued(self) -> TableValuedColumn[Any]: A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that represents a complete row in a table. Support for this construct is backend dependent, and is supported in various forms by backends - such as PostgreSQL, Oracle and SQL Server. + such as PostgreSQL, Oracle Database and SQL Server. E.g.: @@ -1053,7 +1111,11 @@ class SelectLabelStyle(Enum): >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_NONE >>> table1 = table("table1", column("columna"), column("columnb")) >>> table2 = table("table2", column("columna"), column("columnc")) - >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_NONE)) + >>> print( + ... select(table1, table2) + ... .join(table2, true()) + ... .set_label_style(LABEL_STYLE_NONE) + ... ) {printsql}SELECT table1.columna, table1.columnb, table2.columna, table2.columnc FROM table1 JOIN table2 ON true @@ -1075,10 +1137,20 @@ class SelectLabelStyle(Enum): .. sourcecode:: pycon+sql - >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_TABLENAME_PLUS_COL + >>> from sqlalchemy import ( + ... table, + ... column, + ... select, + ... true, + ... LABEL_STYLE_TABLENAME_PLUS_COL, + ... ) >>> table1 = table("table1", column("columna"), column("columnb")) >>> table2 = table("table2", column("columna"), column("columnc")) - >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)) + >>> print( + ... select(table1, table2) + ... .join(table2, true()) + ... .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + ... ) {printsql}SELECT table1.columna AS table1_columna, table1.columnb AS table1_columnb, table2.columna AS table2_columna, table2.columnc AS table2_columnc FROM table1 JOIN table2 ON true @@ -1104,10 +1176,20 @@ class SelectLabelStyle(Enum): .. sourcecode:: pycon+sql - >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_DISAMBIGUATE_ONLY + >>> from sqlalchemy import ( + ... table, + ... column, + ... select, + ... true, + ... LABEL_STYLE_DISAMBIGUATE_ONLY, + ... ) >>> table1 = table("table1", column("columna"), column("columnb")) >>> table2 = table("table2", column("columna"), column("columnc")) - >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)) + >>> print( + ... select(table1, table2) + ... .join(table2, true()) + ... .set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) + ... ) {printsql}SELECT table1.columna, table1.columnb, table2.columna AS columna_1, table2.columnc FROM table1 JOIN table2 ON true @@ -1245,26 +1327,30 @@ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: def self_group( self, against: Optional[OperatorType] = None ) -> FromGrouping: - ... return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") - def _populate_column_collection(self) -> None: + def _populate_column_collection( + self, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], + ) -> None: sqlutil = util.preloaded.sql_util - columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ + _columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c ] - self.primary_key.extend( # type: ignore + primary_key.extend( # type: ignore sqlutil.reduce_columns( - (c for c in columns if c.primary_key), self.onclause + (c for c in _columns if c.primary_key), self.onclause ) ) - self._columns._populate_separate_keys( - (col._tq_key_label, col) for col in columns + columns._populate_separate_keys( + (col._tq_key_label, col) for col in _columns # type: ignore ) - self.foreign_keys.update( # type: ignore - itertools.chain(*[col.foreign_keys for col in columns]) + foreign_keys.update( + itertools.chain(*[col.foreign_keys for col in _columns]) # type: ignore # noqa: E501 ) def _copy_internals( @@ -1291,7 +1377,7 @@ def _copy_internals( def replace( obj: Union[BinaryExpression[Any], ColumnClause[Any]], **kw: Any, - ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: + ) -> Optional[KeyedColumnElement[Any]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem @@ -1506,7 +1592,9 @@ def select(self) -> Select[Any]: stmt = stmt.select() - The above will produce a SQL string resembling:: + The above will produce a SQL string resembling: + + .. sourcecode:: sql SELECT table_a.id, table_a.col, table_b.id, table_b.a_id FROM table_a JOIN table_b ON table_a.id = table_b.a_id @@ -1520,11 +1608,23 @@ def _anonymous_fromclause( ) -> TODO_Any: sqlutil = util.preloaded.sql_util if flat: - if name is not None: - raise exc.ArgumentError("Can't send name argument with flat") + if isinstance(self.left, (FromGrouping, Join)): + left_name = name # will recurse + else: + if name and isinstance(self.left, NamedFromClause): + left_name = f"{name}_{self.left.name}" + else: + left_name = name + if isinstance(self.right, (FromGrouping, Join)): + right_name = name # will recurse + else: + if name and isinstance(self.right, NamedFromClause): + right_name = f"{name}_{self.right.name}" + else: + right_name = name left_a, right_a = ( - self.left._anonymous_fromclause(flat=True), - self.right._anonymous_fromclause(flat=True), + self.left._anonymous_fromclause(name=left_name, flat=flat), + self.right._anonymous_fromclause(name=right_name, flat=flat), ) adapter = sqlutil.ClauseAdapter(left_a).chain( sqlutil.ClauseAdapter(right_a) @@ -1633,8 +1733,15 @@ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - def _populate_column_collection(self) -> None: - self.element._generate_fromclause_column_proxies(self) + def _populate_column_collection( + self, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], + ) -> None: + self.element._generate_fromclause_column_proxies( + self, columns, primary_key=primary_key, foreign_keys=foreign_keys + ) @util.ro_non_memoized_property def description(self) -> str: @@ -1686,7 +1793,7 @@ class Alias(roles.DMLTableRole, FromClauseAlias): Represents an alias, as typically applied to any table or sub-select within a SQL statement using the ``AS`` keyword (or - without the keyword on certain databases such as Oracle). + without the keyword on certain databases such as Oracle Database). This object is constructed from the :func:`_expression.alias` module level function as well as the :meth:`_expression.FromClause.alias` @@ -1728,7 +1835,9 @@ class TableValuedAlias(LateralFromClause, Alias): .. sourcecode:: pycon+sql >>> from sqlalchemy import select, func - >>> fn = func.json_array_elements_text('["one", "two", "three"]').table_valued("value") + >>> fn = func.json_array_elements_text('["one", "two", "three"]').table_valued( + ... "value" + ... ) >>> print(select(fn.c.value)) {printsql}SELECT anon_1.value FROM json_array_elements_text(:json_array_elements_text_1) AS anon_1 @@ -1847,8 +1956,9 @@ def render_derived( >>> print( ... select( - ... func.unnest(array(["one", "two", "three"])). - table_valued("x", with_ordinality="o").render_derived() + ... func.unnest(array(["one", "two", "three"])) + ... .table_valued("x", with_ordinality="o") + ... .render_derived() ... ) ... ) {printsql}SELECT anon_1.x, anon_1.o @@ -1862,9 +1972,7 @@ def render_derived( >>> print( ... select( - ... func.json_to_recordset( - ... '[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]' - ... ) + ... func.json_to_recordset('[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]') ... .table_valued(column("a", Integer), column("b", String)) ... .render_derived(with_types=True) ... ) @@ -2073,11 +2181,26 @@ def _init( self._suffixes = _suffixes super()._init(selectable, name=name) - def _populate_column_collection(self) -> None: + def _populate_column_collection( + self, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], + ) -> None: if self._cte_alias is not None: - self._cte_alias._generate_fromclause_column_proxies(self) + self._cte_alias._generate_fromclause_column_proxies( + self, + columns, + primary_key=primary_key, + foreign_keys=foreign_keys, + ) else: - self.element._generate_fromclause_column_proxies(self) + self.element._generate_fromclause_column_proxies( + self, + columns, + primary_key=primary_key, + foreign_keys=foreign_keys, + ) def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: """Return an :class:`_expression.Alias` of this @@ -2103,7 +2226,7 @@ def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: _suffixes=self._suffixes, ) - def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2132,7 +2255,9 @@ def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: _suffixes=self._suffixes, ) - def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union_all( + self, *other: _SelectStatementForCompoundArgument[Any] + ) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. @@ -2256,9 +2381,9 @@ def _generate_columns_plus_names( repeated = False if not c._render_label_in_columns_clause: - effective_name = ( - required_label_name - ) = fallback_label_name = None + effective_name = required_label_name = fallback_label_name = ( + None + ) elif label_style_none: if TYPE_CHECKING: assert is_column_element(c) @@ -2270,9 +2395,9 @@ def _generate_columns_plus_names( assert is_column_element(c) if table_qualified: - required_label_name = ( - effective_name - ) = fallback_label_name = c._tq_label + required_label_name = effective_name = ( + fallback_label_name + ) = c._tq_label else: effective_name = fallback_label_name = c._non_anon_label required_label_name = None @@ -2303,9 +2428,9 @@ def _generate_columns_plus_names( else: fallback_label_name = c._anon_name_label else: - required_label_name = ( - effective_name - ) = fallback_label_name = expr_label + required_label_name = effective_name = ( + fallback_label_name + ) = expr_label if effective_name is not None: if TYPE_CHECKING: @@ -2319,13 +2444,13 @@ def _generate_columns_plus_names( # different column under the same name. apply # disambiguating label if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._anon_tq_label + required_label_name = fallback_label_name = ( + c._anon_tq_label + ) else: - required_label_name = ( - fallback_label_name - ) = c._anon_name_label + required_label_name = fallback_label_name = ( + c._anon_name_label + ) if anon_for_dupe_key and required_label_name in names: # here, c._anon_tq_label is definitely unique to @@ -2340,14 +2465,14 @@ def _generate_columns_plus_names( # subsequent occurrences of the column so that the # original stays non-ambiguous if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) dedupe_hash += 1 else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_label_idx(dedupe_hash) + ) dedupe_hash += 1 repeated = True else: @@ -2356,14 +2481,14 @@ def _generate_columns_plus_names( # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) dedupe_hash += 1 else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_label_idx(dedupe_hash) + ) dedupe_hash += 1 repeated = True else: @@ -2421,16 +2546,20 @@ def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: E.g.:: from sqlalchemy import table, column, select - t = table('t', column('c1'), column('c2')) + + t = table("t", column("c1"), column("c2")) ins = t.insert().values({"c1": "x", "c2": "y"}).cte() stmt = select(t).add_cte(ins) - Would render:: + Would render: - WITH anon_1 AS - (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2)) + .. sourcecode:: sql + + WITH anon_1 AS ( + INSERT INTO t (c1, c2) VALUES (:param_1, :param_2) + ) SELECT t.c1, t.c2 FROM t @@ -2446,9 +2575,7 @@ def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: t = table("t", column("c1"), column("c2")) - delete_statement_cte = ( - t.delete().where(t.c.c1 < 1).cte("deletions") - ) + delete_statement_cte = t.delete().where(t.c.c1 < 1).cte("deletions") insert_stmt = insert(t).values({"c1": 1, "c2": 2}) update_statement = insert_stmt.on_conflict_do_update( @@ -2461,10 +2588,13 @@ def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: print(update_statement) - The above statement renders as:: + The above statement renders as: - WITH deletions AS - (DELETE FROM t WHERE t.c1 < %(c1_1)s) + .. sourcecode:: sql + + WITH deletions AS ( + DELETE FROM t WHERE t.c1 < %(c1_1)s + ) INSERT INTO t (c1, c2) VALUES (%(c1)s, %(c2)s) ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, c2 = excluded.c2 @@ -2488,10 +2618,8 @@ def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: :paramref:`.HasCTE.cte.nesting` - """ - opt = _CTEOpts( - nest_here, - ) + """ # noqa: E501 + opt = _CTEOpts(nest_here) for cte in ctes: cte = coercions.expect(roles.IsCTERole, cte) self._independent_ctes += (cte,) @@ -2559,95 +2687,123 @@ def cte( Example 1, non recursive:: - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) + from sqlalchemy import ( + Table, + Column, + String, + Integer, + MetaData, + select, + func, + ) metadata = MetaData() - orders = Table('orders', metadata, - Column('region', String), - Column('amount', Integer), - Column('product', String), - Column('quantity', Integer) + orders = Table( + "orders", + metadata, + Column("region", String), + Column("amount", Integer), + Column("product", String), + Column("quantity", Integer), ) - regional_sales = select( - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ).group_by(orders.c.region).cte("regional_sales") + regional_sales = ( + select(orders.c.region, func.sum(orders.c.amount).label("total_sales")) + .group_by(orders.c.region) + .cte("regional_sales") + ) - top_regions = select(regional_sales.c.region).\ - where( - regional_sales.c.total_sales > - select( - func.sum(regional_sales.c.total_sales) / 10 - ) - ).cte("top_regions") + top_regions = ( + select(regional_sales.c.region) + .where( + regional_sales.c.total_sales + > select(func.sum(regional_sales.c.total_sales) / 10) + ) + .cte("top_regions") + ) - statement = select( - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ).where(orders.c.region.in_( - select(top_regions.c.region) - )).group_by(orders.c.region, orders.c.product) + statement = ( + select( + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales"), + ) + .where(orders.c.region.in_(select(top_regions.c.region))) + .group_by(orders.c.region, orders.c.product) + ) result = conn.execute(statement).fetchall() Example 2, WITH RECURSIVE:: - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) + from sqlalchemy import ( + Table, + Column, + String, + Integer, + MetaData, + select, + func, + ) metadata = MetaData() - parts = Table('parts', metadata, - Column('part', String), - Column('sub_part', String), - Column('quantity', Integer), + parts = Table( + "parts", + metadata, + Column("part", String), + Column("sub_part", String), + Column("quantity", Integer), ) - included_parts = select(\ - parts.c.sub_part, parts.c.part, parts.c.quantity\ - ).\ - where(parts.c.part=='our part').\ - cte(recursive=True) + included_parts = ( + select(parts.c.sub_part, parts.c.part, parts.c.quantity) + .where(parts.c.part == "our part") + .cte(recursive=True) + ) incl_alias = included_parts.alias() parts_alias = parts.alias() included_parts = included_parts.union_all( select( - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity - ).\ - where(parts_alias.c.part==incl_alias.c.sub_part) + parts_alias.c.sub_part, parts_alias.c.part, parts_alias.c.quantity + ).where(parts_alias.c.part == incl_alias.c.sub_part) ) statement = select( - included_parts.c.sub_part, - func.sum(included_parts.c.quantity). - label('total_quantity') - ).\ - group_by(included_parts.c.sub_part) + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label("total_quantity"), + ).group_by(included_parts.c.sub_part) result = conn.execute(statement).fetchall() Example 3, an upsert using UPDATE and INSERT with CTEs:: from datetime import date - from sqlalchemy import (MetaData, Table, Column, Integer, - Date, select, literal, and_, exists) + from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + Date, + select, + literal, + and_, + exists, + ) metadata = MetaData() - visitors = Table('visitors', metadata, - Column('product_id', Integer, primary_key=True), - Column('date', Date, primary_key=True), - Column('count', Integer), + visitors = Table( + "visitors", + metadata, + Column("product_id", Integer, primary_key=True), + Column("date", Date, primary_key=True), + Column("count", Integer), ) # add 5 visitors for the product_id == 1 @@ -2657,31 +2813,31 @@ def cte( update_cte = ( visitors.update() - .where(and_(visitors.c.product_id == product_id, - visitors.c.date == day)) + .where( + and_(visitors.c.product_id == product_id, visitors.c.date == day) + ) .values(count=visitors.c.count + count) .returning(literal(1)) - .cte('update_cte') + .cte("update_cte") ) upsert = visitors.insert().from_select( [visitors.c.product_id, visitors.c.date, visitors.c.count], - select(literal(product_id), literal(day), literal(count)) - .where(~exists(update_cte.select())) + select(literal(product_id), literal(day), literal(count)).where( + ~exists(update_cte.select()) + ), ) connection.execute(upsert) Example 4, Nesting CTE (SQLAlchemy 1.4.24 and above):: - value_a = select( - literal("root").label("n") - ).cte("value_a") + value_a = select(literal("root").label("n")).cte("value_a") # A nested CTE with the same name as the root one - value_a_nested = select( - literal("nesting").label("n") - ).cte("value_a", nesting=True) + value_a_nested = select(literal("nesting").label("n")).cte( + "value_a", nesting=True + ) # Nesting CTEs takes ascendency locally # over the CTEs at a higher level @@ -2690,7 +2846,9 @@ def cte( value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b")) The above query will render the second CTE nested inside the first, - shown with inline parameters below as:: + shown with inline parameters below as: + + .. sourcecode:: sql WITH value_a AS @@ -2705,21 +2863,17 @@ def cte( The same CTE can be set up using the :meth:`.HasCTE.add_cte` method as follows (SQLAlchemy 2.0 and above):: - value_a = select( - literal("root").label("n") - ).cte("value_a") + value_a = select(literal("root").label("n")).cte("value_a") # A nested CTE with the same name as the root one - value_a_nested = select( - literal("nesting").label("n") - ).cte("value_a") + value_a_nested = select(literal("nesting").label("n")).cte("value_a") # Nesting CTEs takes ascendency locally # over the CTEs at a higher level value_b = ( - select(value_a_nested.c.n). - add_cte(value_a_nested, nest_here=True). - cte("value_b") + select(value_a_nested.c.n) + .add_cte(value_a_nested, nest_here=True) + .cte("value_b") ) value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b")) @@ -2734,9 +2888,7 @@ def cte( Column("right", Integer), ) - root_node = select(literal(1).label("node")).cte( - "nodes", recursive=True - ) + root_node = select(literal(1).label("node")).cte("nodes", recursive=True) left_edge = select(edge.c.left).join( root_node, edge.c.right == root_node.c.node @@ -2749,7 +2901,9 @@ def cte( subgraph = select(subgraph_cte) - The above query will render 2 UNIONs inside the recursive CTE:: + The above query will render 2 UNIONs inside the recursive CTE: + + .. sourcecode:: sql WITH RECURSIVE nodes(node) AS ( SELECT 1 AS node @@ -2767,7 +2921,7 @@ def cte( :meth:`_orm.Query.cte` - ORM version of :meth:`_expression.HasCTE.cte`. - """ + """ # noqa: E501 return CTE._construct( self, name=name, recursive=recursive, nesting=nesting ) @@ -2846,9 +3000,6 @@ class FromGrouping(GroupedElement, FromClause): def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) - def _init_collections(self) -> None: - pass - @util.ro_non_memoized_property def columns( self, @@ -2892,6 +3043,12 @@ def __getstate__(self) -> Dict[str, FromClause]: def __setstate__(self, state: Dict[str, FromClause]) -> None: self.element = state["element"] + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class NamedFromGrouping(FromGrouping, NamedFromClause): """represent a grouping of a named FROM clause @@ -2902,6 +3059,12 @@ class NamedFromGrouping(FromGrouping, NamedFromClause): inherit_cache = True + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. @@ -2912,10 +3075,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): from sqlalchemy import table, column - user = table("user", - column("id"), - column("name"), - column("description"), + user = table( + "user", + column("id"), + column("name"), + column("description"), ) The :class:`_expression.TableClause` construct serves as the base for @@ -2980,12 +3144,12 @@ def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): if TYPE_CHECKING: @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... + def columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... + def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... def __str__(self) -> str: if self.schema is not None: @@ -2996,9 +3160,6 @@ def __str__(self) -> str: def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: pass - def _init_collections(self) -> None: - pass - @util.ro_memoized_property def description(self) -> str: return self.name @@ -3021,7 +3182,7 @@ def insert(self) -> util.preloaded.sql_dml.Insert: E.g.:: - table.insert().values(name='foo') + table.insert().values(name="foo") See :func:`_expression.insert` for argument and usage information. @@ -3036,7 +3197,7 @@ def update(self) -> Update: E.g.:: - table.update().where(table.c.id==7).values(name='foo') + table.update().where(table.c.id == 7).values(name="foo") See :func:`_expression.update` for argument and usage information. @@ -3052,7 +3213,7 @@ def delete(self) -> Delete: E.g.:: - table.delete().where(table.c.id==7) + table.delete().where(table.c.id == 7) See :func:`_expression.delete` for argument and usage information. @@ -3073,6 +3234,7 @@ class ForUpdateArg(ClauseElement): ("nowait", InternalTraversal.dp_boolean), ("read", InternalTraversal.dp_boolean), ("skip_locked", InternalTraversal.dp_boolean), + ("key_share", InternalTraversal.dp_boolean), ] of: Optional[Sequence[ClauseElement]] @@ -3239,7 +3401,7 @@ def data(self, values: Sequence[Tuple[Any, ...]]) -> Self: E.g.:: - my_values = my_values.data([(1, 'value 1'), (2, 'value2')]) + my_values = my_values.data([(1, "value 1"), (2, "value2")]) :param values: a sequence (i.e. list) of tuples that map to the column expressions given in the :class:`_expression.Values` @@ -3259,16 +3421,23 @@ def scalar_values(self) -> ScalarValues: """ return ScalarValues(self._column_args, self._data, self.literal_binds) - def _populate_column_collection(self) -> None: + def _populate_column_collection( + self, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], + ) -> None: for c in self._column_args: if c.table is not None and c.table is not self: - _, c = c._make_proxy(self) + _, c = c._make_proxy( + self, primary_key=primary_key, foreign_keys=foreign_keys + ) else: # if the column was used in other contexts, ensure # no memoizations of other FROM clauses. # see test_values.py -> test_auto_proxy_select_direct_col c._reset_memoizations() - self._columns.add(c) + columns.add(c) c.table = self @util.ro_non_memoized_property @@ -3315,6 +3484,12 @@ def _column_types(self) -> List[TypeEngine[Any]]: def __clause_element__(self) -> ScalarValues: return self + if TYPE_CHECKING: + + def self_group( + self, against: Optional[OperatorType] = None + ) -> Self: ... + class SelectBase( roles.SelectStatementRole, @@ -3378,6 +3553,9 @@ def selected_columns( def _generate_fromclause_column_proxies( self, subquery: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] @@ -3578,7 +3756,9 @@ def subquery(self, name: Optional[str] = None) -> Subquery: stmt = select(table.c.id, table.c.name) - The above statement might look like:: + The above statement might look like: + + .. sourcecode:: sql SELECT table.id, table.name FROM table @@ -3589,7 +3769,9 @@ def subquery(self, name: Optional[str] = None) -> Subquery: subq = stmt.subquery() new_stmt = select(subq) - The above renders as:: + The above renders as: + + .. sourcecode:: sql SELECT anon_1.id, anon_1.name FROM (SELECT table.id, table.name FROM table) AS anon_1 @@ -3654,7 +3836,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): __visit_name__ = "select_statement_grouping" _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement) - ] + ] + SupportsCloneAnnotations._clone_annotations_traverse_internals _is_select_container = True @@ -3687,13 +3869,11 @@ def select_statement(self) -> _SB: return self.element def self_group(self, against: Optional[OperatorType] = None) -> Self: - ... return self if TYPE_CHECKING: - def _ungroup(self) -> _SB: - ... + def _ungroup(self) -> _SB: ... # def _generate_columns_plus_names( # self, anon_for_dupe_key: bool @@ -3703,13 +3883,20 @@ def _ungroup(self) -> _SB: def _generate_fromclause_column_proxies( self, subquery: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] ] = None, ) -> None: self.element._generate_fromclause_column_proxies( - subquery, proxy_compound_columns=proxy_compound_columns + subquery, + columns, + proxy_compound_columns=proxy_compound_columns, + primary_key=primary_key, + foreign_keys=foreign_keys, ) @util.ro_non_memoized_property @@ -3736,8 +3923,12 @@ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects + def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: + # SelectStatementGrouping not generative: has no attribute '_generate' + raise NotImplementedError + -class GenerativeSelect(SelectBase, Generative): +class GenerativeSelect(DialectKWArgs, SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -3781,13 +3972,17 @@ def with_for_update( stmt = select(table).with_for_update(nowait=True) - On a database like PostgreSQL or Oracle, the above would render a - statement like:: + On a database like PostgreSQL or Oracle Database, the above would + render a statement like: + + .. sourcecode:: sql SELECT table.a, table.b FROM table FOR UPDATE NOWAIT on other backends, the ``nowait`` option is ignored and instead - would produce:: + would produce: + + .. sourcecode:: sql SELECT table.a, table.b FROM table FOR UPDATE @@ -3797,7 +3992,7 @@ def with_for_update( variants. :param nowait: boolean; will render ``FOR UPDATE NOWAIT`` on Oracle - and PostgreSQL dialects. + Database and PostgreSQL dialects. :param read: boolean; will render ``LOCK IN SHARE MODE`` on MySQL, ``FOR SHARE`` on PostgreSQL. On PostgreSQL, when combined with @@ -3806,13 +4001,13 @@ def with_for_update( :param of: SQL expression or list of SQL expression elements, (typically :class:`_schema.Column` objects or a compatible expression, for some backends may also be a table expression) which will render - into a ``FOR UPDATE OF`` clause; supported by PostgreSQL, Oracle, some - MySQL versions and possibly others. May render as a table or as a - column depending on backend. + into a ``FOR UPDATE OF`` clause; supported by PostgreSQL, Oracle + Database, some MySQL versions and possibly others. May render as a + table or as a column depending on backend. - :param skip_locked: boolean, will render ``FOR UPDATE SKIP LOCKED`` - on Oracle and PostgreSQL dialects or ``FOR SHARE SKIP LOCKED`` if - ``read=True`` is also specified. + :param skip_locked: boolean, will render ``FOR UPDATE SKIP LOCKED`` on + Oracle Database and PostgreSQL dialects or ``FOR SHARE SKIP LOCKED`` + if ``read=True`` is also specified. :param key_share: boolean, will render ``FOR NO KEY UPDATE``, or if combined with ``read=True`` will render ``FOR KEY SHARE``, @@ -3844,7 +4039,7 @@ def set_label_style(self, style: SelectLabelStyle) -> Self: :attr:`_sql.SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY`, :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`, and :attr:`_sql.SelectLabelStyle.LABEL_STYLE_NONE`. The default style is - :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`. + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY`. In modern SQLAlchemy, there is not generally a need to change the labeling style, as per-expression labels are more effectively used by @@ -3913,14 +4108,12 @@ def _offset_or_limit_clause( @overload def _offset_or_limit_clause_asint( self, clause: ColumnElement[Any], attrname: str - ) -> NoReturn: - ... + ) -> NoReturn: ... @overload def _offset_or_limit_clause_asint( self, clause: Optional[_OffsetLimitParam], attrname: str - ) -> Optional[int]: - ... + ) -> Optional[int]: ... def _offset_or_limit_clause_asint( self, clause: Optional[ColumnElement[Any]], attrname: str @@ -4016,14 +4209,15 @@ def fetch( count: _LimitOffsetType, with_ties: bool = False, percent: bool = False, + **dialect_kw: Any, ) -> Self: - """Return a new selectable with the given FETCH FIRST criterion + r"""Return a new selectable with the given FETCH FIRST criterion applied. - This is a numeric value which usually renders as - ``FETCH {FIRST | NEXT} [ count ] {ROW | ROWS} {ONLY | WITH TIES}`` - expression in the resulting select. This functionality is - is currently implemented for Oracle, PostgreSQL, MSSQL. + This is a numeric value which usually renders as ``FETCH {FIRST | NEXT} + [ count ] {ROW | ROWS} {ONLY | WITH TIES}`` expression in the resulting + select. This functionality is is currently implemented for Oracle + Database, PostgreSQL, MSSQL. Use :meth:`_sql.GenerativeSelect.offset` to specify the offset. @@ -4047,6 +4241,11 @@ def fetch( :param percent: When ``True``, ``count`` represents the percentage of the total number of selected rows to return. Defaults to ``False`` + :param \**dialect_kw: Additional dialect-specific keyword arguments + may be accepted by dialects. + + .. versionadded:: 2.0.41 + .. seealso:: :meth:`_sql.GenerativeSelect.limit` @@ -4054,7 +4253,7 @@ def fetch( :meth:`_sql.GenerativeSelect.offset` """ - + self._validate_dialect_kwargs(dialect_kw) self._limit_clause = None if count is None: self._fetch_clause = self._fetch_clause_options = None @@ -4107,7 +4306,7 @@ def slice( For example, :: - stmt = select(User).order_by(User).id.slice(1, 3) + stmt = select(User).order_by(User.id).slice(1, 3) renders as @@ -4206,8 +4405,7 @@ def group_by( e.g.:: - stmt = select(table.c.name, func.max(table.c.stat)).\ - group_by(table.c.name) + stmt = select(table.c.name, func.max(table.c.stat)).group_by(table.c.name) :param \*clauses: a series of :class:`_expression.ColumnElement` constructs @@ -4220,7 +4418,7 @@ def group_by( :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial` - """ + """ # noqa: E501 if not clauses and __first is None: self._group_by_clauses = () @@ -4260,7 +4458,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): +class CompoundSelect(HasCompileState, GenerativeSelect, TypedReturnsRows[_TP]): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4283,17 +4481,22 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): __visit_name__ = "compound_select" - _traverse_internals: _TraverseInternalsType = [ - ("selects", InternalTraversal.dp_clauseelement_list), - ("_limit_clause", InternalTraversal.dp_clauseelement), - ("_offset_clause", InternalTraversal.dp_clauseelement), - ("_fetch_clause", InternalTraversal.dp_clauseelement), - ("_fetch_clause_options", InternalTraversal.dp_plain_dict), - ("_order_by_clauses", InternalTraversal.dp_clauseelement_list), - ("_group_by_clauses", InternalTraversal.dp_clauseelement_list), - ("_for_update_arg", InternalTraversal.dp_clauseelement), - ("keyword", InternalTraversal.dp_string), - ] + SupportsCloneAnnotations._clone_annotations_traverse_internals + _traverse_internals: _TraverseInternalsType = ( + [ + ("selects", InternalTraversal.dp_clauseelement_list), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), + ("_fetch_clause", InternalTraversal.dp_clauseelement), + ("_fetch_clause_options", InternalTraversal.dp_plain_dict), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_list), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_list), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("keyword", InternalTraversal.dp_string), + ] + + SupportsCloneAnnotations._clone_annotations_traverse_internals + + HasCTE._has_ctes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + ) selects: List[SelectBase] @@ -4303,7 +4506,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): def __init__( self, keyword: _CompoundSelectKeyword, - *selects: _SelectStatementForCompoundArgument, + *selects: _SelectStatementForCompoundArgument[_TP], ): self.keyword = keyword self.selects = [ @@ -4317,38 +4520,38 @@ def __init__( @classmethod def _create_union( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod def _create_union_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod def _create_except( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod def _create_except_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod def _create_intersect( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod def _create_intersect_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) def _scalar_type(self) -> TypeEngine[Any]: @@ -4365,7 +4568,7 @@ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool: return True return False - def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: + def set_label_style(self, style: SelectLabelStyle) -> Self: if self._label_style is not style: self = self._generate() select_0 = self.selects[0].set_label_style(style) @@ -4373,7 +4576,7 @@ def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: return self - def _ensure_disambiguated_names(self) -> CompoundSelect: + def _ensure_disambiguated_names(self) -> Self: new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -4384,6 +4587,9 @@ def _ensure_disambiguated_names(self) -> CompoundSelect: def _generate_fromclause_column_proxies( self, subquery: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] @@ -4424,7 +4630,11 @@ def _generate_fromclause_column_proxies( # i haven't tried to think what it means for compound nested in # compound select_0._generate_fromclause_column_proxies( - subquery, proxy_compound_columns=extra_col_iterator + subquery, + columns, + proxy_compound_columns=extra_col_iterator, + primary_key=primary_key, + foreign_keys=foreign_keys, ) def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: @@ -4485,13 +4695,14 @@ class default_select_compile_options(CacheableOptions): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Executable) -> Type[SelectState]: - ... + def get_plugin_class( + cls, statement: Executable + ) -> Type[SelectState]: ... def __init__( self, statement: Select[Any], - compiler: Optional[SQLCompiler], + compiler: SQLCompiler, **kw: Any, ): self.statement = statement @@ -4552,6 +4763,7 @@ def _column_naming_convention( cls, label_style: SelectLabelStyle ) -> _LabelConventionCallable: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL + dedupe = label_style is not LABEL_STYLE_NONE pa = prefix_anon_map() @@ -5105,6 +5317,7 @@ class Select( + HasHints._has_hints_traverse_internals + SupportsCloneAnnotations._clone_annotations_traverse_internals + Executable._executable_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals ) _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [ @@ -5126,7 +5339,9 @@ def _create_raw_select(cls, **kw: Any) -> Select[Any]: stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument[Any]): + def __init__( + self, *entities: _ColumnsClauseArgument[Any], **dialect_kw: Any + ): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5139,7 +5354,6 @@ def __init__(self, *entities: _ColumnsClauseArgument[Any]): ) for ent in entities ] - GenerativeSelect.__init__(self) def _scalar_type(self) -> TypeEngine[Any]: @@ -5177,21 +5391,17 @@ def _filter_by_zero( @overload def scalar_subquery( self: Select[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[Any]: - ... + ) -> ScalarSelect[Any]: ... @overload def scalar_subquery( self: Select[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... def filter_by(self, **kwargs: Any) -> Self: r"""apply the given filtering criterion as a WHERE clause @@ -5291,11 +5501,17 @@ def join( E.g.:: - stmt = select(user_table).join(address_table, user_table.c.id == address_table.c.user_id) + stmt = select(user_table).join( + address_table, user_table.c.id == address_table.c.user_id + ) - The above statement generates SQL similar to:: + The above statement generates SQL similar to: - SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id + .. sourcecode:: sql + + SELECT user.id, user.name + FROM user + JOIN address ON user.id = address.user_id .. versionchanged:: 1.4 :meth:`_expression.Select.join` now creates a :class:`_sql.Join` object between a :class:`_sql.FromClause` @@ -5399,7 +5615,9 @@ def join_from( user_table, address_table, user_table.c.id == address_table.c.user_id ) - The above statement generates SQL similar to:: + The above statement generates SQL similar to: + + .. sourcecode:: sql SELECT user.id, user.name, address.id, address.email, address.user_id FROM user JOIN address ON user.id = address.user_id @@ -5534,8 +5752,9 @@ def get_final_froms(self) -> Sequence[FromClause]: :attr:`_sql.Select.columns_clause_froms` """ + compiler = self._default_compiler() - return self._compile_state_factory(self, None)._get_display_froms() + return self._compile_state_factory(self, compiler)._get_display_froms() @property @util.deprecated( @@ -5636,7 +5855,7 @@ def _copy_internals( def replace( obj: Union[BinaryExpression[Any], ColumnClause[Any]], **kw: Any, - ) -> Optional[KeyedColumnElement[ColumnElement[Any]]]: + ) -> Optional[KeyedColumnElement[Any]]: if isinstance(obj, ColumnClause) and obj.table in new_froms: newelem = new_froms[obj.table].corresponding_column(obj) return newelem @@ -5764,26 +5983,34 @@ def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: ) return woc - # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + # START OVERLOADED FUNCTIONS self.with_only_columns Select 1-8 ", *, maintain_column_froms: bool =..." # noqa: E501 # code within this block is **programmatically, - # statically generated** by tools/generate_sel_v1_overloads.py + # statically generated** by tools/generate_tuple_map_overloads.py @overload - def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: - ... + def with_only_columns( + self, __ent0: _TCCA[_T0], *, maintain_column_froms: bool = ... + ) -> Select[Tuple[_T0]]: ... @overload def with_only_columns( - self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> Select[Tuple[_T0, _T1]]: - ... + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1]]: ... @overload def with_only_columns( - self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> Select[Tuple[_T0, _T1, _T2]]: - ... + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2]]: ... @overload def with_only_columns( @@ -5792,8 +6019,9 @@ def with_only_columns( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: - ... + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ... @overload def with_only_columns( @@ -5803,8 +6031,9 @@ def with_only_columns( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: - ... + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... @overload def with_only_columns( @@ -5815,8 +6044,9 @@ def with_only_columns( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: - ... + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... @overload def with_only_columns( @@ -5828,8 +6058,9 @@ def with_only_columns( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: - ... + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... @overload def with_only_columns( @@ -5842,8 +6073,9 @@ def with_only_columns( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: - ... + *, + maintain_column_froms: bool = ..., + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... # END OVERLOADED FUNCTIONS self.with_only_columns @@ -5853,8 +6085,7 @@ def with_only_columns( *entities: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, **__kw: Any, - ) -> Select[Any]: - ... + ) -> Select[Any]: ... @_generative def with_only_columns( @@ -5990,11 +6221,29 @@ def having(self, *having: _ColumnExpressionArgument[bool]) -> Self: @_generative def distinct(self, *expr: _ColumnExpressionArgument[Any]) -> Self: r"""Return a new :func:`_expression.select` construct which - will apply DISTINCT to its columns clause. + will apply DISTINCT to the SELECT statement overall. + + E.g.:: + + from sqlalchemy import select + + stmt = select(users_table.c.id, users_table.c.name).distinct() + + The above would produce an statement resembling: + + .. sourcecode:: sql + + SELECT DISTINCT user.id, user.name FROM user + + The method also accepts an ``*expr`` parameter which produces the + PostgreSQL dialect-specific ``DISTINCT ON`` expression. Using this + parameter on other backends which don't support this syntax will + raise an error. :param \*expr: optional column expressions. When present, - the PostgreSQL dialect will render a ``DISTINCT ON (>)`` - construct. + the PostgreSQL dialect will render a ``DISTINCT ON ()`` + construct. A deprecation warning and/or :class:`_exc.CompileError` + will be raised on other backends. .. deprecated:: 1.4 Using \*expr in other dialects is deprecated and will raise :class:`_exc.CompileError` in a future version. @@ -6018,12 +6267,11 @@ def select_from(self, *froms: _FromClauseArgument) -> Self: E.g.:: - table1 = table('t1', column('a')) - table2 = table('t2', column('b')) - s = select(table1.c.a).\ - select_from( - table1.join(table2, table1.c.a==table2.c.b) - ) + table1 = table("t1", column("a")) + table2 = table("t2", column("b")) + s = select(table1.c.a).select_from( + table1.join(table2, table1.c.a == table2.c.b) + ) The "from" list is a unique set on the identity of each element, so adding an already present :class:`_schema.Table` @@ -6042,7 +6290,7 @@ def select_from(self, *froms: _FromClauseArgument) -> Self: if desired, in the case that the FROM clause cannot be fully derived from the columns clause:: - select(func.count('*')).select_from(table1) + select(func.count("*")).select_from(table1) """ @@ -6195,8 +6443,8 @@ def selected_columns( :class:`_expression.ColumnElement` objects are directly present as they were given, e.g.:: - col1 = column('q', Integer) - col2 = column('p', Integer) + col1 = column("q", Integer) + col2 = column("p", Integer) stmt = select(col1, col2) Above, ``stmt.selected_columns`` would be a collection that contains @@ -6211,7 +6459,8 @@ def selected_columns( criteria, e.g.:: def filter_on_id(my_select, id): - return my_select.where(my_select.selected_columns['id'] == id) + return my_select.where(my_select.selected_columns["id"] == id) + stmt = select(MyModel) @@ -6263,6 +6512,9 @@ def _ensure_disambiguated_names(self) -> Select[Any]: def _generate_fromclause_column_proxies( self, subquery: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] @@ -6280,6 +6532,8 @@ def _generate_fromclause_column_proxies( name=required_label_name, name_is_truncatable=True, compound_select_cols=extra_cols, + primary_key=primary_key, + foreign_keys=foreign_keys, ) for ( ( @@ -6305,6 +6559,8 @@ def _generate_fromclause_column_proxies( key=proxy_key, name=required_label_name, name_is_truncatable=True, + primary_key=primary_key, + foreign_keys=foreign_keys, ) for ( required_label_name, @@ -6316,7 +6572,7 @@ def _generate_fromclause_column_proxies( if is_column_element(c) ] - subquery._columns._populate_separate_keys(prox) + columns._populate_separate_keys(prox) def _needs_parens_for_grouping(self) -> bool: return self._has_row_limiting_clause or bool( @@ -6326,7 +6582,6 @@ def _needs_parens_for_grouping(self) -> bool: def self_group( self, against: Optional[OperatorType] = None ) -> Union[SelectStatementGrouping[Self], Self]: - ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -6344,8 +6599,8 @@ def self_group( return SelectStatementGrouping(self) def union( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -6363,8 +6618,8 @@ def union( return CompoundSelect._create_union(self, *other) def union_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6382,8 +6637,8 @@ def union_all( return CompoundSelect._create_union_all(self, *other) def except_( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -6398,8 +6653,8 @@ def except_( return CompoundSelect._create_except(self, *other) def except_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6414,8 +6669,8 @@ def except_all( return CompoundSelect._create_except_all(self, *other) def intersect( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -6433,8 +6688,8 @@ def intersect( return CompoundSelect._create_intersect(self, *other) def intersect_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6518,27 +6773,12 @@ def where(self, crit: _ColumnExpressionArgument[bool]) -> Self: self.element = cast("Select[Any]", self.element).where(crit) return self - @overload - def self_group( - self: ScalarSelect[Any], against: Optional[OperatorType] = None - ) -> ScalarSelect[Any]: - ... - - @overload - def self_group( - self: ColumnElement[Any], against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: - ... - - def self_group( - self, against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: + def self_group(self, against: Optional[OperatorType] = None) -> Self: return self if TYPE_CHECKING: - def _ungroup(self) -> Select[Any]: - ... + def _ungroup(self) -> Select[Any]: ... @_generative def correlate( @@ -6669,14 +6909,16 @@ def _regroup( assert isinstance(return_value, SelectStatementGrouping) return return_value - def select(self) -> Select[Any]: + def select(self) -> Select[Tuple[bool]]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: stmt = exists(some_table.c.id).where(some_table.c.id == 5).select() - This will produce a statement resembling:: + This will produce a statement resembling: + + .. sourcecode:: sql SELECT EXISTS (SELECT id FROM some_table WHERE some_table = :param) AS anon_1 @@ -6792,10 +7034,14 @@ class was renamed _label_style = LABEL_STYLE_NONE - _traverse_internals: _TraverseInternalsType = [ - ("element", InternalTraversal.dp_clauseelement), - ("column_args", InternalTraversal.dp_clauseelement_list), - ] + SupportsCloneAnnotations._clone_annotations_traverse_internals + _traverse_internals: _TraverseInternalsType = ( + [ + ("element", InternalTraversal.dp_clauseelement), + ("column_args", InternalTraversal.dp_clauseelement_list), + ] + + SupportsCloneAnnotations._clone_annotations_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) _is_textual = True @@ -6878,6 +7124,9 @@ def bindparams( def _generate_fromclause_column_proxies( self, fromclause: FromClause, + columns: ColumnCollection[str, KeyedColumnElement[Any]], + primary_key: ColumnSet, + foreign_keys: Set[KeyedColumnElement[Any]], *, proxy_compound_columns: Optional[ Iterable[Sequence[ColumnElement[Any]]] @@ -6887,15 +7136,25 @@ def _generate_fromclause_column_proxies( assert isinstance(fromclause, Subquery) if proxy_compound_columns: - fromclause._columns._populate_separate_keys( - c._make_proxy(fromclause, compound_select_cols=extra_cols) + columns._populate_separate_keys( + c._make_proxy( + fromclause, + compound_select_cols=extra_cols, + primary_key=primary_key, + foreign_keys=foreign_keys, + ) for c, extra_cols in zip( self.column_args, proxy_compound_columns ) ) else: - fromclause._columns._populate_separate_keys( - c._make_proxy(fromclause) for c in self.column_args + columns._populate_separate_keys( + c._make_proxy( + fromclause, + primary_key=primary_key, + foreign_keys=foreign_keys, + ) + for c in self.column_args ) def _scalar_type(self) -> Union[TypeEngine[Any], Any]: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index ddee7767bc3..fc278678b2e 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1,14 +1,12 @@ # sql/sqltypes.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -"""SQL specific types. - -""" +"""SQL specific types.""" from __future__ import annotations import collections.abc as collections_abc @@ -21,7 +19,9 @@ from typing import Callable from typing import cast from typing import Dict +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence @@ -59,17 +59,22 @@ from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util import warn_deprecated +from ..util.typing import get_args from ..util.typing import is_literal +from ..util.typing import is_pep695 from ..util.typing import Literal -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _CreateDropBind from ._typing import _TypeEngineArgument + from .elements import ColumnElement from .operators import OperatorType from .schema import MetaData from .type_api import _BindProcessorType from .type_api import _ComparatorFactory + from .type_api import _LiteralProcessorType from .type_api import _MatchedOnType from .type_api import _ResultProcessorType from ..engine.interfaces import Dialect @@ -77,10 +82,10 @@ _T = TypeVar("_T", bound="Any") _CT = TypeVar("_CT", bound=Any) _TE = TypeVar("_TE", bound="TypeEngine[Any]") +_P = TypeVar("_P") class HasExpressionLookup(TypeEngineMixin): - """Mixin expression adaptations based on lookup tables. These rules are currently used by the numeric, integer and date types @@ -119,7 +124,6 @@ def _adapt_expression( class Concatenable(TypeEngineMixin): - """A mixin that marks a type as supporting 'concatenation', typically strings.""" @@ -168,7 +172,6 @@ def __getitem__(self, index): class String(Concatenable, TypeEngine[str]): - """The base for all string and character types. In SQL, corresponds to VARCHAR. @@ -205,7 +208,7 @@ def __init__( .. sourcecode:: pycon+sql >>> from sqlalchemy import cast, select, String - >>> print(select(cast('some string', String(collation='utf8')))) + >>> print(select(cast("some string", String(collation="utf8")))) {printsql}SELECT CAST(:param_1 AS VARCHAR COLLATE utf8) AS anon_1 .. note:: @@ -220,6 +223,11 @@ def __init__( self.length = length self.collation = collation + def _with_collation(self, collation): + new_type = self.copy() + new_type.collation = collation + return new_type + def _resolve_for_literal(self, value): # I was SO PROUD of my regex trick, but we dont need it. # re.search(r"[^\u0000-\u007F]", value) @@ -240,10 +248,14 @@ def process(value): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[str]]: return None - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[str]]: return None @property @@ -255,7 +267,6 @@ def get_dbapi_type(self, dbapi): class Text(String): - """A variably sized string type. In SQL, usually corresponds to CLOB or TEXT. In general, TEXT objects @@ -268,14 +279,13 @@ class Text(String): class Unicode(String): - """A variable length Unicode string type. The :class:`.Unicode` type is a :class:`.String` subclass that assumes input and output strings that may contain non-ASCII characters, and for some backends implies an underlying column type that is explicitly - supporting of non-ASCII data, such as ``NVARCHAR`` on Oracle and SQL - Server. This will impact the output of ``CREATE TABLE`` statements and + supporting of non-ASCII data, such as ``NVARCHAR`` on Oracle Database and + SQL Server. This will impact the output of ``CREATE TABLE`` statements and ``CAST`` functions at the dialect level. The character encoding used by the :class:`.Unicode` type that is used to @@ -306,23 +316,12 @@ class Unicode(String): :meth:`.DialectEvents.do_setinputsizes` - """ __visit_name__ = "unicode" - def __init__(self, length=None, **kwargs): - """ - Create a :class:`.Unicode` object. - - Parameters are the same as that of :class:`.String`. - - """ - super().__init__(length=length, **kwargs) - class UnicodeText(Text): - """An unbounded-length Unicode string type. See :class:`.Unicode` for details on the unicode @@ -336,18 +335,8 @@ class UnicodeText(Text): __visit_name__ = "unicode_text" - def __init__(self, length=None, **kwargs): - """ - Create a Unicode-converting Text type. - - Parameters are the same as that of :class:`_expression.TextClause`. - - """ - super().__init__(length=length, **kwargs) - class Integer(HasExpressionLookup, TypeEngine[int]): - """A type for ``int`` integers.""" __visit_name__ = "integer" @@ -355,8 +344,7 @@ class Integer(HasExpressionLookup, TypeEngine[int]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[Integer]: - ... + def _type_affinity(self) -> Type[Integer]: ... def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -397,7 +385,6 @@ def _expression_adaptations(self): class SmallInteger(Integer): - """A type for smaller ``int`` integers. Typically generates a ``SMALLINT`` in DDL, and otherwise acts like @@ -409,7 +396,6 @@ class SmallInteger(Integer): class BigInteger(Integer): - """A type for bigger ``int`` integers. Typically generates a ``BIGINT`` in DDL, and otherwise acts like @@ -424,7 +410,6 @@ class BigInteger(Integer): class Numeric(HasExpressionLookup, TypeEngine[_N]): - """Base for non-integer numeric types, such as ``NUMERIC``, ``FLOAT``, ``DECIMAL``, and other variants. @@ -461,8 +446,7 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[Numeric[_N]]: - ... + def _type_affinity(self) -> Type[Numeric[_N]]: ... _default_decimal_return_scale = 10 @@ -473,8 +457,7 @@ def __init__( scale: Optional[int] = ..., decimal_return_scale: Optional[int] = ..., asdecimal: Literal[True] = ..., - ): - ... + ): ... @overload def __init__( @@ -483,8 +466,7 @@ def __init__( scale: Optional[int] = ..., decimal_return_scale: Optional[int] = ..., asdecimal: Literal[False] = ..., - ): - ... + ): ... def __init__( self, @@ -580,9 +562,11 @@ def result_processor(self, dialect, coltype): # we're a "numeric", DBAPI returns floats, convert. return processors.to_decimal_processor_factory( decimal.Decimal, - self.scale - if self.scale is not None - else self._default_decimal_return_scale, + ( + self.scale + if self.scale is not None + else self._default_decimal_return_scale + ), ) else: if dialect.supports_native_decimal: @@ -627,7 +611,10 @@ class Float(Numeric[_N]): __visit_name__ = "float" - scale = None + if not TYPE_CHECKING: + # this is not in 2.1 branch, not clear if needed for 2.0 + # implementation + scale = None @overload def __init__( @@ -635,8 +622,7 @@ def __init__( precision: Optional[int] = ..., asdecimal: Literal[False] = ..., decimal_return_scale: Optional[int] = ..., - ): - ... + ): ... @overload def __init__( @@ -644,8 +630,7 @@ def __init__( precision: Optional[int] = ..., asdecimal: Literal[True] = ..., decimal_return_scale: Optional[int] = ..., - ): - ... + ): ... def __init__( self: Float[_N], @@ -661,16 +646,16 @@ def __init__( indicates a number of digits for the generic :class:`_sqltypes.Float` datatype. - .. note:: For the Oracle backend, the + .. note:: For the Oracle Database backend, the :paramref:`_sqltypes.Float.precision` parameter is not accepted - when rendering DDL, as Oracle does not support float precision + when rendering DDL, as Oracle Database does not support float precision specified as a number of decimal places. Instead, use the - Oracle-specific :class:`_oracle.FLOAT` datatype and specify the + Oracle Database-specific :class:`_oracle.FLOAT` datatype and specify the :paramref:`_oracle.FLOAT.binary_precision` parameter. This is new in version 2.0 of SQLAlchemy. To create a database agnostic :class:`_types.Float` that - separately specifies binary precision for Oracle, use + separately specifies binary precision for Oracle Database, use :meth:`_types.TypeEngine.with_variant` as follows:: from sqlalchemy import Column @@ -679,7 +664,7 @@ def __init__( Column( "float_data", - Float(5).with_variant(oracle.FLOAT(binary_precision=16), "oracle") + Float(5).with_variant(oracle.FLOAT(binary_precision=16), "oracle"), ) :param asdecimal: the same flag as that of :class:`.Numeric`, but @@ -753,7 +738,6 @@ def process(value): class DateTime( _RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.datetime] ): - """A type for ``datetime.datetime()`` objects. Date and time types return objects from the Python ``datetime`` @@ -782,7 +766,7 @@ def __init__(self, timezone: bool = False): to make use of the :class:`_types.TIMESTAMP` datatype directly when using this flag, as some databases include separate generic date/time-holding types distinct from the timezone-capable - TIMESTAMP datatype, such as Oracle. + TIMESTAMP datatype, such as Oracle Database. """ @@ -817,7 +801,6 @@ def _expression_adaptations(self): class Date(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.date]): - """A type for ``datetime.date()`` objects.""" __visit_name__ = "date" @@ -858,7 +841,6 @@ def _expression_adaptations(self): class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): - """A type for ``datetime.time()`` objects.""" __visit_name__ = "time" @@ -895,12 +877,19 @@ def literal_processor(self, dialect): class _Binary(TypeEngine[bytes]): - """Define base behavior for binary types.""" + length: Optional[int] + def __init__(self, length: Optional[int] = None): self.length = length + @util.ro_memoized_property + def _generic_type_affinity( + self, + ) -> Type[TypeEngine[bytes]]: + return LargeBinary + def literal_processor(self, dialect): def process(value): # TODO: this is useless for real world scenarios; implement @@ -959,7 +948,6 @@ def get_dbapi_type(self, dbapi): class LargeBinary(_Binary): - """A type for large binary byte data. The :class:`.LargeBinary` type corresponds to a large and/or unlengthed @@ -983,7 +971,6 @@ def __init__(self, length: Optional[int] = None): class SchemaType(SchemaEventTarget, TypeEngineMixin): - """Add capabilities to a type which allow for schema-level DDL to be associated with a type. @@ -1044,7 +1031,7 @@ def __init__( if _adapted_from: self.dispatch = self.dispatch._join(_adapted_from.dispatch) - def _set_parent(self, column, **kw): + def _set_parent(self, parent, **kw): # set parent hook is when this type is associated with a column. # Column calls it for all SchemaEventTarget instances, either the # base type and/or variants in _variant_mapping. @@ -1058,7 +1045,7 @@ def _set_parent(self, column, **kw): # on_table/metadata_create/drop in this method, which is used by # "native" types with a separate CREATE/DROP e.g. Postgresql.ENUM - column._on_table_attach(util.portable_instancemethod(self._set_table)) + parent._on_table_attach(util.portable_instancemethod(self._set_table)) def _variant_mapping_for_set_table(self, column): if column.type._variant_mapping: @@ -1118,15 +1105,20 @@ def copy(self, **kw): return self.adapt( cast("Type[TypeEngine[Any]]", self.__class__), _create_events=True, + metadata=( + kw.get("_to_metadata", self.metadata) + if self.metadata is not None + else None + ), ) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -1135,21 +1127,23 @@ def adapt( kw.setdefault("_adapted_from", self) return super().adapt(cls, **kw) - def create(self, bind, checkfirst=False): + def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue CREATE DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.create(bind, checkfirst=checkfirst) - def drop(self, bind, checkfirst=False): + def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue DROP DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.drop(bind, checkfirst=checkfirst) - def _on_table_create(self, target, bind, **kw): + def _on_table_create( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1157,7 +1151,9 @@ def _on_table_create(self, target, bind, **kw): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_create(target, bind, **kw) - def _on_table_drop(self, target, bind, **kw): + def _on_table_drop( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1165,7 +1161,9 @@ def _on_table_drop(self, target, bind, **kw): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_drop(target, bind, **kw) - def _on_metadata_create(self, target, bind, **kw): + def _on_metadata_create( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1173,7 +1171,9 @@ def _on_metadata_create(self, target, bind, **kw): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_create(target, bind, **kw) - def _on_metadata_drop(self, target, bind, **kw): + def _on_metadata_drop( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1181,7 +1181,9 @@ def _on_metadata_drop(self, target, bind, **kw): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_drop(target, bind, **kw) - def _is_impl_for_variant(self, dialect, kw): + def _is_impl_for_variant( + self, dialect: Dialect, kw: Dict[str, Any] + ) -> Optional[bool]: variant_mapping = kw.pop("variant_mapping", None) if not variant_mapping: @@ -1198,7 +1200,7 @@ def _is_impl_for_variant(self, dialect, kw): # since PostgreSQL is the only DB that has ARRAY this can only # be integration tested by PG-specific tests - def _we_are_the_impl(typ): + def _we_are_the_impl(typ: SchemaType) -> bool: return ( typ is self or isinstance(typ, ARRAY) @@ -1211,6 +1213,11 @@ def _we_are_the_impl(typ): return True elif dialect.name not in variant_mapping: return _we_are_the_impl(variant_mapping["_default"]) + else: + return None + + +_EnumTupleArg = Union[Sequence[enum.Enum], Sequence[str]] class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): @@ -1249,15 +1256,14 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): import enum from sqlalchemy import Enum + class MyEnum(enum.Enum): one = 1 two = 2 three = 3 - t = Table( - 'data', MetaData(), - Column('value', Enum(MyEnum)) - ) + + t = Table("data", MetaData(), Column("value", Enum(MyEnum))) connection.execute(t.insert(), {"value": MyEnum.two}) assert connection.scalar(t.select()) is MyEnum.two @@ -1290,7 +1296,18 @@ class MyEnum(enum.Enum): __visit_name__ = "enum" - def __init__(self, *enums: object, **kw: Any): + values_callable: Optional[Callable[[Type[enum.Enum]], Sequence[str]]] + enum_class: Optional[Type[enum.Enum]] + _valid_lookup: Dict[Union[enum.Enum, str, None], Optional[str]] + _object_lookup: Dict[Optional[str], Union[enum.Enum, str, None]] + + @overload + def __init__(self, enums: Type[enum.Enum], **kw: Any) -> None: ... + + @overload + def __init__(self, *enums: str, **kw: Any) -> None: ... + + def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None: r"""Construct an enum. Keyword arguments which don't apply to a specific backend are ignored @@ -1422,7 +1439,7 @@ class was used, its name (converted to lower case) is used by .. versionchanged:: 2.0 This parameter now defaults to True. """ - self._enum_init(enums, kw) + self._enum_init(enums, kw) # type: ignore[arg-type] @property def _enums_argument(self): @@ -1431,7 +1448,7 @@ def _enums_argument(self): else: return self.enums - def _enum_init(self, enums, kw): + def _enum_init(self, enums: _EnumTupleArg, kw: Dict[str, Any]) -> None: """internal init for :class:`.Enum` and subclasses. friendly init helper used by subclasses to remove @@ -1490,15 +1507,19 @@ def _enum_init(self, enums, kw): _adapted_from=kw.pop("_adapted_from", None), ) - def _parse_into_values(self, enums, kw): + def _parse_into_values( + self, enums: _EnumTupleArg, kw: Any + ) -> Tuple[Sequence[str], _EnumTupleArg]: if not enums and "_enums" in kw: enums = kw.pop("_enums") if len(enums) == 1 and hasattr(enums[0], "__members__"): - self.enum_class = enums[0] + self.enum_class = enums[0] # type: ignore[assignment] + assert self.enum_class is not None _members = self.enum_class.__members__ + members: Mapping[str, enum.Enum] if self._omit_aliases is True: # remove aliases members = OrderedDict( @@ -1514,7 +1535,7 @@ def _parse_into_values(self, enums, kw): return values, objects else: self.enum_class = None - return enums, enums + return enums, enums # type: ignore[return-value] def _resolve_for_literal(self, value: Any) -> Enum: tv = type(value) @@ -1535,6 +1556,19 @@ def _resolve_for_python_type( native_enum = None + def process_literal(pt): + # for a literal, where we need to get its contents, parse it out. + enum_args = get_args(pt) + bad_args = [arg for arg in enum_args if not isinstance(arg, str)] + if bad_args: + raise exc.ArgumentError( + f"Can't create string-based Enum datatype from non-string " + f"values: {', '.join(repr(x) for x in bad_args)}. Please " + f"provide an explicit Enum datatype for this Python type" + ) + native_enum = False + return enum_args, native_enum + if not we_are_generic_form and python_type is matched_on: # if we have enumerated values, and the incoming python # type is exactly the one that matched in the type map, @@ -1543,16 +1577,32 @@ def _resolve_for_python_type( enum_args = self._enums_argument elif is_literal(python_type): - # for a literal, where we need to get its contents, parse it out. - enum_args = typing_get_args(python_type) - bad_args = [arg for arg in enum_args if not isinstance(arg, str)] - if bad_args: + enum_args, native_enum = process_literal(python_type) + elif is_pep695(python_type): + value = python_type.__value__ + if is_pep695(value): + new_value = value + while is_pep695(new_value): + new_value = new_value.__value__ + if is_literal(new_value): + value = new_value + warn_deprecated( + f"Mapping recursive TypeAliasType '{python_type}' " + "that resolve to literal to generate an Enum is " + "deprecated. SQLAlchemy 2.1 will not support this " + "use case. Please avoid using recursing " + "TypeAliasType.", + "2.0", + ) + if not is_literal(value): raise exc.ArgumentError( - f"Can't create string-based Enum datatype from non-string " - f"values: {', '.join(repr(x) for x in bad_args)}. Please " - f"provide an explicit Enum datatype for this Python type" + f"Can't associate TypeAliasType '{python_type}' to an " + "Enum since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal['a', " + "'b']` are supported when generating Enums." ) - native_enum = False + enum_args, native_enum = process_literal(value) + elif isinstance(python_type, type) and issubclass( python_type, enum.Enum ): @@ -1575,7 +1625,12 @@ def _resolve_for_python_type( self._generic_type_affinity(_enums=enum_args, **kw), # type: ignore # noqa: E501 ) - def _setup_for_values(self, values, objects, kw): + def _setup_for_values( + self, + values: Sequence[str], + objects: _EnumTupleArg, + kw: Any, + ) -> None: self.enums = list(values) self._valid_lookup = dict(zip(reversed(objects), reversed(values))) @@ -1590,14 +1645,14 @@ def _setup_for_values(self, values, objects, kw): ) @property - def sort_key_function(self): + def sort_key_function(self): # type: ignore[override] if self._sort_key_function is NO_ARG: return self._db_value_for_elem else: return self._sort_key_function @property - def native(self): + def native(self): # type: ignore[override] return self.native_enum def _db_value_for_elem(self, elem): @@ -1642,9 +1697,10 @@ def _adapt_expression( comparator_factory = Comparator - def _object_value_for_elem(self, elem): + def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]: try: - return self._object_lookup[elem] + # Value will not be None beacuse key is not None + return self._object_lookup[elem] # type: ignore[return-value] except KeyError as err: raise LookupError( "'%s' is not among the defined enum values. " @@ -1702,10 +1758,10 @@ def adapt_to_emulated(self, impltype, **kw): assert "_enums" in kw return impltype(**kw) - def adapt(self, impltype, **kw): + def adapt(self, cls, **kw): kw["_enums"] = self._enums_argument kw["_disable_warnings"] = True - return super().adapt(impltype, **kw) + return super().adapt(cls, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): @@ -1886,7 +1942,6 @@ def compare_values(self, x, y): class Boolean(SchemaType, Emulated, TypeEngine[bool]): - """A bool datatype. :class:`.Boolean` typically uses BOOLEAN or SMALLINT on the DDL side, @@ -1942,6 +1997,13 @@ def __init__( if _adapted_from: self.dispatch = self.dispatch._join(_adapted_from.dispatch) + def copy(self, **kw): + # override SchemaType.copy() to not include to_metadata logic + return self.adapt( + cast("Type[TypeEngine[Any]]", self.__class__), + _create_events=True, + ) + def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): return False @@ -2044,13 +2106,11 @@ def _type_affinity(self) -> Type[Interval]: class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): - """A type for ``datetime.timedelta()`` objects. - The Interval type deals with ``datetime.timedelta`` objects. In - PostgreSQL and Oracle, the native ``INTERVAL`` type is used; for others, - the value is stored as a date which is relative to the "epoch" - (Jan. 1, 1970). + The Interval type deals with ``datetime.timedelta`` objects. In PostgreSQL + and Oracle Database, the native ``INTERVAL`` type is used; for others, the + value is stored as a date which is relative to the "epoch" (Jan. 1, 1970). Note that the ``Interval`` type does not currently provide date arithmetic operations on platforms which do not support interval types natively. Such @@ -2075,16 +2135,16 @@ def __init__( :param native: when True, use the actual INTERVAL type provided by the database, if - supported (currently PostgreSQL, Oracle). + supported (currently PostgreSQL, Oracle Database). Otherwise, represent the interval data as an epoch value regardless. :param second_precision: For native interval types which support a "fractional seconds precision" parameter, - i.e. Oracle and PostgreSQL + i.e. Oracle Database and PostgreSQL :param day_precision: for native interval types which - support a "day precision" parameter, i.e. Oracle. + support a "day precision" parameter, i.e. Oracle Database. """ super().__init__() @@ -2194,15 +2254,16 @@ class JSON(Indexable, TypeEngine[Any]): The :class:`_types.JSON` type stores arbitrary JSON format data, e.g.:: - data_table = Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', JSON) + data_table = Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", JSON), ) with engine.connect() as conn: conn.execute( - data_table.insert(), - {"data": {"key1": "value1", "key2": "value2"}} + data_table.insert(), {"data": {"key1": "value1", "key2": "value2"}} ) **JSON-Specific Expression Operators** @@ -2212,7 +2273,7 @@ class JSON(Indexable, TypeEngine[Any]): * Keyed index operations:: - data_table.c.data['some key'] + data_table.c.data["some key"] * Integer index operations:: @@ -2220,7 +2281,7 @@ class JSON(Indexable, TypeEngine[Any]): * Path index operations:: - data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] + data_table.c.data[("key_1", "key_2", 5, ..., "key_n")] * Data casters for specific JSON element types, subsequent to an index or path operation being invoked:: @@ -2275,13 +2336,12 @@ class JSON(Indexable, TypeEngine[Any]): from sqlalchemy import cast, type_coerce from sqlalchemy import String, JSON - cast( - data_table.c.data['some_key'], String - ) == type_coerce(55, JSON) + + cast(data_table.c.data["some_key"], String) == type_coerce(55, JSON) The above case now works directly as:: - data_table.c.data['some_key'].as_integer() == 5 + data_table.c.data["some_key"].as_integer() == 5 For details on the previous comparison approach within the 1.3.x series, see the documentation for SQLAlchemy 1.2 or the included HTML @@ -2312,6 +2372,7 @@ class JSON(Indexable, TypeEngine[Any]): should be SQL NULL as opposed to JSON ``"null"``:: from sqlalchemy import null + conn.execute(table.insert(), {"json_value": null()}) To insert or select against a value that is JSON ``"null"``, use the @@ -2344,7 +2405,8 @@ class JSON(Indexable, TypeEngine[Any]): engine = create_engine( "sqlite://", - json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False)) + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), + ) .. versionchanged:: 1.3.7 @@ -2362,7 +2424,7 @@ class JSON(Indexable, TypeEngine[Any]): :class:`sqlalchemy.dialects.sqlite.JSON` - """ + """ # noqa: E501 __visit_name__ = "JSON" @@ -2396,8 +2458,7 @@ class JSON(Indexable, TypeEngine[Any]): transparent method is to use :func:`_expression.text`:: Table( - 'my_table', metadata, - Column('json_data', JSON, default=text("'null'")) + "my_table", metadata, Column("json_data", JSON, default=text("'null'")) ) While it is possible to use :attr:`_types.JSON.NULL` in this context, the @@ -2409,7 +2470,7 @@ class JSON(Indexable, TypeEngine[Any]): generated defaults. - """ + """ # noqa: E501 def __init__(self, none_as_null: bool = False): """Construct a :class:`_types.JSON` type. @@ -2422,6 +2483,7 @@ def __init__(self, none_as_null: bool = False): as SQL NULL:: from sqlalchemy import null + conn.execute(table.insert(), {"data": null()}) .. note:: @@ -2452,17 +2514,21 @@ class JSONElementType(TypeEngine[Any]): _integer = Integer() _string = String() - def string_bind_processor(self, dialect): + def string_bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[str]]: return self._string._cached_bind_processor(dialect) - def string_literal_processor(self, dialect): + def string_literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[str]]: return self._string._cached_literal_processor(dialect) - def bind_processor(self, dialect): + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: int_processor = self._integer._cached_bind_processor(dialect) string_processor = self.string_bind_processor(dialect) - def process(value): + def process(value: Optional[Any]) -> Any: if int_processor and isinstance(value, int): value = int_processor(value) elif string_processor and isinstance(value, str): @@ -2471,11 +2537,13 @@ def process(value): return process - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: int_processor = self._integer._cached_literal_processor(dialect) string_processor = self.string_literal_processor(dialect) - def process(value): + def process(value: Optional[Any]) -> Any: if int_processor and isinstance(value, int): value = int_processor(value) elif string_processor and isinstance(value, str): @@ -2526,6 +2594,8 @@ class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): __slots__ = () + type: JSON + def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( index, collections_abc.Sequence @@ -2545,104 +2615,112 @@ def _setup_getitem(self, index): index, expr=self.expr, operator=operators.json_getitem_op, - bindparam_type=JSON.JSONIntIndexType - if isinstance(index, int) - else JSON.JSONStrIndexType, + bindparam_type=( + JSON.JSONIntIndexType + if isinstance(index, int) + else JSON.JSONStrIndexType + ), ) operator = operators.json_getitem_op return operator, index, self.type def as_boolean(self): - """Cast an indexed value as boolean. + """Consider an indexed value as boolean. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select( - mytable.c.json_column['some_data'].as_boolean() - ).where( - mytable.c.json_column['some_data'].as_boolean() == True + stmt = select(mytable.c.json_column["some_data"].as_boolean()).where( + mytable.c.json_column["some_data"].as_boolean() == True ) .. versionadded:: 1.3.11 - """ + """ # noqa: E501 return self._binary_w_type(Boolean(), "as_boolean") def as_string(self): - """Cast an indexed value as string. + """Consider an indexed value as string. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select( - mytable.c.json_column['some_data'].as_string() - ).where( - mytable.c.json_column['some_data'].as_string() == - 'some string' + stmt = select(mytable.c.json_column["some_data"].as_string()).where( + mytable.c.json_column["some_data"].as_string() == "some string" ) .. versionadded:: 1.3.11 - """ + """ # noqa: E501 return self._binary_w_type(Unicode(), "as_string") def as_integer(self): - """Cast an indexed value as integer. + """Consider an indexed value as integer. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select( - mytable.c.json_column['some_data'].as_integer() - ).where( - mytable.c.json_column['some_data'].as_integer() == 5 + stmt = select(mytable.c.json_column["some_data"].as_integer()).where( + mytable.c.json_column["some_data"].as_integer() == 5 ) .. versionadded:: 1.3.11 - """ + """ # noqa: E501 return self._binary_w_type(Integer(), "as_integer") def as_float(self): - """Cast an indexed value as float. + """Consider an indexed value as float. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select( - mytable.c.json_column['some_data'].as_float() - ).where( - mytable.c.json_column['some_data'].as_float() == 29.75 + stmt = select(mytable.c.json_column["some_data"].as_float()).where( + mytable.c.json_column["some_data"].as_float() == 29.75 ) .. versionadded:: 1.3.11 - """ + """ # noqa: E501 return self._binary_w_type(Float(), "as_float") def as_numeric(self, precision, scale, asdecimal=True): - """Cast an indexed value as numeric/decimal. + """Consider an indexed value as numeric/decimal. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select( - mytable.c.json_column['some_data'].as_numeric(10, 6) - ).where( - mytable.c. - json_column['some_data'].as_numeric(10, 6) == 29.75 + stmt = select(mytable.c.json_column["some_data"].as_numeric(10, 6)).where( + mytable.c.json_column["some_data"].as_numeric(10, 6) == 29.75 ) .. versionadded:: 1.4.0b2 - """ + """ # noqa: E501 return self._binary_w_type( Numeric(precision, scale, asdecimal=asdecimal), "as_numeric" ) def as_json(self): - """Cast an indexed value as JSON. + """Consider an indexed value as JSON. + + This is similar to using :class:`_sql.type_coerce`, and will + usually not apply a ``CAST()``. e.g.:: - stmt = select(mytable.c.json_column['some_data'].as_json()) + stmt = select(mytable.c.json_column["some_data"].as_json()) This is typically the default behavior of indexed elements in any case. @@ -2677,7 +2755,7 @@ def _binary_w_type(self, typ, method_name): def python_type(self): return dict - @property # type: ignore # mypy property bug + @property def should_evaluate_none(self): """Alias of :attr:`_types.JSON.none_as_null`""" return not self.none_as_null @@ -2739,7 +2817,7 @@ def process(value): class ARRAY( - SchemaEventTarget, Indexable, Concatenable, TypeEngine[Sequence[Any]] + SchemaEventTarget, Indexable, Concatenable, TypeEngine[Sequence[_T]] ): """Represent a SQL Array type. @@ -2760,26 +2838,21 @@ class ARRAY( An :class:`_types.ARRAY` type is constructed given the "type" of element:: - mytable = Table("mytable", metadata, - Column("data", ARRAY(Integer)) - ) + mytable = Table("mytable", metadata, Column("data", ARRAY(Integer))) The above type represents an N-dimensional array, meaning a supporting backend such as PostgreSQL will interpret values with any number of dimensions automatically. To produce an INSERT construct that passes in a 1-dimensional array of integers:: - connection.execute( - mytable.insert(), - {"data": [1,2,3]} - ) + connection.execute(mytable.insert(), {"data": [1, 2, 3]}) The :class:`_types.ARRAY` type can be constructed given a fixed number of dimensions:: - mytable = Table("mytable", metadata, - Column("data", ARRAY(Integer, dimensions=2)) - ) + mytable = Table( + "mytable", metadata, Column("data", ARRAY(Integer, dimensions=2)) + ) Sending a number of dimensions is optional, but recommended if the datatype is to represent arrays of more than one dimension. This number @@ -2804,22 +2877,21 @@ class ARRAY( dimension parameter will generally assume single-dimensional behaviors. SQL expressions of type :class:`_types.ARRAY` have support for "index" and - "slice" behavior. The Python ``[]`` operator works normally here, given - integer indexes or slices. Arrays default to 1-based indexing. - The operator produces binary expression + "slice" behavior. The ``[]`` operator produces expression constructs which will produce the appropriate SQL, both for SELECT statements:: select(mytable.c.data[5], mytable.c.data[2:7]) as well as UPDATE statements when the :meth:`_expression.Update.values` - method - is used:: + method is used:: - mytable.update().values({ - mytable.c.data[5]: 7, - mytable.c.data[2:7]: [1, 2, 3] - }) + mytable.update().values( + {mytable.c.data[5]: 7, mytable.c.data[2:7]: [1, 2, 3]} + ) + + Indexed access is one-based by default; + for zero-based index conversion, set :paramref:`_types.ARRAY.zero_indexes`. The :class:`_types.ARRAY` type also provides for the operators :meth:`.types.ARRAY.Comparator.any` and @@ -2838,6 +2910,7 @@ class ARRAY( from sqlalchemy import ARRAY from sqlalchemy.ext.mutable import MutableList + class SomeOrmClass(Base): # ... @@ -2865,11 +2938,60 @@ class SomeOrmClass(Base): """If True, Python zero-based indexes should be interpreted as one-based on the SQL expression side.""" - class Comparator( - Indexable.Comparator[Sequence[Any]], - Concatenable.Comparator[Sequence[Any]], + def __init__( + self, + item_type: _TypeEngineArgument[_T], + as_tuple: bool = False, + dimensions: Optional[int] = None, + zero_indexes: bool = False, ): + """Construct an :class:`_types.ARRAY`. + + E.g.:: + + Column("myarray", ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. This parameter is + not generally needed as a Python list corresponds well + to a SQL array. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This impacts how the array is declared + on the database, how it goes about interpreting Python and + result values, as well as how expression behavior in conjunction + with the "getitem" operator works. See the description at + :class:`_types.ARRAY` for additional detail. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and SQL one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + """ + if isinstance(item_type, ARRAY): + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + class Comparator( + Indexable.Comparator[Sequence[_T]], + Concatenable.Comparator[Sequence[_T]], + ): """Define comparison operations for :class:`_types.ARRAY`. More operators are available on the dialect-specific form @@ -2879,9 +3001,22 @@ class Comparator( __slots__ = () - type: ARRAY + type: ARRAY[_T] - def _setup_getitem(self, index): + @overload + def _setup_getitem( + self, index: int + ) -> Tuple[OperatorType, int, TypeEngine[Any]]: ... + + @overload + def _setup_getitem( + self, index: slice + ) -> Tuple[OperatorType, Slice, TypeEngine[Any]]: ... + + def _setup_getitem(self, index: Union[int, slice]) -> Union[ + Tuple[OperatorType, int, TypeEngine[Any]], + Tuple[OperatorType, Slice, TypeEngine[Any]], + ]: arr_type = self.type return_type: TypeEngine[Any] @@ -2907,17 +3042,26 @@ def _setup_getitem(self, index): return operators.getitem, index, return_type - def contains(self, *arg, **kw): + def contains(self, *arg: Any, **kw: Any) -> ColumnElement[bool]: + """``ARRAY.contains()`` not implemented for the base ARRAY type. + Use the dialect-specific ARRAY type. + + .. seealso:: + + :class:`_postgresql.ARRAY` - PostgreSQL specific version. + """ raise NotImplementedError( "ARRAY.contains() not implemented for the base " "ARRAY type; please use the dialect-specific ARRAY type" ) @util.preload_module("sqlalchemy.sql.elements") - def any(self, other, operator=None): + def any( + self, other: Any, operator: Optional[OperatorType] = None + ) -> ColumnElement[bool]: """Return ``other operator ANY (array)`` clause. - .. note:: This method is an :class:`_types.ARRAY` - specific + .. legacy:: This method is an :class:`_types.ARRAY` - specific construct that is now superseded by the :func:`_sql.any_` function, which features a different calling style. The :func:`_sql.any_` function is also mirrored at the method level @@ -2929,9 +3073,7 @@ def any(self, other, operator=None): from sqlalchemy.sql import operators conn.execute( - select(table.c.data).where( - table.c.data.any(7, operator=operators.lt) - ) + select(table.c.data).where(table.c.data.any(7, operator=operators.lt)) ) :param other: expression to be compared @@ -2945,15 +3087,14 @@ def any(self, other, operator=None): :meth:`.types.ARRAY.Comparator.all` - """ + """ # noqa: E501 elements = util.preloaded.sql_elements operator = operator if operator else operators.eq arr_type = self.type - # send plain BinaryExpression so that negate remains at None, - # leading to NOT expr for negation. - return elements.BinaryExpression( + return elements.CollectionAggregate._create_any(self.expr).operate( + operators.mirror(operator), coercions.expect( roles.BinaryElementRole, element=other, @@ -2961,19 +3102,19 @@ def any(self, other, operator=None): expr=self.expr, bindparam_type=arr_type.item_type, ), - elements.CollectionAggregate._create_any(self.expr), - operator, ) @util.preload_module("sqlalchemy.sql.elements") - def all(self, other, operator=None): + def all( + self, other: Any, operator: Optional[OperatorType] = None + ) -> ColumnElement[bool]: """Return ``other operator ALL (array)`` clause. - .. note:: This method is an :class:`_types.ARRAY` - specific - construct that is now superseded by the :func:`_sql.any_` + .. legacy:: This method is an :class:`_types.ARRAY` - specific + construct that is now superseded by the :func:`_sql.all_` function, which features a different calling style. The - :func:`_sql.any_` function is also mirrored at the method level - via the :meth:`_sql.ColumnOperators.any_` method. + :func:`_sql.all_` function is also mirrored at the method level + via the :meth:`_sql.ColumnOperators.all_` method. Usage of array-specific :meth:`_types.ARRAY.Comparator.all` is as follows:: @@ -2981,9 +3122,7 @@ def all(self, other, operator=None): from sqlalchemy.sql import operators conn.execute( - select(table.c.data).where( - table.c.data.all(7, operator=operators.lt) - ) + select(table.c.data).where(table.c.data.all(7, operator=operators.lt)) ) :param other: expression to be compared @@ -2997,15 +3136,14 @@ def all(self, other, operator=None): :meth:`.types.ARRAY.Comparator.any` - """ + """ # noqa: E501 elements = util.preloaded.sql_elements operator = operator if operator else operators.eq arr_type = self.type - # send plain BinaryExpression so that negate remains at None, - # leading to NOT expr for negation. - return elements.BinaryExpression( + return elements.CollectionAggregate._create_all(self.expr).operate( + operators.mirror(operator), coercions.expect( roles.BinaryElementRole, element=other, @@ -3013,80 +3151,32 @@ def all(self, other, operator=None): expr=self.expr, bindparam_type=arr_type.item_type, ), - elements.CollectionAggregate._create_all(self.expr), - operator, ) comparator_factory = Comparator - def __init__( - self, - item_type: _TypeEngineArgument[Any], - as_tuple: bool = False, - dimensions: Optional[int] = None, - zero_indexes: bool = False, - ): - """Construct an :class:`_types.ARRAY`. - - E.g.:: - - Column('myarray', ARRAY(Integer)) - - Arguments are: - - :param item_type: The data type of items of this array. Note that - dimensionality is irrelevant here, so multi-dimensional arrays like - ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as - ``ARRAY(ARRAY(Integer))`` or such. - - :param as_tuple=False: Specify whether return results - should be converted to tuples from lists. This parameter is - not generally needed as a Python list corresponds well - to a SQL array. - - :param dimensions: if non-None, the ARRAY will assume a fixed - number of dimensions. This impacts how the array is declared - on the database, how it goes about interpreting Python and - result values, as well as how expression behavior in conjunction - with the "getitem" operator works. See the description at - :class:`_types.ARRAY` for additional detail. - - :param zero_indexes=False: when True, index values will be converted - between Python zero-based and SQL one-based indexes, e.g. - a value of one will be added to all index values before passing - to the database. - - """ - if isinstance(item_type, ARRAY): - raise ValueError( - "Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype" - ) - if isinstance(item_type, type): - item_type = item_type() - self.item_type = item_type - self.as_tuple = as_tuple - self.dimensions = dimensions - self.zero_indexes = zero_indexes - @property - def hashable(self): + def hashable(self) -> bool: # type: ignore[override] return self.as_tuple @property - def python_type(self): + def python_type(self) -> Type[Any]: return list - def compare_values(self, x, y): - return x == y + def compare_values(self, x: Any, y: Any) -> bool: + return x == y # type: ignore[no-any-return] - def _set_parent(self, column, outer=False, **kw): + def _set_parent( + self, parent: SchemaEventTarget, outer: bool = False, **kw: Any + ) -> None: """Support SchemaEventTarget""" if not outer and isinstance(self.item_type, SchemaEventTarget): - self.item_type._set_parent(column, **kw) + self.item_type._set_parent(parent, **kw) - def _set_parent_with_dispatch(self, parent): + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: """Support SchemaEventTarget""" super()._set_parent_with_dispatch(parent, outer=True) @@ -3094,17 +3184,19 @@ def _set_parent_with_dispatch(self, parent): if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: item_proc = self.item_type.dialect_impl(dialect).literal_processor( dialect ) if item_proc is None: return None - def to_str(elements): + def to_str(elements: Iterable[Any]) -> str: return f"[{', '.join(elements)}]" - def process(value): + def process(value: Sequence[Any]) -> str: inner = self._apply_item_processor( value, item_proc, self.dimensions, to_str ) @@ -3112,7 +3204,13 @@ def process(value): return process - def _apply_item_processor(self, arr, itemproc, dim, collection_callable): + def _apply_item_processor( + self, + arr: Sequence[Any], + itemproc: Optional[Callable[[Any], Any]], + dim: Optional[int], + collection_callable: Callable[[Iterable[Any]], _P], + ) -> _P: """Helper method that can be used by bind_processor(), literal_processor(), etc. to apply an item processor to elements of an array value, taking into account the 'dimensions' for this @@ -3143,14 +3241,16 @@ def _apply_item_processor(self, arr, itemproc, dim, collection_callable): return collection_callable(arr) else: return collection_callable( - self._apply_item_processor( - x, - itemproc, - dim - 1 if dim is not None else None, - collection_callable, + ( + self._apply_item_processor( + x, + itemproc, + dim - 1 if dim is not None else None, + collection_callable, + ) + if x is not None + else None ) - if x is not None - else None for x in arr ) @@ -3201,7 +3301,6 @@ def result_processor(self, dialect, coltype): class REAL(Float[_N]): - """The SQL REAL type. .. seealso:: @@ -3214,7 +3313,6 @@ class REAL(Float[_N]): class FLOAT(Float[_N]): - """The SQL FLOAT type. .. seealso:: @@ -3255,7 +3353,6 @@ class DOUBLE_PRECISION(Double[_N]): class NUMERIC(Numeric[_N]): - """The SQL NUMERIC type. .. seealso:: @@ -3268,7 +3365,6 @@ class NUMERIC(Numeric[_N]): class DECIMAL(Numeric[_N]): - """The SQL DECIMAL type. .. seealso:: @@ -3281,7 +3377,6 @@ class DECIMAL(Numeric[_N]): class INTEGER(Integer): - """The SQL INT or INTEGER type. .. seealso:: @@ -3297,7 +3392,6 @@ class INTEGER(Integer): class SMALLINT(SmallInteger): - """The SQL SMALLINT type. .. seealso:: @@ -3310,7 +3404,6 @@ class SMALLINT(SmallInteger): class BIGINT(BigInteger): - """The SQL BIGINT type. .. seealso:: @@ -3323,11 +3416,10 @@ class BIGINT(BigInteger): class TIMESTAMP(DateTime): - """The SQL TIMESTAMP type. - :class:`_types.TIMESTAMP` datatypes have support for timezone - storage on some backends, such as PostgreSQL and Oracle. Use the + :class:`_types.TIMESTAMP` datatypes have support for timezone storage on + some backends, such as PostgreSQL and Oracle Database. Use the :paramref:`~types.TIMESTAMP.timezone` argument in order to enable "TIMESTAMP WITH TIMEZONE" for these backends. @@ -3353,101 +3445,87 @@ def get_dbapi_type(self, dbapi): class DATETIME(DateTime): - """The SQL DATETIME type.""" __visit_name__ = "DATETIME" class DATE(Date): - """The SQL DATE type.""" __visit_name__ = "DATE" class TIME(Time): - """The SQL TIME type.""" __visit_name__ = "TIME" class TEXT(Text): - """The SQL TEXT type.""" __visit_name__ = "TEXT" class CLOB(Text): - """The CLOB type. - This type is found in Oracle and Informix. + This type is found in Oracle Database and Informix. """ __visit_name__ = "CLOB" class VARCHAR(String): - """The SQL VARCHAR type.""" __visit_name__ = "VARCHAR" class NVARCHAR(Unicode): - """The SQL NVARCHAR type.""" __visit_name__ = "NVARCHAR" class CHAR(String): - """The SQL CHAR type.""" __visit_name__ = "CHAR" class NCHAR(Unicode): - """The SQL NCHAR type.""" __visit_name__ = "NCHAR" class BLOB(LargeBinary): - """The SQL BLOB type.""" __visit_name__ = "BLOB" class BINARY(_Binary): - """The SQL BINARY type.""" __visit_name__ = "BINARY" class VARBINARY(_Binary): - """The SQL VARBINARY type.""" __visit_name__ = "VARBINARY" class BOOLEAN(Boolean): - """The SQL BOOLEAN type.""" __visit_name__ = "BOOLEAN" class NullType(TypeEngine[None]): - """An unknown type. :class:`.NullType` is used as a default type for those cases where @@ -3532,7 +3610,6 @@ class MatchType(Boolean): class Uuid(Emulated, TypeEngine[_UUID_RETURN]): - """Represent a database agnostic UUID datatype. For backends that have no "native" UUID datatype, the value will @@ -3560,14 +3637,13 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): t = Table( "t", metadata_obj, - Column('uuid_data', Uuid, primary_key=True), - Column("other_data", String) + Column("uuid_data", Uuid, primary_key=True), + Column("other_data", String), ) with engine.begin() as conn: conn.execute( - t.insert(), - {"uuid_data": uuid.uuid4(), "other_data", "some data"} + t.insert(), {"uuid_data": uuid.uuid4(), "other_data": "some data"} ) To have the :class:`_sqltypes.Uuid` datatype work with string-based @@ -3581,10 +3657,11 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): :class:`_sqltypes.UUID` - represents exactly the ``UUID`` datatype without any backend-agnostic behaviors. - """ + """ # noqa: E501 __visit_name__ = "uuid" + length: Optional[int] = None collation: Optional[str] = None @overload @@ -3592,16 +3669,14 @@ def __init__( self: Uuid[_python_UUID], as_uuid: Literal[True] = ..., native_uuid: bool = ..., - ): - ... + ): ... @overload def __init__( self: Uuid[str], as_uuid: Literal[False] = ..., native_uuid: bool = ..., - ): - ... + ): ... def __init__(self, as_uuid: bool = True, native_uuid: bool = True): """Construct a :class:`_sqltypes.Uuid` type. @@ -3627,7 +3702,7 @@ def python_type(self): return _python_UUID if self.as_uuid else str @property - def native(self): + def native(self): # type: ignore[override] return self.native_uuid def coerce_compared_value(self, op, value): @@ -3638,7 +3713,9 @@ def coerce_compared_value(self, op, value): else: return super().coerce_compared_value(op, value) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[_UUID_RETURN]]: character_based_uuid = ( not dialect.supports_native_uuid or not self.native_uuid ) @@ -3724,7 +3801,6 @@ def process(value): class UUID(Uuid[_UUID_RETURN], type_api.NativeForEmulated): - """Represent the SQL UUID type. This is the SQL-native form of the :class:`_types.Uuid` database agnostic @@ -3748,12 +3824,10 @@ class UUID(Uuid[_UUID_RETURN], type_api.NativeForEmulated): __visit_name__ = "UUID" @overload - def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): - ... + def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): ... @overload - def __init__(self: UUID[str], as_uuid: Literal[False] = ...): - ... + def __init__(self: UUID[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): """Construct a :class:`_sqltypes.UUID` type. diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 5758dff3c43..13ad28996e0 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -1,5 +1,5 @@ # sql/traversals.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -80,16 +80,13 @@ class HasShallowCopy(HasTraverseInternals): if typing.TYPE_CHECKING: - def _generated_shallow_copy_traversal(self, other: Self) -> None: - ... + def _generated_shallow_copy_traversal(self, other: Self) -> None: ... def _generated_shallow_from_dict_traversal( self, d: Dict[str, Any] - ) -> None: - ... + ) -> None: ... - def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: - ... + def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ... @classmethod def _generate_shallow_copy( @@ -312,9 +309,11 @@ def visit_dml_ordered_values( # sequence of 2-tuples return [ ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key + ), clone(value, **kw), ) for key, value in element @@ -336,9 +335,11 @@ def visit_dml_multi_values( def copy(elem): if isinstance(elem, (list, tuple)): return [ - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value + ( + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + ) for value in elem ] elif isinstance(elem, dict): @@ -561,6 +562,8 @@ def compare( return False else: continue + elif right_child is None: + return False comparison = dispatch( left_attrname, left, left_child, right, right_child, **kw diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9cf4872d023..1e08ece5357 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1,18 +1,15 @@ -# sql/types_api.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# sql/type_api.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Base types API. - -""" +"""Base types API.""" from __future__ import annotations from enum import Enum -from types import ModuleType import typing from typing import Any from typing import Callable @@ -39,6 +36,7 @@ from .. import util from ..util.typing import Protocol from ..util.typing import Self +from ..util.typing import TypeAliasType from ..util.typing import TypedDict from ..util.typing import TypeGuard @@ -57,6 +55,7 @@ from .sqltypes import NUMERICTYPE as NUMERICTYPE # noqa: F401 from .sqltypes import STRINGTYPE as STRINGTYPE # noqa: F401 from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401 + from ..engine.interfaces import DBAPIModule from ..engine.interfaces import Dialect from ..util.typing import GenericProtocol @@ -66,8 +65,11 @@ _O = TypeVar("_O", bound=object) _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) +_RT = TypeVar("_RT", bound=Any) -_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]] +_MatchedOnType = Union[ + "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any] +] class _NoValueInList(Enum): @@ -80,23 +82,19 @@ class _NoValueInList(Enum): class _LiteralProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> str: - ... + def __call__(self, value: Any) -> str: ... class _BindProcessorType(Protocol[_T_con]): - def __call__(self, value: Optional[_T_con]) -> Any: - ... + def __call__(self, value: Optional[_T_con]) -> Any: ... class _ResultProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> Optional[_T_co]: - ... + def __call__(self, value: Any) -> Optional[_T_co]: ... class _SentinelProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> Optional[_T_co]: - ... + def __call__(self, value: Any) -> Optional[_T_co]: ... class _BaseTypeMemoDict(TypedDict): @@ -112,8 +110,9 @@ class _TypeMemoDict(_BaseTypeMemoDict, total=False): class _ComparatorFactory(Protocol[_T]): - def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]: - ... + def __call__( + self, expr: ColumnElement[_T] + ) -> TypeEngine.Comparator[_T]: ... class TypeEngine(Visitable, Generic[_T]): @@ -183,10 +182,27 @@ def __init__(self, expr: ColumnElement[_CT]): self.expr = expr self.type = expr.type + def __reduce__(self) -> Any: + return self.__class__, (self.expr,) + + @overload + def operate( + self, + op: OperatorType, + *other: Any, + result_type: Type[TypeEngine[_RT]], + **kwargs: Any, + ) -> ColumnElement[_RT]: ... + + @overload + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: ... + @util.preload_module("sqlalchemy.sql.default_comparator") def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[_CT]: + ) -> ColumnElement[Any]: default_comparator = util.preloaded.sql_default_comparator op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] if kwargs: @@ -297,9 +313,9 @@ def _adapt_expression( """ - _variant_mapping: util.immutabledict[ - str, TypeEngine[Any] - ] = util.EMPTY_DICT + _variant_mapping: util.immutabledict[str, TypeEngine[Any]] = ( + util.EMPTY_DICT + ) def evaluates_none(self) -> Self: """Return a copy of this type which has the @@ -308,11 +324,13 @@ def evaluates_none(self) -> Self: E.g.:: Table( - 'some_table', metadata, + "some_table", + metadata, Column( String(50).evaluates_none(), nullable=True, - server_default='no value') + server_default="no value", + ), ) The ORM uses this flag to indicate that a positive value of ``None`` @@ -574,18 +592,6 @@ class explicitly. """ return None - def _sentinel_value_resolver( - self, dialect: Dialect - ) -> Optional[_SentinelProcessorType[_T]]: - """Return an optional callable that will match parameter values - (post-bind processing) to result values - (pre-result-processing), for use in the "sentinel" feature. - - .. versionadded:: 2.0.10 - - """ - return None - @util.memoized_property def _has_bind_expression(self) -> bool: """memoized boolean, check if bind_expression is implemented. @@ -606,7 +612,7 @@ def compare_values(self, x: Any, y: Any) -> bool: return x == y # type: ignore[no-any-return] - def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]: + def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]: """Return the corresponding type object from the underlying DB-API, if any. @@ -650,7 +656,7 @@ def with_variant( string_type = String() string_type = string_type.with_variant( - mysql.VARCHAR(collation='foo'), 'mysql', 'mariadb' + mysql.VARCHAR(collation="foo"), "mysql", "mariadb" ) The variant mapping indicates that when this type is @@ -767,6 +773,10 @@ def _resolve_for_python_type( return self + def _with_collation(self, collation: str) -> Self: + """set up error handling for the collate expression""" + raise NotImplementedError("this datatype does not support collation") + @util.ro_memoized_property def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: """Return a rudimental 'affinity' value expressing the general class @@ -933,18 +943,6 @@ def _cached_result_processor( d["result"][coltype] = rp return rp - def _cached_sentinel_value_processor( - self, dialect: Dialect - ) -> Optional[_SentinelProcessorType[_T]]: - try: - return dialect._type_memos[self]["sentinel"] - except KeyError: - pass - - d = self._dialect_info(dialect) - d["sentinel"] = bp = d["impl"]._sentinel_value_resolver(dialect) - return bp - def _cached_custom_processor( self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O] ) -> _O: @@ -999,9 +997,11 @@ def _static_cache_key( return (self.__class__,) + tuple( ( k, - self.__dict__[k]._static_cache_key - if isinstance(self.__dict__[k], TypeEngine) - else self.__dict__[k], + ( + self.__dict__[k]._static_cache_key + if isinstance(self.__dict__[k], TypeEngine) + else self.__dict__[k] + ), ) for k in names if k in self.__dict__ @@ -1010,12 +1010,12 @@ def _static_cache_key( ) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -1027,9 +1027,11 @@ def adapt( types with "implementation" types that are specific to a particular dialect. """ - return util.constructor_copy( + typ = util.constructor_copy( self, cast(Type[TypeEngine[Any]], cls), **kw ) + typ._variant_mapping = self._variant_mapping + return typ def coerce_compared_value( self, op: Optional[OperatorType], value: Any @@ -1108,26 +1110,21 @@ class TypeEngineMixin: @util.memoized_property def _static_cache_key( self, - ) -> Union[CacheConst, Tuple[Any, ...]]: - ... + ) -> Union[CacheConst, Tuple[Any, ...]]: ... @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload def adapt( self, cls: Type[TypeEngineMixin], **kw: Any - ) -> TypeEngine[Any]: - ... + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any - ) -> TypeEngine[Any]: - ... + ) -> TypeEngine[Any]: ... - def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: - ... + def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ... class ExternalType(TypeEngineMixin): @@ -1146,7 +1143,7 @@ class ExternalType(TypeEngineMixin): """ cache_ok: Optional[bool] = None - """Indicate if statements using this :class:`.ExternalType` are "safe to + '''Indicate if statements using this :class:`.ExternalType` are "safe to cache". The default value ``None`` will emit a warning and then not allow caching @@ -1187,12 +1184,12 @@ def __init__(self, choices): series of tuples. Given a previously un-cacheable type as:: class LookupType(UserDefinedType): - '''a custom type that accepts a dictionary as a parameter. + """a custom type that accepts a dictionary as a parameter. this is the non-cacheable version, as "self.lookup" is not hashable. - ''' + """ def __init__(self, lookup): self.lookup = lookup @@ -1200,8 +1197,7 @@ def __init__(self, lookup): def get_col_spec(self, **kw): return "VARCHAR(255)" - def bind_processor(self, dialect): - # ... works with "self.lookup" ... + def bind_processor(self, dialect): ... # works with "self.lookup" ... Where "lookup" is a dictionary. The type will not be able to generate a cache key:: @@ -1237,7 +1233,7 @@ def bind_processor(self, dialect): to the ".lookup" attribute:: class LookupType(UserDefinedType): - '''a custom type that accepts a dictionary as a parameter. + """a custom type that accepts a dictionary as a parameter. The dictionary is stored both as itself in a private variable, and published in a public variable as a sorted tuple of tuples, @@ -1245,7 +1241,7 @@ class LookupType(UserDefinedType): two equivalent dictionaries. Note it assumes the keys and values of the dictionary are themselves hashable. - ''' + """ cache_ok = True @@ -1254,15 +1250,12 @@ def __init__(self, lookup): # assume keys/values of "lookup" are hashable; otherwise # they would also need to be converted in some way here - self.lookup = tuple( - (key, lookup[key]) for key in sorted(lookup) - ) + self.lookup = tuple((key, lookup[key]) for key in sorted(lookup)) def get_col_spec(self, **kw): return "VARCHAR(255)" - def bind_processor(self, dialect): - # ... works with "self._lookup" ... + def bind_processor(self, dialect): ... # works with "self._lookup" ... Where above, the cache key for ``LookupType({"a": 10, "b": 20})`` will be:: @@ -1280,7 +1273,7 @@ def bind_processor(self, dialect): :ref:`sql_caching` - """ # noqa: E501 + ''' # noqa: E501 @util.non_memoized_property def _static_cache_key( @@ -1322,10 +1315,11 @@ class UserDefinedType( import sqlalchemy.types as types + class MyType(types.UserDefinedType): cache_ok = True - def __init__(self, precision = 8): + def __init__(self, precision=8): self.precision = precision def get_col_spec(self, **kw): @@ -1334,19 +1328,23 @@ def get_col_spec(self, **kw): def bind_processor(self, dialect): def process(value): return value + return process def result_processor(self, dialect, coltype): def process(value): return value + return process Once the type is made, it's immediately usable:: - table = Table('foo', metadata_obj, - Column('id', Integer, primary_key=True), - Column('data', MyType(16)) - ) + table = Table( + "foo", + metadata_obj, + Column("id", Integer, primary_key=True), + Column("data", MyType(16)), + ) The ``get_col_spec()`` method will in most cases receive a keyword argument ``type_expression`` which refers to the owning expression @@ -1391,6 +1389,10 @@ def coerce_compared_value( return self + if TYPE_CHECKING: + + def get_col_spec(self, **kw: Any) -> str: ... + class Emulated(TypeEngineMixin): """Mixin for base types that emulate the behavior of a DB-native type. @@ -1429,12 +1431,12 @@ def adapt_to_emulated( return super().adapt(impltype, **kw) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -1511,7 +1513,7 @@ def adapt_emulated_to_native( class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): - """Allows the creation of types which add additional functionality + '''Allows the creation of types which add additional functionality to an existing type. This method is preferred to direct subclassing of SQLAlchemy's @@ -1522,10 +1524,11 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): import sqlalchemy.types as types + class MyType(types.TypeDecorator): - '''Prefixes Unicode values with "PREFIX:" on the way in and + """Prefixes Unicode values with "PREFIX:" on the way in and strips it off on the way out. - ''' + """ impl = types.Unicode @@ -1578,6 +1581,8 @@ class produces the same behavior each time, it may be set to ``True``. class MyEpochType(types.TypeDecorator): impl = types.Integer + cache_ok = True + epoch = datetime.date(1970, 1, 1) def process_bind_param(self, value, dialect): @@ -1615,6 +1620,7 @@ def coerce_compared_value(self, op, value): from sqlalchemy import JSON from sqlalchemy import TypeDecorator + class MyJsonType(TypeDecorator): impl = JSON @@ -1635,6 +1641,7 @@ def coerce_compared_value(self, op, value): from sqlalchemy import ARRAY from sqlalchemy import TypeDecorator + class MyArrayType(TypeDecorator): impl = ARRAY @@ -1643,8 +1650,7 @@ class MyArrayType(TypeDecorator): def coerce_compared_value(self, op, value): return self.impl.coerce_compared_value(op, value) - - """ + ''' __visit_name__ = "type_decorator" @@ -1740,20 +1746,48 @@ def reverse_operate( kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super().reverse_operate(op, other, **kwargs) + @staticmethod + def _reduce_td_comparator( + impl: TypeEngine[Any], expr: ColumnElement[_T] + ) -> Any: + return TypeDecorator._create_td_comparator_type(impl)(expr) + + @staticmethod + def _create_td_comparator_type( + impl: TypeEngine[Any], + ) -> _ComparatorFactory[Any]: + + def __reduce__(self: TypeDecorator.Comparator[Any]) -> Any: + return (TypeDecorator._reduce_td_comparator, (impl, self.expr)) + + return type( + "TDComparator", + (TypeDecorator.Comparator, impl.comparator_factory), # type: ignore # noqa: E501 + {"__reduce__": __reduce__}, + ) + @property def comparator_factory( # type: ignore # mypy properties bug self, ) -> _ComparatorFactory[Any]: if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: # type: ignore # noqa: E501 - return self.impl.comparator_factory + return self.impl_instance.comparator_factory else: # reconcile the Comparator class on the impl with that - # of TypeDecorator - return type( - "TDComparator", - (TypeDecorator.Comparator, self.impl.comparator_factory), # type: ignore # noqa: E501 - {}, + # of TypeDecorator. + # the use of multiple staticmethods is to support repeated + # pickling of the Comparator itself + return TypeDecorator._create_td_comparator_type(self.impl_instance) + + def _copy_with_check(self) -> Self: + tt = self.copy() + if not isinstance(tt, self.__class__): + raise AssertionError( + "Type object %s does not properly " + "implement the copy() method, it must " + "return an object of type %s" % (self, self.__class__) ) + return tt def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: if dialect.name in self._variant_mapping: @@ -1769,16 +1803,17 @@ def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: # to a copy of this TypeDecorator and return # that. typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect) - tt = self.copy() - if not isinstance(tt, self.__class__): - raise AssertionError( - "Type object %s does not properly " - "implement the copy() method, it must " - "return an object of type %s" % (self, self.__class__) - ) + tt = self._copy_with_check() tt.impl = tt.impl_instance = typedesc return tt + def _with_collation(self, collation: str) -> Self: + tt = self._copy_with_check() + tt.impl = tt.impl_instance = self.impl_instance._with_collation( + collation + ) + return tt + @util.ro_non_memoized_property def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]: return self.impl_instance._type_affinity @@ -2233,7 +2268,7 @@ def copy(self, **kw: Any) -> Self: instance.__dict__.update(self.__dict__) return instance - def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]: + def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]: """Return the DBAPI type object represented by this :class:`.TypeDecorator`. @@ -2280,13 +2315,13 @@ def __init__(self, *arg: Any, **kw: Any): @overload -def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE: - ... +def to_instance( + typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any +) -> _TE: ... @overload -def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: - ... +def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: ... def to_instance( @@ -2302,11 +2337,10 @@ def to_instance( def adapt_type( - typeobj: TypeEngine[Any], + typeobj: _TypeEngineArgument[Any], colspecs: Mapping[Type[Any], Type[TypeEngine[Any]]], ) -> TypeEngine[Any]: - if isinstance(typeobj, type): - typeobj = typeobj() + typeobj = to_instance(typeobj) for t in typeobj.__class__.__mro__[0:-1]: try: impltype = colspecs[t] diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 28480a5d437..9fc4e65d9b4 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,14 +1,12 @@ # sql/util.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -"""High level utilities which build upon other modules here. - -""" +"""High level utilities which build upon other modules here.""" from __future__ import annotations from collections import deque @@ -106,7 +104,7 @@ def join_condition( would produce an expression along the lines of:: - tablea.c.id==tableb.c.tablea_id + tablea.c.id == tableb.c.tablea_id The join is determined based on the foreign key relationships between the two selectables. If there are multiple ways @@ -268,7 +266,7 @@ def visit_binary_product( The function is of the form:: - def my_fn(binary, left, right) + def my_fn(binary, left, right): ... For each binary expression located which has a comparison operator, the product of "left" and @@ -277,12 +275,11 @@ def my_fn(binary, left, right) Hence an expression like:: - and_( - (a + b) == q + func.sum(e + f), - j == r - ) + and_((a + b) == q + func.sum(e + f), j == r) - would have the traversal:: + would have the traversal: + + .. sourcecode:: text a q a e @@ -350,9 +347,9 @@ def find_tables( ] = _visitors["lateral"] = tables.append if include_crud: - _visitors["insert"] = _visitors["update"] = _visitors[ - "delete" - ] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors["delete"] = ( + lambda ent: tables.append(ent.table) + ) if check_columns: @@ -367,7 +364,7 @@ def visit_column(column): return tables -def unwrap_order_by(clause): +def unwrap_order_by(clause: Any) -> Any: """Break up an 'order by' expression into individual column-expressions, without DESC/ASC/NULLS FIRST/NULLS LAST""" @@ -481,7 +478,7 @@ def surface_selectables(clause): stack.append(elem.element) -def surface_selectables_only(clause): +def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]: stack = [clause] while stack: elem = stack.pop() @@ -528,9 +525,7 @@ def bind_values(clause): E.g.:: - >>> expr = and_( - ... table.c.foo==5, table.c.foo==7 - ... ) + >>> expr = and_(table.c.foo == 5, table.c.foo == 7) >>> bind_values(expr) [5, 7] """ @@ -878,8 +873,7 @@ def reduce_columns( columns: Iterable[ColumnElement[Any]], *clauses: Optional[ClauseElement], **kw: bool, -) -> Sequence[ColumnElement[Any]]: - ... +) -> Sequence[ColumnElement[Any]]: ... @overload @@ -887,8 +881,7 @@ def reduce_columns( columns: _SelectIterable, *clauses: Optional[ClauseElement], **kw: bool, -) -> Sequence[Union[ColumnElement[Any], TextClause]]: - ... +) -> Sequence[Union[ColumnElement[Any], TextClause]]: ... def reduce_columns( @@ -1043,20 +1036,24 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): E.g.:: - table1 = Table('sometable', metadata, - Column('col1', Integer), - Column('col2', Integer) - ) - table2 = Table('someothertable', metadata, - Column('col1', Integer), - Column('col2', Integer) - ) + table1 = Table( + "sometable", + metadata, + Column("col1", Integer), + Column("col2", Integer), + ) + table2 = Table( + "someothertable", + metadata, + Column("col1", Integer), + Column("col2", Integer), + ) condition = table1.c.col1 == table2.c.col1 make an alias of table1:: - s = table1.alias('foo') + s = table1.alias("foo") calling ``ClauseAdapter(s).traverse(condition)`` converts condition to read:: @@ -1099,8 +1096,7 @@ def __init__( if TYPE_CHECKING: @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... # note this specializes the ReplacingExternalTraversal.traverse() # method to state @@ -1111,13 +1107,11 @@ def traverse(self, obj: Literal[None]) -> None: # FromClause but Mypy is not accepting those as compatible with # the base ReplacingExternalTraversal @overload - def traverse(self, obj: _ET) -> _ET: - ... + def traverse(self, obj: _ET) -> _ET: ... def traverse( self, obj: Optional[ExternallyTraversible] - ) -> Optional[ExternallyTraversible]: - ... + ) -> Optional[ExternallyTraversible]: ... def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET @@ -1219,23 +1213,18 @@ def replace( class _ColumnLookup(Protocol): @overload - def __getitem__(self, key: None) -> None: - ... + def __getitem__(self, key: None) -> None: ... @overload - def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: - ... + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ... @overload - def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: - ... + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ... @overload - def __getitem__(self, key: _ET) -> _ET: - ... + def __getitem__(self, key: _ET) -> _ET: ... - def __getitem__(self, key: Any) -> Any: - ... + def __getitem__(self, key: Any) -> Any: ... class ColumnAdapter(ClauseAdapter): @@ -1333,12 +1322,10 @@ def wrap(self, adapter): return ac @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: _ET) -> _ET: - ... + def traverse(self, obj: _ET) -> _ET: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -1353,8 +1340,7 @@ def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: if TYPE_CHECKING: @property - def visitor_iterator(self) -> Iterator[ColumnAdapter]: - ... + def visitor_iterator(self) -> Iterator[ColumnAdapter]: ... adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index cccebe65ba8..27642851676 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,14 +1,11 @@ # sql/visitors.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Visitor/traversal interface and library functions. - - -""" +"""Visitor/traversal interface and library functions.""" from __future__ import annotations @@ -72,8 +69,7 @@ class _CompilerDispatchType(Protocol): - def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: - ... + def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: ... class Visitable: @@ -100,8 +96,7 @@ class Visitable: if typing.TYPE_CHECKING: - def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: - ... + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: ... def __init_subclass__(cls) -> None: if "__visit_name__" in cls.__dict__: @@ -493,8 +488,7 @@ def get_children( class _InternalTraversalDispatchType(Protocol): - def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: - ... + def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: ... class HasTraversalDispatch: @@ -602,13 +596,11 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: - def _annotate(self, values: _AnnotationDict) -> Self: - ... + def _annotate(self, values: _AnnotationDict) -> Self: ... def get_children( self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any - ) -> Iterable[ExternallyTraversible]: - ... + ) -> Iterable[ExternallyTraversible]: ... def _clone(self, **kw: Any) -> Self: """clone this element""" @@ -638,13 +630,11 @@ def _copy_internals( class _CloneCallableType(Protocol): - def __call__(self, element: _ET, **kw: Any) -> _ET: - ... + def __call__(self, element: _ET, **kw: Any) -> _ET: ... class _TraverseTransformCallableType(Protocol[_ET]): - def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: - ... + def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: ... _ExtT = TypeVar("_ExtT", bound="ExternalTraversal") @@ -680,12 +670,12 @@ def iterate( return iterate(obj, self.__traverse_options__) @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -746,12 +736,12 @@ def copy_and_process( return [self.traverse(x) for x in list_] @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -786,12 +776,12 @@ def replace( return None @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -866,8 +856,7 @@ def traverse_using( iterator: Iterable[ExternallyTraversible], obj: Literal[None], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... @overload @@ -875,8 +864,7 @@ def traverse_using( iterator: Iterable[ExternallyTraversible], obj: ExternallyTraversible, visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def traverse_using( @@ -920,8 +908,7 @@ def traverse( obj: Literal[None], opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... @overload @@ -929,8 +916,7 @@ def traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def traverse( @@ -945,11 +931,13 @@ def traverse( from sqlalchemy.sql import visitors - stmt = select(some_table).where(some_table.c.foo == 'bar') + stmt = select(some_table).where(some_table.c.foo == "bar") + def visit_bindparam(bind_param): print("found bound value: %s" % bind_param.value) + visitors.traverse(stmt, {}, {"bindparam": visit_bindparam}) The iteration of objects uses the :func:`.visitors.iterate` function, @@ -975,8 +963,7 @@ def cloned_traverse( obj: Literal[None], opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... # a bit of controversy here, as the clone of the lead element @@ -988,8 +975,7 @@ def cloned_traverse( obj: _ET, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> _ET: - ... +) -> _ET: ... def cloned_traverse( @@ -1088,8 +1074,7 @@ def replacement_traverse( obj: Literal[None], opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> None: - ... +) -> None: ... @overload @@ -1097,8 +1082,7 @@ def replacement_traverse( obj: _CE, opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> _CE: - ... +) -> _CE: ... @overload @@ -1106,8 +1090,7 @@ def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def replacement_traverse( diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index b218774b0d2..4e574bbb24e 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -1,5 +1,5 @@ # testing/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -83,6 +83,7 @@ from .util import resolve_lambda from .util import rowset from .util import run_as_contextmanager +from .util import skip_if_timeout from .util import teardown_events from .warnings import assert_warnings from .warnings import warn_test_suite diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e7b4161672c..719692125fb 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -1,5 +1,5 @@ # testing/assertions.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -274,8 +274,8 @@ def int_within_variance(expected, received, variance): ) -def eq_regex(a, b, msg=None): - assert re.match(b, a), msg or "%r !~ %r" % (a, b) +def eq_regex(a, b, msg=None, flags=0): + assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b) def eq_(a, b, msg=None): diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 3865497ff4c..81c7138c4b5 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -1,5 +1,5 @@ # testing/assertsql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -88,9 +88,9 @@ def _compile_dialect(self, execute_observed): dialect.supports_default_metavalue = True if self.enable_returning: - dialect.insert_returning = ( - dialect.update_returning - ) = dialect.delete_returning = True + dialect.insert_returning = dialect.update_returning = ( + dialect.delete_returning + ) = True dialect.use_insertmanyvalues = True dialect.supports_multivalues_insert = True dialect.update_returning_multifrom = True diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 4236dcf92e2..28470ba21c3 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -1,5 +1,5 @@ # testing/asyncio.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -24,16 +24,21 @@ import inspect from . import config -from ..util.concurrency import _util_async_run -from ..util.concurrency import _util_async_run_coroutine_function +from ..util.concurrency import _AsyncUtil # may be set to False if the # --disable-asyncio flag is passed to the test runner. ENABLE_ASYNCIO = True +_async_util = _AsyncUtil() # it has lazy init so just always create one + + +def _shutdown(): + """called when the test finishes""" + _async_util.close() def _run_coroutine_function(fn, *args, **kwargs): - return _util_async_run_coroutine_function(fn, *args, **kwargs) + return _async_util.run(fn, *args, **kwargs) def _assume_async(fn, *args, **kwargs): @@ -50,7 +55,7 @@ def _assume_async(fn, *args, **kwargs): if not ENABLE_ASYNCIO: return fn(*args, **kwargs) - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) def _maybe_async_provisioning(fn, *args, **kwargs): @@ -69,7 +74,7 @@ def _maybe_async_provisioning(fn, *args, **kwargs): return fn(*args, **kwargs) if config.any_async: - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) else: return fn(*args, **kwargs) @@ -89,7 +94,7 @@ def _maybe_async(fn, *args, **kwargs): is_async = config._current.is_async if is_async: - return _util_async_run(fn, *args, **kwargs) + return _async_util.run_in_greenlet(fn, *args, **kwargs) else: return fn(*args, **kwargs) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 8430203dee2..2eec642b777 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -1,5 +1,5 @@ # testing/config.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -122,7 +122,9 @@ def combinations( passed, each argument combination is turned into a pytest.param() object, mapping the elements of the argument tuple to produce an id based on a character value in the same position within the string template using the - following scheme:: + following scheme: + + .. sourcecode:: text i - the given argument is a string that is part of the id only, don't pass it as an argument @@ -146,7 +148,7 @@ def combinations( (operator.ne, "ne"), (operator.gt, "gt"), (operator.lt, "lt"), - id_="na" + id_="na", ) def test_operator(self, opfunc, name): pass @@ -177,8 +179,7 @@ def __init__(self, case, argname, case_names): if typing.TYPE_CHECKING: - def __getattr__(self, key: str) -> bool: - ... + def __getattr__(self, key: str) -> bool: ... @property def name(self): @@ -229,14 +230,9 @@ def variation(argname_or_fn, cases=None): @testing.variation("querytyp", ["select", "subquery", "legacy_query"]) @testing.variation("lazy", ["select", "raise", "raise_on_sql"]) - def test_thing( - self, - querytyp, - lazy, - decl_base - ): + def test_thing(self, querytyp, lazy, decl_base): class Thing(decl_base): - __tablename__ = 'thing' + __tablename__ = "thing" # use name directly rel = relationship("Rel", lazy=lazy.name) @@ -251,7 +247,6 @@ class Thing(decl_base): else: querytyp.fail() - The variable provided is a slots object of boolean variables, as well as the name of the case itself under the attribute ".name" @@ -269,9 +264,11 @@ def go(self, request): else: argname = argname_or_fn cases_plus_limitations = [ - entry - if (isinstance(entry, tuple) and len(entry) == 2) - else (entry, None) + ( + entry + if (isinstance(entry, tuple) and len(entry) == 2) + else (entry, None) + ) for entry in cases ] @@ -280,9 +277,11 @@ def go(self, request): ) return combinations( *[ - (variation._name, variation, limitation) - if limitation is not None - else (variation._name, variation) + ( + (variation._name, variation, limitation) + if limitation is not None + else (variation._name, variation) + ) for variation, (case, limitation) in zip( variations, cases_plus_limitations ) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 749f9c160e8..51beed98b19 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -1,5 +1,5 @@ # testing/engines.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -289,8 +289,7 @@ def testing_engine( options: Optional[Dict[str, Any]] = None, asyncio: Literal[False] = False, transfer_staticpool: bool = False, -) -> Engine: - ... +) -> Engine: ... @typing.overload @@ -299,8 +298,7 @@ def testing_engine( options: Optional[Dict[str, Any]] = None, asyncio: Literal[True] = True, transfer_staticpool: bool = False, -) -> AsyncEngine: - ... +) -> AsyncEngine: ... def testing_engine( @@ -332,16 +330,18 @@ def testing_engine( url = url or config.db.url url = make_url(url) - if options is None: - if config.db is None or url.drivername == config.db.url.drivername: - options = config.db_opts - else: - options = {} - elif config.db is not None and url.drivername == config.db.url.drivername: - default_opt = config.db_opts.copy() - default_opt.update(options) - engine = create_engine(url, **options) + if ( + config.db is None or url.drivername == config.db.url.drivername + ) and config.db_opts: + use_options = config.db_opts.copy() + else: + use_options = {} + + if options is not None: + use_options.update(options) + + engine = create_engine(url, **use_options) if sqlite_savepoint and engine.name == "sqlite": # apply SQLite savepoint workaround @@ -370,7 +370,12 @@ def do_begin(conn): True # enable event blocks, helps with profiling ) - if isinstance(engine.pool, pool.QueuePool): + if ( + isinstance(engine.pool, pool.QueuePool) + and "pool" not in use_options + and "pool_timeout" not in use_options + and "max_overflow" not in use_options + ): engine.pool._timeout = 0 engine.pool._max_overflow = 0 if use_reaper: diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 3c43f04613f..5bd4f7de240 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -1,5 +1,5 @@ # testing/entities.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 09cf5b3247a..d28e9d85e0c 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -1,5 +1,5 @@ # testing/exclusions.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -205,12 +205,12 @@ def _format_description(self, config, negate=False): if negate: bool_ = not negate return self.description % { - "driver": config.db.url.get_driver_name() - if config - else "", - "database": config.db.url.get_backend_name() - if config - else "", + "driver": ( + config.db.url.get_driver_name() if config else "" + ), + "database": ( + config.db.url.get_backend_name() if config else "" + ), "doesnt_support": "doesn't support" if bool_ else "does support", "does_support": "does support" if bool_ else "doesn't support", } @@ -392,8 +392,8 @@ def open(): # noqa return skip_if(BooleanPredicate(False, "mark as execute")) -def closed(): - return skip_if(BooleanPredicate(True, "marked as skip")) +def closed(reason="marked as skip"): + return skip_if(BooleanPredicate(True, reason)) def fails(reason=None): diff --git a/lib/sqlalchemy/testing/fixtures/__init__.py b/lib/sqlalchemy/testing/fixtures/__init__.py index 932051ce8ed..f2948dee8d3 100644 --- a/lib/sqlalchemy/testing/fixtures/__init__.py +++ b/lib/sqlalchemy/testing/fixtures/__init__.py @@ -1,5 +1,5 @@ # testing/fixtures/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/fixtures/base.py b/lib/sqlalchemy/testing/fixtures/base.py index 199ae7134ea..09d45a0a220 100644 --- a/lib/sqlalchemy/testing/fixtures/base.py +++ b/lib/sqlalchemy/testing/fixtures/base.py @@ -1,5 +1,5 @@ # testing/fixtures/base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py index 80e5ee07335..849df4dc30a 100644 --- a/lib/sqlalchemy/testing/fixtures/mypy.py +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -1,5 +1,5 @@ # testing/fixtures/mypy.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -86,9 +86,11 @@ def run(path, use_plugin=False, use_cachedir=None): "--config-file", os.path.join( use_cachedir, - "sqla_mypy_config.cfg" - if use_plugin - else "plain_mypy_config.cfg", + ( + "sqla_mypy_config.cfg" + if use_plugin + else "plain_mypy_config.cfg" + ), ), ] @@ -141,7 +143,9 @@ def _collect_messages(self, path): from sqlalchemy.ext.mypy.util import mypy_14 expected_messages = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + expected_re = re.compile( + r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)" + ) py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: current_assert_messages = [] @@ -159,9 +163,24 @@ def _collect_messages(self, path): if m: is_mypy = bool(m.group(1)) is_re = bool(m.group(2)) - is_type = bool(m.group(3)) + is_row = bool(m.group(3)) + is_type = bool(m.group(4)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5)) + if is_row: + expected_msg = re.sub( + r"Row\[([^\]]+)\]", + lambda m: f"tuple[{m.group(1)}, fallback=s" + f"qlalchemy.engine.row.{m.group(0)}]", + expected_msg, + ) + # For some reason it does not use or syntax (|) + expected_msg = re.sub( + r"Optional\[(.*)\]", + lambda m: f"Union[{m.group(1)}, None]", + expected_msg, + ) - expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) if is_type: if not is_re: # the goal here is that we can cut-and-paste @@ -208,9 +227,11 @@ def _collect_messages(self, path): # skip first character which could be capitalized # "List item x not found" type of message expected_msg = expected_msg[0] + re.sub( - r"\b(List|Tuple|Dict|Set)\b" - if is_type - else r"\b(List|Tuple|Dict|Set|Type)\b", + ( + r"\b(List|Tuple|Dict|Set)\b" + if is_type + else r"\b(List|Tuple|Dict|Set|Type)\b" + ), lambda m: m.group(1).lower(), expected_msg[1:], ) @@ -239,7 +260,9 @@ def _collect_messages(self, path): return expected_messages - def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + def _check_output( + self, path, expected_messages, stdout: str, stderr, exitcode + ): not_located = [] filename = os.path.basename(path) if expected_messages: @@ -259,7 +282,8 @@ def _check_output(self, path, expected_messages, stdout, stderr, exitcode): ): while raw_lines: ol = raw_lines.pop(0) - if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + if not re.match(r".+\.py:\d+: note: +def .*", ol): + raw_lines.insert(0, ol) break elif re.match( r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I diff --git a/lib/sqlalchemy/testing/fixtures/orm.py b/lib/sqlalchemy/testing/fixtures/orm.py index da622c068cf..77cb243a808 100644 --- a/lib/sqlalchemy/testing/fixtures/orm.py +++ b/lib/sqlalchemy/testing/fixtures/orm.py @@ -1,5 +1,5 @@ # testing/fixtures/orm.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/fixtures/sql.py b/lib/sqlalchemy/testing/fixtures/sql.py index 911dddda312..44cf21c24fe 100644 --- a/lib/sqlalchemy/testing/fixtures/sql.py +++ b/lib/sqlalchemy/testing/fixtures/sql.py @@ -1,5 +1,5 @@ # testing/fixtures/sql.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -459,6 +459,10 @@ def __init__(self, cursor): # by not having the other methods we assert that those aren't being # used + @property + def description(self): + return self.cursor.description + def fetchall(self): rows = self.cursor.fetchall() rows = list(rows) @@ -466,22 +470,29 @@ def fetchall(self): return rows def _deliver_insertmanyvalues_batches( - cursor, statement, parameters, generic_setinputsizes, context + connection, + cursor, + statement, + parameters, + generic_setinputsizes, + context, ): if randomize_rows: cursor = RandomCursor(cursor) for batch in orig_dialect( - cursor, statement, parameters, generic_setinputsizes, context + connection, + cursor, + statement, + parameters, + generic_setinputsizes, + context, ): if warn_on_downgraded and batch.is_downgraded: util.warn("Batches were downgraded for sorted INSERT") yield batch - def _exec_insertmany_context( - dialect, - context, - ): + def _exec_insertmany_context(dialect, context): with mock.patch.object( dialect, "_deliver_insertmanyvalues_batches", diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 89155a84190..9317be63b8f 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -1,5 +1,5 @@ # testing/pickleable.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py index e69de29bb2d..ce960be967d 100644 --- a/lib/sqlalchemy/testing/plugin/__init__.py +++ b/lib/sqlalchemy/testing/plugin/__init__.py @@ -0,0 +1,6 @@ +# testing/plugin/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index f93b8d3e629..2ad4d9915eb 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -1,3 +1,9 @@ +# testing/plugin/bootstrap.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors """ diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index f6a7f152b79..2dfa441413d 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -1,5 +1,5 @@ -# plugin/plugin_base.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# testing/plugin/plugin_base.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -90,7 +90,7 @@ def setup_options(make_option): action="append", type=str, dest="dburi", - help="Database uri. Multiple OK, " "first one is run by default.", + help="Database uri. Multiple OK, first one is run by default.", ) make_option( "--dbdriver", diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index a676e7e28d0..e5b63adf295 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -1,3 +1,9 @@ +# testing/plugin/pytestplugin.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from __future__ import annotations @@ -176,6 +182,12 @@ def pytest_sessionfinish(session): collect_types.dump_stats(session.config.option.dump_pyannotate) +def pytest_unconfigure(config): + from sqlalchemy.testing import asyncio + + asyncio._shutdown() + + def pytest_collection_finish(session): if session.config.option.dump_pyannotate: from pyannotate_runtime import collect_types @@ -258,7 +270,6 @@ def setup_test_classes(): for test_class in test_classes: # transfer legacy __backend__ and __sparse_backend__ symbols # to be markers - add_markers = set() if getattr(test_class.cls, "__backend__", False) or getattr( test_class.cls, "__only_on__", False ): @@ -663,9 +674,9 @@ def mark_base_test_class(self): "i": lambda obj: obj, "r": repr, "s": str, - "n": lambda obj: obj.__name__ - if hasattr(obj, "__name__") - else type(obj).__name__, + "n": lambda obj: ( + obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__ + ), } def combinations(self, *arg_sets, **kw): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 5471b1cfd48..0d90947e444 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -1,5 +1,5 @@ # testing/profiling.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 0ff564e2455..3afcf119b27 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -1,3 +1,9 @@ +# testing/provision.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from __future__ import annotations @@ -68,6 +74,7 @@ def setup_config(db_url, options, file_config, follower_ident): # hooks dialect = sa_url.make_url(db_url).get_dialect() + dialect.load_provisioning() if follower_ident: @@ -101,7 +108,9 @@ def generate_db_urls(db_urls, extra_drivers): """Generate a set of URLs to test given configured URLs plus additional driver names. - Given:: + Given: + + .. sourcecode:: text --dburi postgresql://db1 \ --dburi postgresql://db2 \ @@ -109,7 +118,9 @@ def generate_db_urls(db_urls, extra_drivers): --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true Noting that the default postgresql driver is psycopg2, the output - would be:: + would be: + + .. sourcecode:: text postgresql+psycopg2://db1 postgresql+asyncpg://db1 @@ -126,6 +137,8 @@ def generate_db_urls(db_urls, extra_drivers): driver name. For example, to enable the async fallback option for asyncpg:: + .. sourcecode:: text + --dburi postgresql://db1 \ --dbdriver=asyncpg?async_fallback=true @@ -140,7 +153,10 @@ def generate_db_urls(db_urls, extra_drivers): ] for url_obj, dialect in urls_plus_dialects: - backend_to_driver_we_already_have[dialect.name].add(dialect.driver) + # use get_driver_name instead of dialect.driver to account for + # "_async" virtual drivers like oracledb and psycopg + driver_name = url_obj.get_driver_name() + backend_to_driver_we_already_have[dialect.name].add(driver_name) backend_to_driver_we_need = {} @@ -352,7 +368,7 @@ def update_db_opts(db_url, db_opts, options): def post_configure_engine(url, engine, follower_ident): """Perform extra steps after configuring an engine for testing. - (For the internal dialects, currently only used by sqlite, oracle) + (For the internal dialects, currently only used by sqlite, oracle, mssql) """ diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 5d1f3fb1663..fd64b1ffd43 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1,5 +1,5 @@ # testing/requirements.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -19,6 +19,7 @@ from __future__ import annotations +import os import platform from . import asyncio as _test_asyncio @@ -91,7 +92,9 @@ def unique_constraints_reflect_as_index(self): @property def table_value_constructor(self): - """Database / dialect supports a query like:: + """Database / dialect supports a query like: + + .. sourcecode:: sql SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...) AS some_table(col1, col2) @@ -796,6 +799,11 @@ def unique_constraint_reflection(self): """target dialect supports reflection of unique constraints""" return exclusions.open() + @property + def inline_check_constraint_reflection(self): + """target dialect supports reflection of inline check constraints""" + return exclusions.closed() + @property def check_constraint_reflection(self): """target dialect supports reflection of check constraints""" @@ -987,7 +995,9 @@ def binary_comparisons(self): @property def binary_literals(self): """target backend supports simple binary literals, e.g. an - expression like:: + expression like: + + .. sourcecode:: sql SELECT CAST('foo' AS BINARY) @@ -1093,6 +1103,11 @@ def go(config): return exclusions.only_if(go) + @property + def array_type(self): + """Target platform implements a native ARRAY type""" + return exclusions.closed() + @property def json_type(self): """target platform implements a native JSON type.""" @@ -1154,6 +1169,19 @@ def cast_precision_numerics_many_significant_digits(self): """ return self.precision_numerics_many_significant_digits + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.closed() + + @property + def expression_server_defaults(self): + """Target backend supports server side defaults with SQL expressions + for columns""" + + return exclusions.closed() + @property def implicit_decimal_binds(self): """target backend will return a selected Decimal as a Decimal, not @@ -1163,9 +1191,7 @@ def implicit_decimal_binds(self): expr = decimal.Decimal("15.7563") - value = e.scalar( - select(literal(expr)) - ) + value = e.scalar(select(literal(expr))) assert value == expr @@ -1333,7 +1359,9 @@ def update_where_target_in_subquery(self): present in a subquery in the WHERE clause. This is an ANSI-standard syntax that apparently MySQL can't handle, - such as:: + such as: + + .. sourcecode:: sql UPDATE documents SET flag=1 WHERE documents.title IN (SELECT max(documents.title) AS title @@ -1366,7 +1394,11 @@ def order_by_col_from_union(self): """target database supports ordering by a column from a SELECT inside of a UNION - E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id + E.g.: + + .. sourcecode:: sql + + (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id """ return exclusions.open() @@ -1376,7 +1408,9 @@ def order_by_label_with_expression(self): """target backend supports ORDER BY a column label within an expression. - Basically this:: + Basically this: + + .. sourcecode:: sql select data as foo from test order by foo || 'bar' @@ -1465,6 +1499,10 @@ def timing_intensive(self): return config.add_to_marker.timing_intensive + @property + def posix(self): + return exclusions.skip_if(lambda: os.name != "posix") + @property def memory_intensive(self): from . import config @@ -1506,6 +1544,27 @@ def check(config): return exclusions.skip_if(check) + @property + def up_to_date_typealias_type(self): + # this checks a particular quirk found in typing_extensions <=4.12.0 + # using older python versions like 3.10 or 3.9, we use TypeAliasType + # from typing_extensions which does not provide for sufficient + # introspection prior to 4.13.0 + def check(config): + import typing + import typing_extensions + + TypeAliasType = getattr( + typing, "TypeAliasType", typing_extensions.TypeAliasType + ) + TV = typing.TypeVar("TV") + TA_generic = TypeAliasType( # type: ignore + "TA_generic", typing.List[TV], type_params=(TV,) + ) + return hasattr(TA_generic[int], "__value__") + + return exclusions.only_if(check) + @property def python38(self): return exclusions.only_if( @@ -1530,6 +1589,32 @@ def python311(self): lambda: util.py311, "Python 3.11 or above required" ) + @property + def python312(self): + return exclusions.only_if( + lambda: util.py312, "Python 3.12 or above required" + ) + + @property + def fail_python314b1(self): + return exclusions.fails_if( + lambda: util.compat.py314b1, "Fails as of python 3.14.0b1" + ) + + @property + def not_python314(self): + """This requirement is interim to assist with backporting of + issue #12405. + + SQLAlchemy 2.0 still includes the ``await_fallback()`` method that + makes use of ``asyncio.get_event_loop_policy()``. This is removed + in SQLAlchemy 2.1. + + """ + return exclusions.skip_if( + lambda: util.py314, "Python 3.14 or above not supported" + ) + @property def cpython(self): return exclusions.only_if( @@ -1609,6 +1694,18 @@ def async_dialect(self): def asyncio(self): return self.greenlet + @property + def no_greenlet(self): + def go(config): + try: + import greenlet # noqa: F401 + except ImportError: + return True + else: + return False + + return exclusions.only_if(go) + @property def greenlet(self): def go(config): @@ -1763,3 +1860,34 @@ def materialized_views(self): def materialized_views_reflect_pk(self): """Target database reflect MATERIALIZED VIEWs pks.""" return exclusions.closed() + + @property + def supports_bitwise_or(self): + """Target database supports bitwise or""" + return exclusions.closed() + + @property + def supports_bitwise_and(self): + """Target database supports bitwise and""" + return exclusions.closed() + + @property + def supports_bitwise_not(self): + """Target database supports bitwise not""" + return exclusions.closed() + + @property + def supports_bitwise_xor(self): + """Target database supports bitwise xor""" + return exclusions.closed() + + @property + def supports_bitwise_shift(self): + """Target database supports bitwise left or right shift""" + return exclusions.closed() + + @property + def like_escapes(self): + """Target backend supports custom ESCAPE characters + with LIKE comparisons""" + return exclusions.open() diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 72ef9754ef5..0dd7de2029d 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -1,5 +1,5 @@ # testing/schema.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 30817e1e445..8435aa004f3 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,3 +1,9 @@ +# testing/suite/__init__.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php from .test_cte import * # noqa from .test_ddl import * # noqa from .test_deprecations import * # noqa diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index fb767e46354..4e4d420faa1 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -1,3 +1,9 @@ +# testing/suite/test_cte.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from .. import fixtures diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 35651170d12..c7e7d817d8e 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -1,3 +1,9 @@ +# testing/suite/test_ddl.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import random diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py index c453cbfed92..db0a9fc48db 100644 --- a/lib/sqlalchemy/testing/suite/test_deprecations.py +++ b/lib/sqlalchemy/testing/suite/test_deprecations.py @@ -1,3 +1,9 @@ +# testing/suite/test_deprecations.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from .. import fixtures diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 6edf93ffdc3..ebbb9e435a0 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,3 +1,9 @@ +# testing/suite/test_dialect.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors @@ -531,7 +537,7 @@ def test_round_trip_same_named_column( t.c[name].in_(["some name", "some other_name"]) ) - row = connection.execute(stmt).first() + connection.execute(stmt).first() @testing.fixture def multirow_fixture(self, metadata, connection): @@ -615,7 +621,7 @@ def go(stmt, executemany, id_param_name, expect_success): f"current server capabilities does not support " f".*RETURNING when executemany is used", ): - result = connection.execute( + connection.execute( stmt, [ {id_param_name: 1, "data": "d1"}, diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 09f24d356da..8467c351790 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -1,3 +1,9 @@ +# testing/suite/test_insert.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from decimal import Decimal @@ -486,9 +492,11 @@ def test_insert_w_floats( t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: @@ -545,6 +553,12 @@ def test_insert_w_floats( uuid.uuid4(), testing.requires.uuid_data_type, ), + ( + "generic_native_uuid_str", + Uuid(as_uuid=False, native_uuid=True), + str(uuid.uuid4()), + testing.requires.uuid_data_type, + ), ("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type), ( "LargeBinary1", @@ -590,9 +604,11 @@ def test_imv_returning_datatypes( t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index f2ecf1cae95..d3d8b37dfa7 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1,5 +1,12 @@ +# testing/suite/test_reflection.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors +import contextlib import operator import re @@ -7,6 +14,7 @@ from .. import config from .. import engines from .. import eq_ +from .. import eq_regex from .. import expect_raises from .. import expect_raises_message from .. import expect_warnings @@ -16,6 +24,8 @@ from ..provision import temp_table_keyword_args from ..schema import Column from ..schema import Table +from ... import Boolean +from ... import DateTime from ... import event from ... import ForeignKey from ... import func @@ -213,6 +223,7 @@ def test_has_table_view_schema(self, connection): class HasIndexTest(fixtures.TablesTest): __backend__ = True + __requires__ = ("index_reflection",) @classmethod def define_tables(cls, metadata): @@ -291,6 +302,7 @@ class BizarroCharacterFKResolutionTest(fixtures.TestBase): """tests for #10275""" __backend__ = True + __requires__ = ("foreign_key_constraint_reflection",) @testing.combinations( ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" @@ -448,7 +460,7 @@ def test_get_table_options(self, name): is_true(isinstance(res, dict)) else: with expect_raises(NotImplementedError): - res = insp.get_table_options(name) + insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection @@ -467,11 +479,13 @@ def test_get_pk_constraint(self, name): assert insp.get_pk_constraint(name) @quote_fixtures + @testing.requires.foreign_key_constraint_reflection def test_get_foreign_keys(self, name): insp = inspect(config.db) assert insp.get_foreign_keys(name) @quote_fixtures + @testing.requires.index_reflection def test_get_indexes(self, name): insp = inspect(config.db) assert insp.get_indexes(name) @@ -1084,9 +1098,9 @@ def fk( "referred_columns": ref_col, "name": name, "options": mock.ANY, - "referred_schema": ref_schema - if ref_schema is not None - else tt(), + "referred_schema": ( + ref_schema if ref_schema is not None else tt() + ), "referred_table": ref_table, "comment": comment, } @@ -1940,6 +1954,8 @@ def test_get_unique_constraints(self, metadata, connection, use_schema): if dupe: names_that_duplicate_index.add(dupe) eq_(refl.pop("comment", None), None) + # ignore dialect_options + refl.pop("dialect_options", None) eq_(orig, refl) reflected_metadata = MetaData() @@ -2031,7 +2047,7 @@ def test_get_table_options(self, use_schema): is_true(isinstance(res, dict)) else: with expect_raises(NotImplementedError): - res = insp.get_table_options("users", schema=schema) + insp.get_table_options("users", schema=schema) @testing.combinations((True, testing.requires.schemas), False) def test_multi_get_table_options(self, use_schema): @@ -2047,7 +2063,7 @@ def test_multi_get_table_options(self, use_schema): eq_(res, exp) else: with expect_raises(NotImplementedError): - res = insp.get_multi_table_options() + insp.get_multi_table_options() @testing.fixture def get_multi_exp(self, connection): @@ -2448,62 +2464,158 @@ def test_get_columns_view_no_columns(self, connection, view_no_columns): class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): __backend__ = True - @testing.combinations( - (True, testing.requires.schemas), (False,), argnames="use_schema" - ) - @testing.requires.check_constraint_reflection - def test_get_check_constraints(self, metadata, connection, use_schema): - if use_schema: - schema = config.test_schema + @testing.fixture(params=[True, False]) + def use_schema_fixture(self, request): + if request.param: + return config.test_schema else: - schema = None + return None - Table( - "sa_cc", - metadata, - Column("a", Integer()), - sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), - sa.CheckConstraint( - "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing" - ), - schema=schema, - ) - Table( - "no_constraints", - metadata, - Column("data", sa.String(20)), - schema=schema, - ) + @testing.fixture() + def inspect_for_table(self, metadata, connection, use_schema_fixture): + @contextlib.contextmanager + def go(tablename): + yield use_schema_fixture, inspect(connection) - metadata.create_all(connection) + metadata.create_all(connection) - insp = inspect(connection) - reflected = sorted( - insp.get_check_constraints("sa_cc", schema=schema), - key=operator.itemgetter("name"), - ) + return go + def ck_eq(self, reflected, expected): # trying to minimize effect of quoting, parenthesis, etc. # may need to add more to this as new dialects get CHECK # constraint reflection support def normalize(sqltext): return " ".join( - re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + re.findall(r"and|\d|=|a|b|c|or|<|>", sqltext.lower(), re.I) ) - reflected = [ - {"name": item["name"], "sqltext": normalize(item["sqltext"])} - for item in reflected - ] - eq_( + reflected = sorted( + [ + {"name": item["name"], "sqltext": normalize(item["sqltext"])} + for item in reflected + ], + key=lambda item: (item["sqltext"]), + ) + + expected = sorted( + expected, + key=lambda item: (item["sqltext"]), + ) + eq_(reflected, expected) + + @testing.requires.check_constraint_reflection + def test_check_constraint_no_constraint(self, metadata, inspect_for_table): + with inspect_for_table("no_constraints") as (schema, inspector): + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) + + self.ck_eq( + inspector.get_check_constraints("no_constraints", schema=schema), + [], + ) + + @testing.requires.inline_check_constraint_reflection + @testing.combinations( + "my_inline", "MyInline", None, argnames="constraint_name" + ) + def test_check_constraint_inline( + self, metadata, inspect_for_table, constraint_name + ): + + with inspect_for_table("sa_cc") as (schema, inspector): + Table( + "sa_cc", + metadata, + Column("id", Integer(), primary_key=True), + Column( + "a", + Integer(), + sa.CheckConstraint( + "a > 1 AND a < 5", name=constraint_name + ), + ), + Column("data", String(50)), + schema=schema, + ) + + reflected = inspector.get_check_constraints("sa_cc", schema=schema) + + self.ck_eq( reflected, [ - {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"}, - {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + { + "name": constraint_name or mock.ANY, + "sqltext": "a > 1 and a < 5", + }, + ], + ) + + @testing.requires.check_constraint_reflection + @testing.combinations( + "my_ck_const", "MyCkConst", None, argnames="constraint_name" + ) + def test_check_constraint_standalone( + self, metadata, inspect_for_table, constraint_name + ): + with inspect_for_table("sa_cc") as (schema, inspector): + Table( + "sa_cc", + metadata, + Column("a", Integer()), + sa.CheckConstraint( + "a = 1 OR (a > 2 AND a < 5)", name=constraint_name + ), + schema=schema, + ) + + reflected = inspector.get_check_constraints("sa_cc", schema=schema) + + self.ck_eq( + reflected, + [ + { + "name": constraint_name or mock.ANY, + "sqltext": "a = 1 or a > 2 and a < 5", + }, + ], + ) + + @testing.requires.inline_check_constraint_reflection + def test_check_constraint_mixed(self, metadata, inspect_for_table): + with inspect_for_table("sa_cc") as (schema, inspector): + Table( + "sa_cc", + metadata, + Column("id", Integer(), primary_key=True), + Column("a", Integer(), sa.CheckConstraint("a > 1 AND a < 5")), + Column( + "b", + Integer(), + sa.CheckConstraint("b > 1 AND b < 5", name="my_inline"), + ), + Column("c", Integer()), + Column("data", String(50)), + sa.UniqueConstraint("data", name="some_uq"), + sa.CheckConstraint("c > 1 AND c < 5", name="cc1"), + sa.UniqueConstraint("c", name="some_c_uq"), + schema=schema, + ) + + reflected = inspector.get_check_constraints("sa_cc", schema=schema) + + self.ck_eq( + reflected, + [ + {"name": "cc1", "sqltext": "c > 1 and c < 5"}, + {"name": "my_inline", "sqltext": "b > 1 and b < 5"}, + {"name": mock.ANY, "sqltext": "a > 1 and a < 5"}, ], ) - no_cst = "no_constraints" - eq_(insp.get_check_constraints(no_cst, schema=schema), []) @testing.requires.indexes_with_expressions def test_reflect_expression_based_indexes(self, metadata, connection): @@ -2776,6 +2888,47 @@ def test_get_foreign_key_options( eq_(opts, expected) # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + @testing.combinations( + (Integer, sa.text("10"), r"'?10'?"), + (Integer, "10", r"'?10'?"), + (Boolean, sa.true(), r"1|true"), + ( + Integer, + sa.text("3 + 5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + ( + Integer, + sa.text("(3 * 5)"), + r"3\*5", + testing.requires.expression_server_defaults, + ), + (DateTime, func.now(), r"current_timestamp|now|getdate"), + ( + Integer, + sa.literal_column("3") + sa.literal_column("5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + argnames="datatype, default, expected_reg", + ) + @testing.requires.server_defaults + def test_server_defaults( + self, metadata, connection, datatype, default, expected_reg + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + reflected = inspect(connection).get_columns("t")[1]["default"] + reflected_sanitized = re.sub(r"[\(\) \']", "", reflected) + eq_regex(reflected_sanitized, expected_reg, flags=re.IGNORECASE) + class NormalizedNameTest(fixtures.TablesTest): __requires__ = ("denormalized_names",) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e439d6ca6d9..317195fd1e9 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -1,6 +1,13 @@ +# testing/suite/test_results.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import datetime +import re from .. import engines from .. import fixtures @@ -261,12 +268,16 @@ def _is_server_side(self, cursor): return isinstance(cursor, sscursor) elif self.engine.dialect.driver == "mariadbconnector": return not cursor.buffered + elif self.engine.dialect.driver == "mysqlconnector": + return "buffered" not in type(cursor).__name__.lower() elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): return cursor.server_side elif self.engine.dialect.driver == "pg8000": return getattr(cursor, "server_side", False) elif self.engine.dialect.driver == "psycopg": return bool(getattr(cursor, "name", False)) + elif self.engine.dialect.driver == "oracledb": + return getattr(cursor, "server_side", False) else: return False @@ -287,11 +298,26 @@ def _fixture(self, server_side_cursors): ) return self.engine + def stringify(self, str_): + return re.compile(r"SELECT (\d+)", re.I).sub( + lambda m: str(select(int(m.group(1))).compile(testing.db)), str_ + ) + @testing.combinations( - ("global_string", True, "select 1", True), - ("global_text", True, text("select 1"), True), + ("global_string", True, lambda stringify: stringify("select 1"), True), + ( + "global_text", + True, + lambda stringify: text(stringify("select 1")), + True, + ), ("global_expr", True, select(1), True), - ("global_off_explicit", False, text("select 1"), False), + ( + "global_off_explicit", + False, + lambda stringify: text(stringify("select 1")), + False, + ), ( "stmt_option", False, @@ -309,15 +335,22 @@ def _fixture(self, server_side_cursors): ( "for_update_string", True, - "SELECT 1 FOR UPDATE", + lambda stringify: stringify("SELECT 1 FOR UPDATE"), True, testing.skip_if(["sqlite", "mssql"]), ), - ("text_no_ss", False, text("select 42"), False), + ( + "text_no_ss", + False, + lambda stringify: text(stringify("select 42")), + False, + ), ( "text_ss_option", False, - text("select 42").execution_options(stream_results=True), + lambda stringify: text(stringify("select 42")).execution_options( + stream_results=True + ), True, ), id_="iaaa", @@ -328,6 +361,11 @@ def test_ss_cursor_status( ): engine = self._fixture(engine_ss_arg) with engine.begin() as conn: + if callable(statement): + statement = testing.resolve_lambda( + statement, stringify=self.stringify + ) + if isinstance(statement, str): result = conn.exec_driver_sql(statement) else: @@ -342,7 +380,7 @@ def test_conn_option(self): # should be enabled for this one result = conn.execution_options( stream_results=True - ).exec_driver_sql("select 1") + ).exec_driver_sql(self.stringify("select 1")) assert self._is_server_side(result.cursor) # the connection has autobegun, which means at the end of the @@ -396,7 +434,9 @@ def test_roundtrip_fetchall(self, metadata): test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) @@ -436,7 +476,9 @@ def test_roundtrip_fetchmany(self, metadata): test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py index 58295a5c531..59953fff59c 100644 --- a/lib/sqlalchemy/testing/suite/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -1,3 +1,9 @@ +# testing/suite/test_rowcount.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from sqlalchemy import bindparam @@ -198,7 +204,7 @@ def test_raw_sql_rowcount(self, connection): def test_text_rowcount(self, connection): # test issue #3622, make sure eager rowcount is called for text result = connection.execute( - text("update employees set department='Z' " "where department='C'") + text("update employees set department='Z' where department='C'") ) eq_(result.rowcount, 3) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index a0aa147f9c0..d67d7698767 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -1,3 +1,9 @@ +# testing/suite/test_select.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import collections.abc as collections_abc @@ -1535,6 +1541,7 @@ def test_startswith_unescaped(self): col = self.tables.some_table.c.data self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + @testing.requires.like_escapes def test_startswith_autoescape(self): col = self.tables.some_table.c.data self._test(col.startswith("ab%c", autoescape=True), {3}) @@ -1546,10 +1553,12 @@ def test_startswith_sqlexpr(self): {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, ) + @testing.requires.like_escapes def test_startswith_escape(self): col = self.tables.some_table.c.data self._test(col.startswith("ab##c", escape="#"), {7}) + @testing.requires.like_escapes def test_startswith_autoescape_escape(self): col = self.tables.some_table.c.data self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3}) @@ -1565,14 +1574,17 @@ def test_endswith_sqlexpr(self): col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} ) + @testing.requires.like_escapes def test_endswith_autoescape(self): col = self.tables.some_table.c.data self._test(col.endswith("e%fg", autoescape=True), {6}) + @testing.requires.like_escapes def test_endswith_escape(self): col = self.tables.some_table.c.data self._test(col.endswith("e##fg", escape="#"), {9}) + @testing.requires.like_escapes def test_endswith_autoescape_escape(self): col = self.tables.some_table.c.data self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6}) @@ -1582,14 +1594,17 @@ def test_contains_unescaped(self): col = self.tables.some_table.c.data self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + @testing.requires.like_escapes def test_contains_autoescape(self): col = self.tables.some_table.c.data self._test(col.contains("b%cde", autoescape=True), {3}) + @testing.requires.like_escapes def test_contains_escape(self): col = self.tables.some_table.c.data self._test(col.contains("b##cde", escape="#"), {7}) + @testing.requires.like_escapes def test_contains_autoescape_escape(self): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) @@ -1765,7 +1780,7 @@ def define_tables(cls, metadata): ) def test_autoincrement_with_identity(self, connection): - res = connection.execute(self.tables.tbl.insert(), {"desc": "row"}) + connection.execute(self.tables.tbl.insert(), {"desc": "row"}) res = connection.execute(self.tables.tbl.select()).first() eq_(res, (1, "row")) @@ -1880,3 +1895,114 @@ def test_is_or_is_not_distinct_from( len(result), expected_row_count_for_is_not, ) + + +class WindowFunctionTest(fixtures.TablesTest): + __requires__ = ("window_functions",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("col1", Integer), + Column("col2", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [{"id": i, "col1": i, "col2": i * 5} for i in range(1, 50)], + ) + + def test_window(self, connection): + some_table = self.tables.some_table + rows = connection.execute( + select( + func.max(some_table.c.col2).over( + order_by=[some_table.c.col1.desc()] + ) + ).where(some_table.c.col1 < 20) + ).all() + + eq_(rows, [(95,) for i in range(19)]) + + def test_window_rows_between(self, connection): + some_table = self.tables.some_table + + # note the rows are part of the cache key right now, not handled + # as binds. this is issue #11515 + rows = connection.execute( + select( + func.max(some_table.c.col2).over( + order_by=[some_table.c.col1], + rows=(-5, 0), + ) + ) + ).all() + + eq_(rows, [(i,) for i in range(5, 250, 5)]) + + +class BitwiseTest(fixtures.TablesTest): + __backend__ = True + run_inserts = run_deletes = "once" + + inserted_data = [{"a": i, "b": i + 1} for i in range(10)] + + @classmethod + def define_tables(cls, metadata): + Table("bitwise", metadata, Column("a", Integer), Column("b", Integer)) + + @classmethod + def insert_data(cls, connection): + connection.execute(cls.tables.bitwise.insert(), cls.inserted_data) + + @testing.combinations( + ( + lambda a: a.bitwise_xor(5), + [i for i in range(10) if i != 5], + testing.requires.supports_bitwise_xor, + ), + ( + lambda a: a.bitwise_or(1), + list(range(10)), + testing.requires.supports_bitwise_or, + ), + ( + lambda a: a.bitwise_and(4), + list(range(4, 8)), + testing.requires.supports_bitwise_and, + ), + ( + lambda a: (a - 2).bitwise_not(), + [0], + testing.requires.supports_bitwise_not, + ), + ( + lambda a: a.bitwise_lshift(1), + list(range(1, 10)), + testing.requires.supports_bitwise_shift, + ), + ( + lambda a: a.bitwise_rshift(2), + list(range(4, 10)), + testing.requires.supports_bitwise_shift, + ), + argnames="case, expected", + ) + def test_bitwise(self, case, expected, connection): + tbl = self.tables.bitwise + + a = tbl.c.a + + op = testing.resolve_lambda(case, a=a) + + stmt = select(tbl).where(op > 0).order_by(a) + + res = connection.execute(stmt).mappings().all() + eq_(res, [self.inserted_data[i] for i in expected]) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index 43e2d066bba..f0e6575370b 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -1,3 +1,9 @@ +# testing/suite/test_sequence.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from .. import config diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 5debb450f60..5f1bf75d504 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -1,3 +1,9 @@ +# testing/suite/test_types.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors @@ -26,6 +32,7 @@ from ... import cast from ... import Date from ... import DateTime +from ... import Enum from ... import Float from ... import Integer from ... import Interval @@ -292,6 +299,7 @@ def test_literal_complex(self, literal_round_trip): class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): __backend__ = True + __requires__ = ("binary_literals",) @classmethod def define_tables(cls, metadata): @@ -1476,6 +1484,7 @@ def default(self, o): return datatype, compare_value, p_s + @testing.requires.legacy_unconditional_json_extract @_index_fixtures(False) def test_index_typed_access(self, datatype, value): data_table = self.tables.data_table @@ -1497,6 +1506,7 @@ def test_index_typed_access(self, datatype, value): eq_(roundtrip, compare_value) is_(type(roundtrip), type(compare_value)) + @testing.requires.legacy_unconditional_json_extract @_index_fixtures(True) def test_index_typed_comparison(self, datatype, value): data_table = self.tables.data_table @@ -1521,6 +1531,7 @@ def test_index_typed_comparison(self, datatype, value): # make sure we get a row even if value is None eq_(row, (compare_value,)) + @testing.requires.legacy_unconditional_json_extract @_index_fixtures(True) def test_path_typed_comparison(self, datatype, value): data_table = self.tables.data_table @@ -1912,6 +1923,74 @@ def test_string_cast_crit_against_string_basic(self): ) +class EnumTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + enum_values = "a", "b", "a%", "b%percent", "réveillé" + + datatype = Enum(*enum_values, name="myenum") + + @classmethod + def define_tables(cls, metadata): + Table( + "enum_table", + metadata, + Column("id", Integer, primary_key=True), + Column("enum_data", cls.datatype), + ) + + @testing.combinations(*enum_values, argnames="data") + def test_round_trip(self, data, connection): + connection.execute( + self.tables.enum_table.insert(), {"id": 1, "enum_data": data} + ) + + eq_( + connection.scalar( + select(self.tables.enum_table.c.enum_data).where( + self.tables.enum_table.c.id == 1 + ) + ), + data, + ) + + def test_round_trip_executemany(self, connection): + connection.execute( + self.tables.enum_table.insert(), + [ + {"id": 1, "enum_data": "b%percent"}, + {"id": 2, "enum_data": "réveillé"}, + {"id": 3, "enum_data": "b"}, + {"id": 4, "enum_data": "a%"}, + ], + ) + + eq_( + connection.scalars( + select(self.tables.enum_table.c.enum_data).order_by( + self.tables.enum_table.c.id + ) + ).all(), + ["b%percent", "réveillé", "b", "a%"], + ) + + @testing.requires.insert_executemany_returning + def test_round_trip_executemany_returning(self, connection): + result = connection.execute( + self.tables.enum_table.insert().returning( + self.tables.enum_table.c.enum_data + ), + [ + {"id": 1, "enum_data": "b%percent"}, + {"id": 2, "enum_data": "réveillé"}, + {"id": 3, "enum_data": "b"}, + {"id": 4, "enum_data": "a%"}, + ], + ) + + eq_(result.scalars().all(), ["b%percent", "réveillé", "b", "a%"]) + + class UuidTest(_LiteralRoundTripFixture, fixtures.TablesTest): __backend__ = True @@ -2060,6 +2139,7 @@ class NativeUUIDTest(UuidTest): "DateHistoricTest", "StringTest", "BooleanTest", + "EnumTest", "UuidTest", "NativeUUIDTest", ) diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py index 01597893727..c8dd3350588 100644 --- a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py @@ -1,3 +1,9 @@ +# testing/suite/test_unicode_ddl.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index 2d13bda34ae..85a8d393391 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -1,3 +1,9 @@ +# testing/suite/test_update_delete.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from .. import fixtures @@ -87,9 +93,11 @@ def test_update_returning(self, connection, criteria): eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (2, "d2_new"), (3, "d3")] - if criteria.rows - else [(1, "d1"), (2, "d2"), (3, "d3")], + ( + [(1, "d1"), (2, "d2_new"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), ) @testing.variation("criteria", ["rows", "norows", "emptyin"]) @@ -120,9 +128,11 @@ def test_delete_returning(self, connection, criteria): eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (3, "d3")] - if criteria.rows - else [(1, "d1"), (2, "d2"), (3, "d3")], + ( + [(1, "d1"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), ) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index cf24b43a969..42f077108f5 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -1,5 +1,5 @@ # testing/util.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,13 +10,16 @@ from __future__ import annotations from collections import deque +import contextlib import decimal import gc from itertools import chain import random import sys from sys import getsizeof +import time import types +from typing import Any from . import config from . import mock @@ -251,18 +254,19 @@ def flag_combinations(*combinations): dict(lazy=False, passive=True), dict(lazy=False, passive=True, raiseload=True), ) - + def test_fn(lazy, passive, raiseload): ... would result in:: @testing.combinations( - ('', False, False, False), - ('lazy', True, False, False), - ('lazy_passive', True, True, False), - ('lazy_passive', True, True, True), - id_='iaaa', - argnames='lazy,passive,raiseload' + ("", False, False, False), + ("lazy", True, False, False), + ("lazy_passive", True, True, False), + ("lazy_passive", True, True, True), + id_="iaaa", + argnames="lazy,passive,raiseload", ) + def test_fn(lazy, passive, raiseload): ... """ @@ -517,3 +521,18 @@ def count_cache_key_tuples(tup): if elem: stack = list(elem) + [sentinel] + stack return num_elements + + +@contextlib.contextmanager +def skip_if_timeout(seconds: float, cleanup: Any = None): + + now = time.time() + yield + sec = time.time() - now + if sec > seconds: + try: + cleanup() + finally: + config.skip_test( + f"test took too long ({sec:.4f} seconds > {seconds})" + ) diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 6a2ac08e39e..9be0813b584 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -1,5 +1,5 @@ # testing/warnings.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index dfe6d2edb7c..bb2c2e11de3 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -1,13 +1,11 @@ # types.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Compatibility namespace for sqlalchemy.sql.types. - -""" +"""Compatibility namespace for sqlalchemy.sql.types.""" from __future__ import annotations diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c804f968878..1ccebc47fce 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -1,5 +1,5 @@ # util/__init__.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,6 @@ from collections import defaultdict as defaultdict from functools import partial as partial from functools import update_wrapper as update_wrapper -from typing import TYPE_CHECKING from . import preloaded as preloaded from ._collections import coerce_generator_arg as coerce_generator_arg @@ -49,7 +48,6 @@ from ._collections import WeakSequence as WeakSequence from .compat import anext_ as anext_ from .compat import arm as arm -from .compat import athrow as athrow from .compat import b as b from .compat import b64decode as b64decode from .compat import b64encode as b64encode @@ -66,6 +64,8 @@ from .compat import py310 as py310 from .compat import py311 as py311 from .compat import py312 as py312 +from .compat import py313 as py313 +from .compat import py314 as py314 from .compat import py38 as py38 from .compat import py39 as py39 from .compat import pypy as pypy @@ -157,3 +157,4 @@ from .langhelpers import warn_limited as warn_limited from .langhelpers import wrap_callable as wrap_callable from .preloaded import preload_module as preload_module +from .typing import is_non_string_iterable as is_non_string_iterable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index a0b1977ee50..c5e00a636d7 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -1,5 +1,5 @@ # util/_collections.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,6 @@ """Collection classes and helpers.""" from __future__ import annotations -import collections.abc as collections_abc import operator import threading import types @@ -17,6 +16,7 @@ from typing import Any from typing import Callable from typing import cast +from typing import Container from typing import Dict from typing import FrozenSet from typing import Generic @@ -36,6 +36,7 @@ import weakref from ._has_cy import HAS_CYEXTENSION +from .typing import is_non_string_iterable from .typing import Literal from .typing import Protocol @@ -79,8 +80,8 @@ def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: Example:: - >>> a = ['__tablename__', 'id', 'x', 'created_at'] - >>> b = ['id', 'name', 'data', 'y', 'created_at'] + >>> a = ["__tablename__", "id", "x", "created_at"] + >>> b = ["id", "name", "data", "y", "created_at"] >>> merge_lists_w_ordering(a, b) ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] @@ -227,12 +228,10 @@ def update(self, value: Dict[str, _T]) -> None: self._data.update(value) @overload - def get(self, key: str) -> Optional[_T]: - ... + def get(self, key: str) -> Optional[_T]: ... @overload - def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: - ... + def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... def get( self, key: str, default: Optional[Union[_DT, _T]] = None @@ -419,9 +418,7 @@ def coerce_generator_arg(arg: Any) -> List[Any]: def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: if x is None: return default # type: ignore - if not isinstance(x, collections_abc.Iterable) or isinstance( - x, (str, bytes) - ): + if not is_non_string_iterable(x): return [x] elif isinstance(x, list): return x @@ -429,15 +426,14 @@ def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: return list(x) -def has_intersection(set_, iterable): +def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: r"""return True if any items of set\_ are present in iterable. Goes through special effort to ensure __hash__ is not called on items in iterable that don't support it. """ - # TODO: optimize, write in C, etc. - return bool(set_.intersection([i for i in iterable if i.__hash__])) + return any(i in set_ for i in iterable if i.__hash__) def to_set(x): @@ -458,7 +454,9 @@ def to_column_set(x: Any) -> Set[Any]: return x -def update_copy(d, _new=None, **kw): +def update_copy( + d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any +) -> Dict[Any, Any]: """Copy the given dict and update with the given values.""" d = d.copy() @@ -522,12 +520,10 @@ def _inc_counter(self): return self._counter @overload - def get(self, key: _KT) -> Optional[_VT]: - ... + def get(self, key: _KT) -> Optional[_VT]: ... @overload - def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: - ... + def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... def get( self, key: _KT, default: Optional[Union[_VT, _T]] = None @@ -589,13 +585,11 @@ def _manage_size(self) -> None: class _CreateFuncType(Protocol[_T_co]): - def __call__(self) -> _T_co: - ... + def __call__(self) -> _T_co: ... class _ScopeFuncType(Protocol): - def __call__(self) -> Any: - ... + def __call__(self) -> Any: ... class ScopedRegistry(Generic[_T]): diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 71d10a68579..718c077c0da 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -1,5 +1,5 @@ # util/_concurrency_py3k.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -19,10 +19,14 @@ from typing import Optional from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .langhelpers import memoized_property from .. import exc +from ..util import py311 +from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import Self from ..util.typing import TypeGuard _T = TypeVar("_T") @@ -33,8 +37,7 @@ class greenlet(Protocol): dead: bool gr_context: Optional[Context] - def __init__(self, fn: Callable[..., Any], driver: greenlet): - ... + def __init__(self, fn: Callable[..., Any], driver: greenlet): ... def throw(self, *arg: Any) -> Any: return None @@ -42,8 +45,7 @@ def throw(self, *arg: Any) -> Any: def switch(self, value: Any) -> Any: return None - def getcurrent() -> greenlet: - ... + def getcurrent() -> greenlet: ... else: from greenlet import getcurrent @@ -72,9 +74,10 @@ def is_exit_exception(e: BaseException) -> bool: class _AsyncIoGreenlet(greenlet): dead: bool + __sqlalchemy_greenlet_provider__ = True + def __init__(self, fn: Callable[..., Any], driver: greenlet): greenlet.__init__(self, fn, driver) - self.driver = driver if _has_gr_context: self.gr_context = driver.gr_context @@ -85,8 +88,7 @@ def __init__(self, fn: Callable[..., Any], driver: greenlet): def iscoroutine( awaitable: Awaitable[_T_co], - ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: - ... + ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ... else: iscoroutine = asyncio.iscoroutine @@ -99,6 +101,11 @@ def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None: awaitable.close() +def in_greenlet() -> bool: + current = getcurrent() + return getattr(current, "__sqlalchemy_greenlet_provider__", False) + + def await_only(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. @@ -110,7 +117,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T: """ # this is called in the context greenlet while running fn current = getcurrent() - if not isinstance(current, _AsyncIoGreenlet): + if not getattr(current, "__sqlalchemy_greenlet_provider__", False): _safe_cancel_awaitable(awaitable) raise exc.MissingGreenlet( @@ -122,7 +129,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T: # a coroutine to run. Once the awaitable is done, the driver greenlet # switches back to this greenlet with the result of awaitable that is # then returned to the caller (or raised as error) - return current.driver.switch(awaitable) # type: ignore[no-any-return] + return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501 def await_fallback(awaitable: Awaitable[_T]) -> _T: @@ -133,11 +140,16 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: :param awaitable: The coroutine to call. + .. deprecated:: 2.0.24 The ``await_fallback()`` function will be removed + in SQLAlchemy 2.1. Use :func:`_util.await_only` instead, running the + function / program / etc. within a top-level greenlet that is set up + using :func:`_util.greenlet_spawn`. + """ # this is called in the context greenlet while running fn current = getcurrent() - if not isinstance(current, _AsyncIoGreenlet): + if not getattr(current, "__sqlalchemy_greenlet_provider__", False): loop = get_event_loop() if loop.is_running(): _safe_cancel_awaitable(awaitable) @@ -149,7 +161,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: ) return loop.run_until_complete(awaitable) - return current.driver.switch(awaitable) # type: ignore[no-any-return] + return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501 async def greenlet_spawn( @@ -175,24 +187,21 @@ async def greenlet_spawn( # coroutine to wait. If the context is dead the function has # returned, and its result can be returned. switch_occurred = False - try: - result = context.switch(*args, **kwargs) - while not context.dead: - switch_occurred = True - try: - # wait for a coroutine from await_only and then return its - # result back to it. - value = await result - except BaseException: - # this allows an exception to be raised within - # the moderated greenlet so that it can continue - # its expected flow. - result = context.throw(*sys.exc_info()) - else: - result = context.switch(value) - finally: - # clean up to avoid cycle resolution by gc - del context.driver + result = context.switch(*args, **kwargs) + while not context.dead: + switch_occurred = True + try: + # wait for a coroutine from await_only and then return its + # result back to it. + value = await result + except BaseException: + # this allows an exception to be raised within + # the moderated greenlet so that it can continue + # its expected flow. + result = context.throw(*sys.exc_info()) + else: + result = context.switch(value) + if _require_await and not switch_occurred: raise exc.AwaitRequired( "The current operation required an async execution but none was " @@ -218,34 +227,6 @@ def __exit__(self, *arg: Any, **kw: Any) -> None: self.mutex.release() -def _util_async_run_coroutine_function( - fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any -) -> Any: - """for test suite/ util only""" - - loop = get_event_loop() - if loop.is_running(): - raise Exception( - "for async run coroutine we expect that no greenlet or event " - "loop is running when we start out" - ) - return loop.run_until_complete(fn(*args, **kwargs)) - - -def _util_async_run( - fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any -) -> Any: - """for test suite/ util only""" - - loop = get_event_loop() - if not loop.is_running(): - return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) - else: - # allow for a wrapped test function to call another - assert isinstance(getcurrent(), _AsyncIoGreenlet) - return fn(*args, **kwargs) - - def get_event_loop() -> asyncio.AbstractEventLoop: """vendor asyncio.get_event_loop() for python 3.7 and above. @@ -258,3 +239,50 @@ def get_event_loop() -> asyncio.AbstractEventLoop: # avoid "During handling of the above exception, another exception..." pass return asyncio.get_event_loop_policy().get_event_loop() + + +if not TYPE_CHECKING and py311: + _Runner = asyncio.Runner +else: + + class _Runner: + """Runner implementation for test only""" + + _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]] + + def __init__(self) -> None: + self._loop = None + + def __enter__(self) -> Self: + self._lazy_init() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def close(self) -> None: + if self._loop: + try: + self._loop.run_until_complete( + self._loop.shutdown_asyncgens() + ) + finally: + self._loop.close() + self._loop = False + + def get_loop(self) -> asyncio.AbstractEventLoop: + """Return embedded event loop.""" + self._lazy_init() + assert self._loop + return self._loop + + def run(self, coro: Coroutine[Any, Any, _T]) -> _T: + self._lazy_init() + assert self._loop + return self._loop.run_until_complete(coro) + + def _lazy_init(self) -> None: + if self._loop is False: + raise RuntimeError("Runner is closed") + if self._loop is None: + self._loop = asyncio.new_event_loop() diff --git a/lib/sqlalchemy/util/_has_cy.py b/lib/sqlalchemy/util/_has_cy.py index 37f716ad3b9..21faed04e6b 100644 --- a/lib/sqlalchemy/util/_has_cy.py +++ b/lib/sqlalchemy/util/_has_cy.py @@ -1,4 +1,5 @@ -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# util/_has_cy.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 4f52d3bce67..f6aefcf67c3 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,5 +1,5 @@ # util/_py_collections.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -59,11 +59,9 @@ def __setattr__(self, key: str, value: Any) -> NoReturn: class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]): if TYPE_CHECKING: - def __new__(cls, *args: Any) -> Self: - ... + def __new__(cls, *args: Any) -> Self: ... - def __init__(cls, *args: Any): - ... + def __init__(cls, *args: Any): ... def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: self._immutable() @@ -148,12 +146,16 @@ def __ior__(self, __value: Any) -> NoReturn: # type: ignore def __or__( # type: ignore[override] self, __value: Mapping[_KT, _VT] ) -> immutabledict[_KT, _VT]: - return immutabledict(super().__or__(__value)) + return immutabledict( + super().__or__(__value), # type: ignore[call-overload] + ) def __ror__( # type: ignore[override] self, __value: Mapping[_KT, _VT] ) -> immutabledict[_KT, _VT]: - return immutabledict(super().__ror__(__value)) + return immutabledict( + super().__ror__(__value), # type: ignore[call-overload] + ) class OrderedSet(Set[_T]): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 98a0b65ec95..2ee47031184 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -1,5 +1,5 @@ # util/compat.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -19,8 +19,6 @@ import sys import typing from typing import Any -from typing import AsyncGenerator -from typing import Awaitable from typing import Callable from typing import Dict from typing import Iterable @@ -34,6 +32,9 @@ from typing import TypeVar +py314b1 = sys.version_info >= (3, 14, 0, "beta", 1) +py314 = sys.version_info >= (3, 14) +py313 = sys.version_info >= (3, 13) py312 = sys.version_info >= (3, 12) py311 = sys.version_info >= (3, 11) py310 = sys.version_info >= (3, 10) @@ -60,7 +61,7 @@ class FullArgSpec(typing.NamedTuple): varkw: Optional[str] defaults: Optional[Tuple[Any, ...]] kwonlyargs: List[str] - kwonlydefaults: Dict[str, Any] + kwonlydefaults: Optional[Dict[str, Any]] annotations: Dict[str, Any] @@ -102,24 +103,6 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec: ) -if py312: - # we are 95% certain this form of athrow works in former Python - # versions, however we are unable to get confirmation; - # see https://github.com/python/cpython/issues/105269 where have - # been unable to get a straight answer so far - def athrow( # noqa - gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any - ) -> Awaitable[_T_co]: - return gen.athrow(value) - -else: - - def athrow( # noqa - gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any - ) -> Awaitable[_T_co]: - return gen.athrow(typ, value, traceback) - - if py39: # python stubs don't have a public type for this. not worth # making a protocol @@ -173,7 +156,7 @@ async def anext_(async_iterator, default=_NOT_PROVIDED): def importlib_metadata_get(group): ep = importlib_metadata.entry_points() - if not typing.TYPE_CHECKING and hasattr(ep, "select"): + if typing.TYPE_CHECKING or hasattr(ep, "select"): return ep.select(group=group) else: return ep.get(group, ()) diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 53a70070b76..006340f5bf3 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -1,5 +1,5 @@ # util/concurrency.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,11 +10,15 @@ import asyncio # noqa import typing +from typing import Any +from typing import Callable +from typing import Coroutine +from typing import TypeVar have_greenlet = False greenlet_error = None try: - import greenlet # type: ignore # noqa: F401 + import greenlet # type: ignore[import-untyped,unused-ignore] # noqa: F401,E501 except ImportError as e: greenlet_error = str(e) pass @@ -22,15 +26,47 @@ have_greenlet = True from ._concurrency_py3k import await_only as await_only from ._concurrency_py3k import await_fallback as await_fallback + from ._concurrency_py3k import in_greenlet as in_greenlet from ._concurrency_py3k import greenlet_spawn as greenlet_spawn from ._concurrency_py3k import is_exit_exception as is_exit_exception from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock - from ._concurrency_py3k import ( - _util_async_run as _util_async_run, - ) # noqa: F401 - from ._concurrency_py3k import ( - _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa: F401, E501 - ) + from ._concurrency_py3k import _Runner + +_T = TypeVar("_T") + + +class _AsyncUtil: + """Asyncio util for test suite/ util only""" + + def __init__(self) -> None: + if have_greenlet: + self.runner = _Runner() + + def run( + self, + fn: Callable[..., Coroutine[Any, Any, _T]], + *args: Any, + **kwargs: Any, + ) -> _T: + """Run coroutine on the loop""" + return self.runner.run(fn(*args, **kwargs)) + + def run_in_greenlet( + self, fn: Callable[..., _T], *args: Any, **kwargs: Any + ) -> _T: + """Run sync function in greenlet. Support nested calls""" + if have_greenlet: + if self.runner.get_loop().is_running(): + return fn(*args, **kwargs) + else: + return self.runner.run(greenlet_spawn(fn, *args, **kwargs)) + else: + return fn(*args, **kwargs) + + def close(self) -> None: + if have_greenlet: + self.runner.close() + if not typing.TYPE_CHECKING and not have_greenlet: @@ -56,6 +92,9 @@ def await_only(thing): # type: ignore # noqa: F811 def await_fallback(thing): # type: ignore # noqa: F811 return thing + def in_greenlet(): # type: ignore # noqa: F811 + _not_implemented() + def greenlet_spawn(fn, *args, **kw): # type: ignore # noqa: F811 _not_implemented() diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 26d9924898b..88b68724038 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -1,5 +1,5 @@ # util/deprecations.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -205,10 +205,10 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: weak_identity_map=( "0.7", "the :paramref:`.Session.weak_identity_map parameter " - "is deprecated." + "is deprecated.", ) - ) + def some_function(**kwargs): ... """ diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9c56487c400..ebdd8ffa045 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1,5 +1,5 @@ # util/langhelpers.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -60,7 +60,85 @@ _HM = TypeVar("_HM", bound="hybridmethod[Any]") -if compat.py310: +if compat.py314: + # vendor a minimal form of get_annotations per + # https://github.com/python/cpython/issues/133684#issuecomment-2863841891 + + from annotationlib import call_annotate_function # type: ignore + from annotationlib import Format + + def _get_and_call_annotate(obj, format): # noqa: A002 + annotate = getattr(obj, "__annotate__", None) + if annotate is not None: + ann = call_annotate_function(annotate, format, owner=obj) + if not isinstance(ann, dict): + raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") + return ann + return None + + # this is ported from py3.13.0a7 + _BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ # type: ignore # noqa: E501 + + def _get_dunder_annotations(obj): + if isinstance(obj, type): + try: + ann = _BASE_GET_ANNOTATIONS(obj) + except AttributeError: + # For static types, the descriptor raises AttributeError. + return {} + else: + ann = getattr(obj, "__annotations__", None) + if ann is None: + return {} + + if not isinstance(ann, dict): + raise ValueError( + f"{obj!r}.__annotations__ is neither a dict nor None" + ) + return dict(ann) + + def _vendored_get_annotations( + obj: Any, *, format: Format # noqa: A002 + ) -> Mapping[str, Any]: + """A sparse implementation of annotationlib.get_annotations()""" + + try: + ann = _get_dunder_annotations(obj) + except Exception: + pass + else: + if ann is not None: + return dict(ann) + + # But if __annotations__ threw a NameError, we try calling __annotate__ + ann = _get_and_call_annotate(obj, format) + if ann is None: + # If that didn't work either, we have a very weird object: + # evaluating + # __annotations__ threw NameError and there is no __annotate__. + # In that case, + # we fall back to trying __annotations__ again. + ann = _get_dunder_annotations(obj) + + if ann is None: + if isinstance(obj, type) or callable(obj): + return {} + raise TypeError(f"{obj!r} does not have annotations") + + if not ann: + return {} + + return dict(ann) + + def get_annotations(obj: Any) -> Mapping[str, Any]: + # FORWARDREF has the effect of giving us ForwardRefs and not + # actually trying to evaluate the annotations. We need this so + # that the annotations act as much like + # "from __future__ import annotations" as possible, which is going + # away in future python as a separate mode + return _vendored_get_annotations(obj, format=Format.FORWARDREF) + +elif compat.py310: def get_annotations(obj: Any) -> Mapping[str, Any]: return inspect.get_annotations(obj) @@ -174,10 +252,11 @@ def string_or_unprintable(element: Any) -> str: return "unprintable element %r" % element -def clsname_as_plain_name(cls: Type[Any]) -> str: - return " ".join( - n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", cls.__name__) - ) +def clsname_as_plain_name( + cls: Type[Any], use_name: Optional[str] = None +) -> str: + name = use_name or cls.__name__ + return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name)) def method_is_overridden( @@ -249,10 +328,30 @@ def decorate(fn: _Fn) -> _Fn: if not inspect.isfunction(fn) and not inspect.ismethod(fn): raise Exception("not a decoratable function") - spec = compat.inspect_getfullargspec(fn) - env: Dict[str, Any] = {} + # Python 3.14 defer creating __annotations__ until its used. + # We do not want to create __annotations__ now. + annofunc = getattr(fn, "__annotate__", None) + if annofunc is not None: + fn.__annotate__ = None # type: ignore[union-attr] + try: + spec = compat.inspect_getfullargspec(fn) + finally: + fn.__annotate__ = annofunc # type: ignore[union-attr] + else: + spec = compat.inspect_getfullargspec(fn) - spec = _update_argspec_defaults_into_env(spec, env) + # Do not generate code for annotations. + # update_wrapper() copies the annotation from fn to decorated. + # We use dummy defaults for code generation to avoid having + # copy of large globals for compiling. + # We copy __defaults__ and __kwdefaults__ from fn to decorated. + empty_defaults = (None,) * len(spec.defaults or ()) + empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ()) + spec = spec._replace( + annotations={}, + defaults=empty_defaults, + kwonlydefaults=empty_kwdefaults, + ) names = ( tuple(cast("Tuple[str, ...]", spec[0])) @@ -297,41 +396,21 @@ def decorate(fn: _Fn) -> _Fn: % metadata ) - mod = sys.modules[fn.__module__] - env.update(vars(mod)) - env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__}) + env: Dict[str, Any] = { + targ_name: target, + fn_name: fn, + "__name__": fn.__module__, + } decorated = cast( types.FunctionType, _exec_code_in_env(code, env, fn.__name__), ) - decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ - - decorated.__wrapped__ = fn # type: ignore - return cast(_Fn, update_wrapper(decorated, fn)) - - return update_wrapper(decorate, target) + decorated.__defaults__ = fn.__defaults__ + decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore + return update_wrapper(decorated, fn) # type: ignore[return-value] - -def _update_argspec_defaults_into_env(spec, env): - """given a FullArgSpec, convert defaults to be symbol names in an env.""" - - if spec.defaults: - new_defaults = [] - i = 0 - for arg in spec.defaults: - if type(arg).__module__ not in ("builtins", "__builtin__"): - name = "x%d" % i - env[name] = arg - new_defaults.append(name) - i += 1 - else: - new_defaults.append(arg) - elem = list(spec) - elem[3] = tuple(new_defaults) - return compat.FullArgSpec(*elem) - else: - return spec + return update_wrapper(decorate, target) # type: ignore[return-value] def _exec_code_in_env( @@ -384,6 +463,9 @@ def load(): self.impls[name] = load + def deregister(self, name: str) -> None: + del self.impls[name] + def _inspect_func_args(fn): try: @@ -411,15 +493,13 @@ def get_cls_kwargs( *, _set: Optional[Set[str]] = None, raiseerr: Literal[True] = ..., -) -> Set[str]: - ... +) -> Set[str]: ... @overload def get_cls_kwargs( cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False -) -> Optional[Set[str]]: - ... +) -> Optional[Set[str]]: ... def get_cls_kwargs( @@ -663,7 +743,9 @@ def format_argspec_init(method, grouped=True): """format_argspec_plus with considerations for typical __init__ methods Wraps format_argspec_plus with error handling strategies for typical - __init__ cases:: + __init__ cases: + + .. sourcecode:: text object.__init__ -> (self) other unreflectable (usually C) -> (self, *args, **kwargs) @@ -718,7 +800,9 @@ def decorate(cls): def getargspec_init(method): """inspect.getargspec with considerations for typical __init__ methods - Wraps inspect.getargspec with error handling for typical __init__ cases:: + Wraps inspect.getargspec with error handling for typical __init__ cases: + + .. sourcecode:: text object.__init__ -> (self) other unreflectable (usually C) -> (self, *args, **kwargs) @@ -1092,23 +1176,19 @@ def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None): self.__name__ = fget.__name__ @overload - def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: - ... + def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ... @overload - def __get__(self, obj: object, cls: Any) -> _T_co: - ... + def __get__(self, obj: object, cls: Any) -> _T_co: ... def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]: raise NotImplementedError() if TYPE_CHECKING: - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... def _reset(self, obj: Any) -> None: raise NotImplementedError() @@ -1247,12 +1327,10 @@ def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None): self.__name__ = fget.__name__ @overload - def __get__(self: _MA, obj: None, cls: Any) -> _MA: - ... + def __get__(self: _MA, obj: None, cls: Any) -> _MA: ... @overload - def __get__(self, obj: Any, cls: Any) -> _T: - ... + def __get__(self, obj: Any, cls: Any) -> _T: ... def __get__(self, obj, cls): if obj is None: @@ -1598,9 +1676,9 @@ def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]: class symbol(int): """A constant symbol. - >>> symbol('foo') is symbol('foo') + >>> symbol("foo") is symbol("foo") True - >>> symbol('foo') + >>> symbol("foo") A slight refinement of the MAGICCOOKIE=object() pattern. The primary @@ -1666,6 +1744,8 @@ def __init__( items: List[symbol] cls._items = items = [] for k, v in dict_.items(): + if re.match(r"^__.*__$", k): + continue if isinstance(v, int): sym = symbol(k, canonical=v) elif not k.startswith("_"): @@ -1959,12 +2039,15 @@ def chop_traceback( def attrsetter(attrname): - code = "def set(obj, value):" " obj.%s = value" % attrname + code = "def set(obj, value): obj.%s = value" % attrname env = locals().copy() exec(code, env) return env["set"] +_dunders = re.compile("^__.+__$") + + class TypingOnly: """A mixin class that marks a class as 'typing only', meaning it has absolutely no methods, attributes, or runtime functionality whatsoever. @@ -1975,15 +2058,9 @@ class TypingOnly: def __init_subclass__(cls) -> None: if TypingOnly in cls.__bases__: - remaining = set(cls.__dict__).difference( - { - "__module__", - "__doc__", - "__slots__", - "__orig_bases__", - "__annotations__", - } - ) + remaining = { + name for name in cls.__dict__ if not _dunders.match(name) + } if remaining: raise AssertionError( f"Class {cls} directly inherits TypingOnly but has " @@ -2216,3 +2293,11 @@ def has_compiled_ext(raise_=False): ) else: return False + + +class _Missing(enum.Enum): + Missing = enum.auto() + + +Missing = _Missing.Missing +MissingOr = Union[_T, Literal[_Missing.Missing]] diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index f3609c8e472..4ea9aa90f30 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -1,5 +1,5 @@ -# util/_preloaded.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# util/preloaded.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index b641c910c71..3fb01a9a9f8 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -1,5 +1,5 @@ # util/queue.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -57,8 +57,7 @@ class QueueCommon(Generic[_T]): maxsize: int use_lifo: bool - def __init__(self, maxsize: int = 0, use_lifo: bool = False): - ... + def __init__(self, maxsize: int = 0, use_lifo: bool = False): ... def empty(self) -> bool: raise NotImplementedError() @@ -242,8 +241,7 @@ class AsyncAdaptedQueue(QueueCommon[_T]): if typing.TYPE_CHECKING: @staticmethod - def await_(coroutine: Awaitable[Any]) -> _T: - ... + def await_(coroutine: Awaitable[Any]) -> _T: ... else: await_ = staticmethod(await_only) diff --git a/lib/sqlalchemy/util/tool_support.py b/lib/sqlalchemy/util/tool_support.py index 5a2fc3ba051..407c2d45075 100644 --- a/lib/sqlalchemy/util/tool_support.py +++ b/lib/sqlalchemy/util/tool_support.py @@ -1,5 +1,5 @@ # util/tool_support.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -27,6 +27,7 @@ from typing import Dict from typing import Iterator from typing import Optional +from typing import Union from . import compat @@ -121,7 +122,7 @@ def write_status(self, *text: str) -> None: sys.stderr.write(" ".join(text)) def write_output_file_from_text( - self, text: str, destination_path: str + self, text: str, destination_path: Union[str, Path] ) -> None: if self.args.check: self._run_diff(destination_path, source=text) @@ -129,7 +130,9 @@ def write_output_file_from_text( print(text) else: self.write_status(f"Writing {destination_path}...") - Path(destination_path).write_text(text) + Path(destination_path).write_text( + text, encoding="utf-8", newline="\n" + ) self.write_status("done\n") def write_output_file_from_tempfile( @@ -149,24 +152,24 @@ def write_output_file_from_tempfile( def _run_diff( self, - destination_path: str, + destination_path: Union[str, Path], *, source: Optional[str] = None, source_file: Optional[str] = None, ) -> None: if source_file: - with open(source_file) as tf: + with open(source_file, encoding="utf-8") as tf: source_lines = list(tf) elif source is not None: source_lines = source.splitlines(keepends=True) else: assert False, "source or source_file is required" - with open(destination_path) as dp: + with open(destination_path, encoding="utf-8") as dp: d = difflib.unified_diff( list(dp), source_lines, - fromfile=destination_path, + fromfile=Path(destination_path).as_posix(), tofile="", n=3, lineterm="\n", diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 8c6a663f602..82f22a01957 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -1,5 +1,5 @@ # util/topological.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -112,7 +112,7 @@ def find_cycles( todo.remove(node) break else: - node = stack.pop() + stack.pop() return output diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 3d15d43db76..794dd18591c 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,5 +1,5 @@ # util/typing.py -# Copyright (C) 2022 the SQLAlchemy authors and contributors +# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,12 +9,13 @@ from __future__ import annotations import builtins +from collections import deque +import collections.abc as collections_abc import re import sys import typing from typing import Any from typing import Callable -from typing import cast from typing import Dict from typing import ForwardRef from typing import Generic @@ -31,6 +32,8 @@ from typing import TypeVar from typing import Union +import typing_extensions + from . import compat if True: # zimports removes the tailing comments @@ -52,7 +55,9 @@ from typing_extensions import TypedDict as TypedDict # 3.8 from typing_extensions import TypeGuard as TypeGuard # 3.10 from typing_extensions import Self as Self # 3.11 - + from typing_extensions import TypeAliasType as TypeAliasType # 3.12 + from typing_extensions import Never as Never # 3.11 + from typing_extensions import LiteralString as LiteralString # 3.11 _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT") @@ -61,7 +66,6 @@ _VT = TypeVar("_VT") _VT_co = TypeVar("_VT_co", covariant=True) - if compat.py310: # why they took until py310 to put this in stdlib is beyond me, # I've been wanting it since py27 @@ -69,18 +73,17 @@ else: NoneType = type(None) # type: ignore -NoneFwd = ForwardRef("None") -typing_get_args = get_args -typing_get_origin = get_origin +def is_fwd_none(typ: Any) -> bool: + return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None" _AnnotationScanType = Union[ - Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]" + Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]" ] -class ArgsTypeProcotol(Protocol): +class ArgsTypeProtocol(Protocol): """protocol for types that have ``__args__`` there's no public interface for this AFAIK @@ -111,11 +114,9 @@ class GenericProtocol(Protocol[_T]): # copied from TypeShed, required in order to implement # MutableMapping.update() class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): - def keys(self) -> Iterable[_KT]: - ... + def keys(self) -> Iterable[_KT]: ... - def __getitem__(self, __k: _KT) -> _VT_co: - ... + def __getitem__(self, __k: _KT) -> _VT_co: ... # work around https://github.com/microsoft/pyright/issues/3025 @@ -155,7 +156,7 @@ def de_stringify_annotation( annotation = str_cleanup_fn(annotation, originating_module) annotation = eval_expression( - annotation, originating_module, locals_=locals_ + annotation, originating_module, locals_=locals_, in_class=cls ) if ( @@ -189,9 +190,51 @@ def de_stringify_annotation( ) return _copy_generic_annotation_with(annotation, elements) + return annotation # type: ignore +def fixup_container_fwd_refs( + type_: _AnnotationScanType, +) -> _AnnotationScanType: + """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')] + and similar for list, set + + """ + + if ( + is_generic(type_) + and get_origin(type_) + in ( + dict, + set, + list, + collections_abc.MutableSet, + collections_abc.MutableMapping, + collections_abc.MutableSequence, + collections_abc.Mapping, + collections_abc.Sequence, + ) + # fight, kick and scream to struggle to tell the difference between + # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT + # behave the same yet there is NO WAY to distinguish between which type + # it is using public attributes + and not re.match( + "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_) + ) + ): + # compat with py3.10 and earlier + return get_origin(type_).__class_getitem__( # type: ignore + tuple( + [ + ForwardRef(elem) if isinstance(elem, str) else elem + for elem in get_args(type_) + ] + ) + ) + return type_ + + def _copy_generic_annotation_with( annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...] ) -> Type[_T]: @@ -208,6 +251,7 @@ def eval_expression( module_name: str, *, locals_: Optional[Mapping[str, Any]] = None, + in_class: Optional[Type[Any]] = None, ) -> Any: try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ @@ -218,7 +262,18 @@ def eval_expression( ) from ke try: - annotation = eval(expression, base_globals, locals_) + if in_class is not None: + cls_namespace = dict(in_class.__dict__) + cls_namespace.setdefault(in_class.__name__, in_class) + + # see #10899. We want the locals/globals to take precedence + # over the class namespace in this context, even though this + # is not the usual way variables would resolve. + cls_namespace.update(base_globals) + + annotation = eval(expression, cls_namespace, locals_) + else: + annotation = eval(expression, base_globals, locals_) except Exception as err: raise NameError( f"Could not de-stringify annotation {expression!r}" @@ -270,34 +325,18 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: return getattr(obj, "__name__", name) -def de_stringify_union_elements( - cls: Type[Any], - annotation: ArgsTypeProcotol, - originating_module: str, - locals_: Mapping[str, Any], - *, - str_cleanup_fn: Optional[Callable[[str, str], str]] = None, -) -> Type[Any]: - return make_union_type( - *[ - de_stringify_annotation( - cls, - anno, - originating_module, - {}, - str_cleanup_fn=str_cleanup_fn, - ) - for anno in annotation.__args__ - ] - ) +def is_pep593(type_: Optional[Any]) -> bool: + return type_ is not None and get_origin(type_) in _type_tuples.Annotated -def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: - return type_ is not None and typing_get_origin(type_) is Annotated +def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: + return isinstance(obj, collections_abc.Iterable) and not isinstance( + obj, (str, bytes) + ) -def is_literal(type_: _AnnotationScanType) -> bool: - return get_origin(type_) is Literal +def is_literal(type_: Any) -> bool: + return get_origin(type_) in _type_tuples.Literal def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: @@ -305,46 +344,99 @@ def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: # doesn't work in 3.8, 3.7 as it passes a closure, not an # object instance - # return isinstance(type_, NewType) + # isinstance(type, type_instances.NewType) def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: return hasattr(type_, "__args__") and hasattr(type_, "__origin__") +def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]: + # NOTE: a generic TAT does not instance check as TypeAliasType outside of + # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT + # though. + # NOTE: things seems to work also without this additional check + if is_generic(type_): + return is_pep695(type_.__origin__) + return isinstance(type_, _type_instances.TypeAliasType) + + def flatten_newtype(type_: NewType) -> Type[Any]: super_type = type_.__supertype__ while is_newtype(super_type): super_type = super_type.__supertype__ - return super_type + return super_type # type: ignore[return-value] + + +def pep695_values(type_: _AnnotationScanType) -> Set[Any]: + """Extracts the value from a TypeAliasType, recursively exploring unions + and inner TypeAliasType to flatten them into a single set. + + Forward references are not evaluated, so no recursive exploration happens + into them. + """ + _seen = set() + + def recursive_value(inner_type): + if inner_type in _seen: + # recursion are not supported (at least it's flagged as + # an error by pyright). Just avoid infinite loop + return inner_type + _seen.add(inner_type) + if not is_pep695(inner_type): + return inner_type + value = inner_type.__value__ + if not is_union(value): + return value + return [recursive_value(t) for t in value.__args__] + + res = recursive_value(type_) + if isinstance(res, list): + types = set() + stack = deque(res) + while stack: + t = stack.popleft() + if isinstance(t, list): + stack.extend(t) + else: + types.add(None if t is NoneType or is_fwd_none(t) else t) + return types + else: + return {res} def is_fwd_ref( - type_: _AnnotationScanType, check_generic: bool = False + type_: _AnnotationScanType, + check_generic: bool = False, + check_for_plain_string: bool = False, ) -> TypeGuard[ForwardRef]: - if isinstance(type_, ForwardRef): + if check_for_plain_string and isinstance(type_, str): + return True + elif isinstance(type_, _type_instances.ForwardRef): return True elif check_generic and is_generic(type_): - return any(is_fwd_ref(arg, True) for arg in type_.__args__) + return any( + is_fwd_ref( + arg, True, check_for_plain_string=check_for_plain_string + ) + for arg in type_.__args__ + ) else: return False @overload -def de_optionalize_union_types(type_: str) -> str: - ... +def de_optionalize_union_types(type_: str) -> str: ... @overload -def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: - ... +def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ... @overload def de_optionalize_union_types( type_: _AnnotationScanType, -) -> _AnnotationScanType: - ... +) -> _AnnotationScanType: ... def de_optionalize_union_types( @@ -353,16 +445,33 @@ def de_optionalize_union_types( """Given a type, filter out ``Union`` types that include ``NoneType`` to not include the ``NoneType``. + Contains extra logic to work on non-flattened unions, unions that contain + ``None`` (seen in py38, 37) + """ if is_fwd_ref(type_): - return de_optionalize_fwd_ref_union_types(type_) + return _de_optionalize_fwd_ref_union_types(type_, False) - elif is_optional(type_): - typ = set(type_.__args__) - - typ.discard(NoneType) - typ.discard(NoneFwd) + elif is_union(type_) and includes_none(type_): + if compat.py39: + typ = set(type_.__args__) + else: + # py38, 37 - unions are not automatically flattened, can contain + # None rather than NoneType + stack_of_unions = deque([type_]) + typ = set() + while stack_of_unions: + u_typ = stack_of_unions.popleft() + for elem in u_typ.__args__: + if is_union(elem): + stack_of_unions.append(elem) + else: + typ.add(elem) + + typ.discard(None) # type: ignore + + typ = {t for t in typ if t is not NoneType and not is_fwd_none(t)} return make_union_type(*typ) @@ -370,9 +479,21 @@ def de_optionalize_union_types( return type_ -def de_optionalize_fwd_ref_union_types( - type_: ForwardRef, -) -> _AnnotationScanType: +@overload +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: Literal[True] +) -> bool: ... + + +@overload +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: Literal[False] +) -> _AnnotationScanType: ... + + +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: bool +) -> Union[_AnnotationScanType, bool]: """return the non-optional type for Optional[], Union[None, ...], x|None, etc. without de-stringifying forward refs. @@ -384,68 +505,94 @@ def de_optionalize_fwd_ref_union_types( mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if mm: - if mm.group(1) == "Optional": - return ForwardRef(mm.group(2)) - elif mm.group(1) == "Union": - elements = re.split(r",\s*", mm.group(2)) - return make_union_type( - *[ForwardRef(elem) for elem in elements if elem != "None"] - ) + g1 = mm.group(1).split(".")[-1] + if g1 == "Optional": + return True if return_has_none else ForwardRef(mm.group(2)) + elif g1 == "Union": + if "[" in mm.group(2): + # cases like "Union[Dict[str, int], int, None]" + elements: list[str] = [] + current: list[str] = [] + ignore_comma = 0 + for char in mm.group(2): + if char == "[": + ignore_comma += 1 + elif char == "]": + ignore_comma -= 1 + elif ignore_comma == 0 and char == ",": + elements.append("".join(current).strip()) + current.clear() + continue + current.append(char) + else: + elements = re.split(r",\s*", mm.group(2)) + parts = [ForwardRef(elem) for elem in elements if elem != "None"] + if return_has_none: + return len(elements) != len(parts) + else: + return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501 else: - return type_ + return False if return_has_none else type_ pipe_tokens = re.split(r"\s*\|\s*", annotation) - if "None" in pipe_tokens: - return ForwardRef("|".join(p for p in pipe_tokens if p != "None")) + has_none = "None" in pipe_tokens + if return_has_none: + return has_none + if has_none: + anno_str = "|".join(p for p in pipe_tokens if p != "None") + return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501 return type_ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: - """Make a Union type. - - This is needed by :func:`.de_optionalize_union_types` which removes - ``NoneType`` from a ``Union``. + """Make a Union type.""" - """ - return cast(Any, Union).__getitem__(types) # type: ignore + return Union[types] # type: ignore -def expand_unions( - type_: Type[Any], include_union: bool = False, discard_none: bool = False -) -> Tuple[Type[Any], ...]: - """Return a type as a tuple of individual types, expanding for - ``Union`` types.""" +def includes_none(type_: Any) -> bool: + """Returns if the type annotation ``type_`` allows ``None``. + This function supports: + * forward refs + * unions + * pep593 - Annotated + * pep695 - TypeAliasType (does not support looking into + fw reference of other pep695) + * NewType + * plain types like ``int``, ``None``, etc + """ + if is_fwd_ref(type_): + return _de_optionalize_fwd_ref_union_types(type_, True) if is_union(type_): - typ = set(type_.__args__) - - if discard_none: - typ.discard(NoneType) - - if include_union: - return (type_,) + tuple(typ) # type: ignore - else: - return tuple(typ) # type: ignore - else: - return (type_,) + return any(includes_none(t) for t in get_args(type_)) + if is_pep593(type_): + return includes_none(get_args(type_)[0]) + if is_pep695(type_): + return any(includes_none(t) for t in pep695_values(type_)) + if is_newtype(type_): + return includes_none(type_.__supertype__) + try: + return type_ in (NoneType, None) or is_fwd_none(type_) + except TypeError: + # if type_ is Column, mapped_column(), etc. the use of "in" + # resolves to ``__eq__()`` which then gives us an expression object + # that can't resolve to boolean. just catch it all via exception + return False -def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: - return is_origin_of( - type_, - "Optional", - "Union", - "UnionType", +def is_a_type(type_: Any) -> bool: + return ( + isinstance(type_, type) + or hasattr(type_, "__origin__") + or type_.__module__ in ("typing", "typing_extensions") + or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions") ) -def is_optional_union(type_: Any) -> bool: - return is_optional(type_) and NoneType in typing_get_args(type_) - - -def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: - return is_origin_of(type_, "Union") +def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]: + return is_origin_of(type_, "Union", "UnionType") def is_origin_of_cls( @@ -454,7 +601,7 @@ def is_origin_of_cls( """return True if the given type has an __origin__ that shares a base with the given class""" - origin = typing_get_origin(type_) + origin = get_origin(type_) if origin is None: return False @@ -467,7 +614,7 @@ def is_origin_of( """return True if the given type has an __origin__ with the given name and optional module.""" - origin = typing_get_origin(type_) + origin = get_origin(type_) if origin is None: return False @@ -488,14 +635,11 @@ def _get_type_name(type_: Type[Any]) -> str: class DescriptorProto(Protocol): - def __get__(self, instance: object, owner: Any) -> Any: - ... + def __get__(self, instance: object, owner: Any) -> Any: ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... _DESC = TypeVar("_DESC", bound=DescriptorProto) @@ -514,14 +658,11 @@ class DescriptorReference(Generic[_DESC]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _DESC: - ... + def __get__(self, instance: object, owner: Any) -> _DESC: ... - def __set__(self, instance: Any, value: _DESC) -> None: - ... + def __set__(self, instance: Any, value: _DESC) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... _DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True) @@ -537,14 +678,11 @@ class RODescriptorReference(Generic[_DESC_co]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _DESC_co: - ... + def __get__(self, instance: object, owner: Any) -> _DESC_co: ... - def __set__(self, instance: Any, value: Any) -> NoReturn: - ... + def __set__(self, instance: Any, value: Any) -> NoReturn: ... - def __delete__(self, instance: Any) -> NoReturn: - ... + def __delete__(self, instance: Any) -> NoReturn: ... _FN = TypeVar("_FN", bound=Optional[Callable[..., Any]]) @@ -561,14 +699,35 @@ class CallableReference(Generic[_FN]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _FN: - ... + def __get__(self, instance: object, owner: Any) -> _FN: ... + + def __set__(self, instance: Any, value: _FN) -> None: ... - def __set__(self, instance: Any, value: _FN) -> None: - ... + def __delete__(self, instance: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... +class _TypingInstances: + def __getattr__(self, key: str) -> tuple[type, ...]: + types = tuple( + { + t + for t in [ + getattr(typing, key, None), + getattr(typing_extensions, key, None), + ] + if t is not None + } + ) + if not types: + raise AttributeError(key) + self.__dict__[key] = types + return types + + +_type_tuples = _TypingInstances() +if TYPE_CHECKING: + _type_instances = typing_extensions +else: + _type_instances = _type_tuples -# $def ro_descriptor_reference(fn: Callable[]) +LITERAL_TYPES = _type_tuples.Literal diff --git a/pyproject.toml b/pyproject.toml index 3cdf49301f7..31863651faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [build-system] - build-backend = "setuptools.build_meta" - requires = [ - "setuptools>=47", - "cython>=0.29.24; python_implementation == 'CPython'", # Skip cython when using pypy - ] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=61.0", + "cython>=0.29.24; platform_python_implementation == 'CPython'", # Skip cython when using pypy +] [tool.black] line-length = 79 @@ -45,6 +45,13 @@ filterwarnings = [ # sqlite3 warnings due to test/dialect/test_sqlite.py->test_native_datetime, # which is asserting that these deprecated-in-py312 handlers are functional "ignore:The default (date)?(time)?(stamp)? (adapter|converter):DeprecationWarning", + + # warning regarding using "fork" mode for multiprocessing when the parent + # has threads; using pytest-xdist introduces threads in the parent + # and we use multiprocessing in test/aaa_profiling/test_memusage.py where + # we require "fork" mode + # https://github.com/python/cpython/pull/100229#issuecomment-2704616288 + "ignore:This process .* is multi-threaded:DeprecationWarning", ] markers = [ "memory_intensive: memory / CPU intensive suite tests", @@ -65,6 +72,8 @@ reportTypedDictNotRequiredAccess = "warning" mypy_path = "./lib/" show_error_codes = true incremental = true +# would be nice to enable this but too many error are surfaceds +# enable_error_code = "ignore-without-code" [[tool.mypy.overrides]] @@ -80,7 +89,7 @@ strict = true [tool.cibuildwheel] test-requires = "pytest pytest-xdist" # remove user site, otherwise the local checkout has precedence, disabling cyextensions -test-command = "python -s -m pytest -c {project}/pyproject.toml -n2 -q --nomemory --notimingintensive --nomypy {project}/test" +test-command = "python -s -m pytest -c {project}/pyproject.toml -n4 -q --nomemory --notimingintensive --nomypy {project}/test" build = "*" # python 3.6 is no longer supported by sqlalchemy diff --git a/reap_dbs.py b/reap_dbs.py index 81f9b8f26ee..c6d2616e6da 100644 --- a/reap_dbs.py +++ b/reap_dbs.py @@ -1,4 +1,4 @@ -"""Drop Oracle, SQL Server databases that are left over from a +"""Drop Oracle Database, SQL Server databases that are left over from a multiprocessing test run. Currently the cx_Oracle driver seems to sometimes not release a @@ -10,6 +10,7 @@ database in process. """ + import logging import sys diff --git a/regen_callcounts.tox.ini b/regen_callcounts.tox.ini index 5f9c2aa99bc..9a98ce8efa7 100644 --- a/regen_callcounts.tox.ini +++ b/regen_callcounts.tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{310}-sqla_{cext,nocext}-db_{sqlite,postgresql,mysql,oracle,mssql} +envlist = py{311}-sqla_{cext,nocext}-db_{sqlite,postgresql,mysql,oracle,mssql} [testenv] deps=pytest @@ -7,8 +7,7 @@ deps=pytest mock db_postgresql: .[postgresql] db_mysql: .[mysql] - db_mysql: .[pymysql] - db_oracle: .[oracle] + db_oracle: .[oracle_oracledb] db_mssql: .[mssql] @@ -22,13 +21,13 @@ commands= db_{mssql}: {env:BASECOMMAND} {env:MSSQL:} {posargs} passenv= - ORACLE_HOME - NLS_LANG - TOX_POSTGRESQL - TOX_MYSQL - TOX_ORACLE - TOX_MSSQL - TOX_SQLITE + ORACLE_HOME + NLS_LANG + TOX_POSTGRESQL + TOX_MYSQL + TOX_ORACLE + TOX_MSSQL + TOX_SQLITE TOX_WORKERS # -E : ignore PYTHON* environment variables (such as PYTHONPATH) @@ -41,8 +40,8 @@ setenv= sqla_cext: REQUIRE_SQLALCHEMY_CEXT=1 db_sqlite: SQLITE={env:TOX_SQLITE:--db sqlite} db_postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} - db_mysql: MYSQL={env:TOX_MYSQL:--db mysql --db pymysql} - db_oracle: ORACLE={env:TOX_ORACLE:--db oracle} + db_mysql: MYSQL={env:TOX_MYSQL:--db mysql} + db_oracle: ORACLE={env:TOX_ORACLE:--db oracledb} db_mssql: MSSQL={env:TOX_MSSQL:--db mssql} diff --git a/setup.cfg b/setup.cfg index b797af4afc5..de35dd2e158 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,6 @@ license_files = LICENSE classifiers = Development Status :: 5 - Production/Stable Intended Audience :: Developers - License :: OSI Approved :: MIT License Operating System :: OS Independent Programming Language :: Python Programming Language :: Python :: 3 @@ -21,6 +20,8 @@ classifiers = Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 Programming Language :: Python :: Implementation :: CPython Programming Language :: Python :: Implementation :: PyPy Topic :: Database :: Front-Ends @@ -37,12 +38,12 @@ package_dir = install_requires = importlib-metadata;python_version<"3.8" - greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32')))))) - typing-extensions >= 4.2.0 + greenlet >= 1;(python_version<"3.14" and (platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32'))))))) + typing-extensions >= 4.6.0 [options.extras_require] asyncio = - greenlet!=0.4.17 + greenlet >= 1 mypy = mypy >= 0.910 mssql = pyodbc @@ -53,7 +54,7 @@ mysql = mysql_connector = mysql-connector-python mariadb_connector = - mariadb>=1.0.1,!=1.1.2,!=1.1.5 + mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 oracle = cx_oracle>=8 oracle_oracledb = @@ -104,9 +105,9 @@ enable-extensions = G # E203 is due to https://github.com/PyCQA/pycodestyle/issues/373 ignore = - A003, + A003,A005 D, - E203,E305,E711,E712,E721,E722,E741, + E203,E305,E701,E704,E711,E712,E721,E722,E741, N801,N802,N806, RST304,RST303,RST299,RST399, W503,W504,W601 @@ -135,13 +136,13 @@ requirement_cls = test.requirements:DefaultRequirements profile_file = test/profiles.txt # name of a "loopback" link set up on the oracle database. -# to create this, suppose your DB is scott/tiger@xe. You'd create it +# to create this, suppose your DB is scott/tiger@free. You'd create it # like: # create public database link test_link connect to scott identified by tiger -# using 'xe'; +# using 'free'; oracle_db_link = test_link # create public database link test_link2 connect to test_schema identified by tiger -# using 'xe'; +# using 'free'; oracle_db_link2 = test_link2 # host name of a postgres database that has the postgres_fdw extension. @@ -177,11 +178,13 @@ asyncmy = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 asyncmy_fallback = mysql+asyncmy://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true mariadb = mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test -mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes -mssql_async = mssql+aioodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes -pymssql = mssql+pymssql://scott:tiger^5HHH@mssql2017:1433/test -docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server +mysql_connector = mariadb+mysqlconnector://scott:tiger@127.0.0.1:3306/test +mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2022:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Encrypt=Optional +mssql_async = mssql+aioodbc://scott:tiger^5HHH@mssql2022:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Encrypt=Optional +pymssql = mssql+pymssql://scott:tiger^5HHH@mssql2022:1433/test +docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Encrypt=Optional oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe oracledb = oracle+oracledb://scott:tiger@oracle18c/xe -docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=XEPDB1 +oracledb_async = oracle+oracledb_async://scott:tiger@oracle18c/xe +docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=FREEPDB1 diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index fc6be0f0960..b6745e8b0b3 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from sqlalchemy import and_ +from sqlalchemy import ClauseElement from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import inspect @@ -20,8 +21,10 @@ from sqlalchemy import util from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects import registry from sqlalchemy.dialects import sqlite from sqlalchemy.engine import result +from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.processors import to_decimal_processor_factory from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes @@ -39,6 +42,7 @@ from sqlalchemy.orm.session import _sessions from sqlalchemy.sql import column from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.util import visit_binary_product from sqlalchemy.sql.visitors import cloned_traverse from sqlalchemy.sql.visitors import replacement_traverse @@ -219,10 +223,14 @@ def run_plain(*func_args): # return run_plain def run_in_process(*func_args): - queue = multiprocessing.Queue() - proc = multiprocessing.Process( - target=profile, args=(queue, func_args) - ) + # see + # https://docs.python.org/3.14/whatsnew/3.14.html + # #incompatible-changes - the default run type is no longer + # "fork", but since we are running closures in the process + # we need forked mode + ctx = multiprocessing.get_context("fork") + queue = ctx.Queue() + proc = ctx.Process(target=profile, args=(queue, func_args)) proc.start() while True: row = queue.get() @@ -390,7 +398,7 @@ def go(): @testing.add_to_marker.memory_intensive class MemUsageWBackendTest(fixtures.MappedTest, EnsureZeroed): - __requires__ = "cpython", "memory_process_intensive", "no_asyncio" + __requires__ = "cpython", "posix", "memory_process_intensive", "no_asyncio" __sparse_backend__ = True # ensure a pure growing test trips the assertion @@ -1192,6 +1200,22 @@ def go(): metadata.drop_all(self.engine) +class SomeFoo(DialectKWArgs, ClauseElement): + pass + + +class FooDialect(DefaultDialect): + construct_arguments = [ + ( + SomeFoo, + { + "bar": False, + "bat": False, + }, + ) + ] + + @testing.add_to_marker.memory_intensive class CycleTest(_fixtures.FixtureTest): __requires__ = ("cpython", "no_windows") @@ -1216,6 +1240,33 @@ def go(): go() + @testing.fixture + def foo_dialect(self): + registry.register("foo", __name__, "FooDialect") + + yield + registry.deregister("foo") + + def test_dialect_kwargs(self, foo_dialect): + + @assert_cycles() + def go(): + ff = SomeFoo() + + ff._validate_dialect_kwargs({"foo_bar": True}) + + eq_(ff.dialect_options["foo"]["bar"], True) + + eq_(ff.dialect_options["foo"]["bat"], False) + + eq_(ff.dialect_kwargs["foo_bar"], True) + eq_(ff.dialect_kwargs["foo_bat"], False) + + ff.dialect_kwargs["foo_bat"] = True + eq_(ff.dialect_options["foo"]["bat"], True) + + go() + def test_session_execute_orm(self): User, Address = self.classes("User", "Address") configure_mappers() diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 3a5a200d805..e02c7cae857 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -1,7 +1,9 @@ from sqlalchemy import and_ from sqlalchemy import ForeignKey +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import join +from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing @@ -13,10 +15,12 @@ from sqlalchemy.orm import join as orm_join from sqlalchemy.orm import joinedload from sqlalchemy.orm import Load +from sqlalchemy.orm import query_expression from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import with_expression from sqlalchemy.testing import fixtures from sqlalchemy.testing import profiling from sqlalchemy.testing.fixtures import fixture_session @@ -142,7 +146,6 @@ def go2(): class LoadManyToOneFromIdentityTest(fixtures.MappedTest): - """test overhead associated with many-to-one fetches. Prior to the refactor of LoadLazyAttribute and @@ -1315,3 +1318,112 @@ def go(): r = q.all() # noqa: F841 go() + + +class WithExpresionLoaderOptTest(fixtures.DeclarativeMappedTest): + # keep caching on with this test. + __requires__ = ("python_profiling_backend",) + + """test #11085""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class A(Base): + __tablename__ = "a" + + id = Column(Integer, Identity(), primary_key=True) + data = Column(String(30)) + bs = relationship("B") + + class B(Base): + __tablename__ = "b" + id = Column(Integer, Identity(), primary_key=True) + a_id = Column(ForeignKey("a.id")) + boolean = query_expression() + d1 = Column(String(30)) + d2 = Column(String(30)) + d3 = Column(String(30)) + d4 = Column(String(30)) + d5 = Column(String(30)) + d6 = Column(String(30)) + d7 = Column(String(30)) + + @classmethod + def insert_data(cls, connection): + A, B = cls.classes("A", "B") + + with Session(connection) as s: + s.add( + A( + bs=[ + B( + d1="x", + d2="x", + d3="x", + d4="x", + d5="x", + d6="x", + d7="x", + ) + ] + ) + ) + s.commit() + + def test_from_opt_no_cache(self): + A, B = self.classes("A", "B") + + @profiling.function_call_count(warmup=2) + def go(): + with Session( + testing.db.execution_options(compiled_cache=None) + ) as sess: + _ = sess.execute( + select(A).options( + selectinload(A.bs).options( + with_expression( + B.boolean, + and_( + B.d1 == "x", + B.d2 == "x", + B.d3 == "x", + B.d4 == "x", + B.d5 == "x", + B.d6 == "x", + B.d7 == "x", + ), + ) + ) + ) + ).scalars() + + go() + + def test_from_opt_after_cache(self): + A, B = self.classes("A", "B") + + @profiling.function_call_count(warmup=2) + def go(): + with Session(testing.db) as sess: + _ = sess.execute( + select(A).options( + selectinload(A.bs).options( + with_expression( + B.boolean, + and_( + B.d1 == literal_column("'x'"), + B.d2 == "x", + B.d3 == literal_column("'x'"), + B.d4 == "x", + B.d5 == literal_column("'x'"), + B.d6 == "x", + B.d7 == literal_column("'x'"), + ), + ) + ) + ) + ).scalars() + + go() diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index b4fb34d0259..6cfa0383d6d 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -4,6 +4,7 @@ import threading from sqlalchemy import exc +from sqlalchemy import testing from sqlalchemy.testing import async_test from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -80,6 +81,7 @@ def go(): with expect_raises_message(ValueError, "sync error"): await greenlet_spawn(go) + @testing.requires.not_python314 def test_await_fallback_no_greenlet(self): to_await = run1() await_fallback(to_await) @@ -264,3 +266,18 @@ def prime(): t.join() is_true(run[0]) + + +class GracefulNoGreenletTest(fixtures.TestBase): + __requires__ = ("no_greenlet",) + + def test_await_only_graceful(self): + async def async_fn(): + pass + + with expect_raises_message( + ValueError, + "the greenlet library is required to use this " + "function. No module named 'greenlet'", + ): + await_only(async_fn()) diff --git a/test/base/test_events.py b/test/base/test_events.py index 6f8456274f3..ccb53f2bb37 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -978,6 +978,9 @@ class TargetElement(BaseTarget): def __init__(self, parent): self.dispatch = self.dispatch._join(parent.dispatch) + def create(self): + return TargetElement(self) + def run_event(self, arg): list(self.dispatch.event_one) self.dispatch.event_one(self, arg) @@ -1044,6 +1047,38 @@ def test_parent_class_child_class(self): [call(element, 1), call(element, 2), call(element, 3)], ) + def test_join_twice(self): + """test #12289""" + + l1 = Mock() + l2 = Mock() + + first_target_element = self.TargetFactory().create() + second_target_element = first_target_element.create() + + event.listen(second_target_element, "event_one", l2) + event.listen(first_target_element, "event_one", l1) + + second_target_element.run_event(1) + eq_( + l1.mock_calls, + [call(second_target_element, 1)], + ) + eq_( + l2.mock_calls, + [call(second_target_element, 1)], + ) + + first_target_element.run_event(2) + eq_( + l1.mock_calls, + [call(second_target_element, 1), call(first_target_element, 2)], + ) + eq_( + l2.mock_calls, + [call(second_target_element, 1)], + ) + def test_parent_class_child_instance_apply_after(self): l1 = Mock() l2 = Mock() @@ -1271,6 +1306,107 @@ class Target: return Target + def test_two_subclasses_one_event(self): + """test #12216""" + + Target = self._fixture() + + class TargetSubclassOne(Target): + pass + + class TargetSubclassTwo(Target): + pass + + m1 = Mock() + + def my_event_one(x, y): + m1.my_event_one(x, y) + + event.listen(TargetSubclassOne, "event_one", my_event_one) + event.listen(TargetSubclassTwo, "event_one", my_event_one) + + t1 = TargetSubclassOne() + t2 = TargetSubclassTwo() + + t1.dispatch.event_one("x1a", "y1a") + t2.dispatch.event_one("x2a", "y2a") + + eq_( + m1.mock_calls, + [call.my_event_one("x1a", "y1a"), call.my_event_one("x2a", "y2a")], + ) + + event.remove(TargetSubclassOne, "event_one", my_event_one) + + t1.dispatch.event_one("x1b", "y1b") + t2.dispatch.event_one("x2b", "y2b") + + eq_( + m1.mock_calls, + [ + call.my_event_one("x1a", "y1a"), + call.my_event_one("x2a", "y2a"), + call.my_event_one("x2b", "y2b"), + ], + ) + + event.remove(TargetSubclassTwo, "event_one", my_event_one) + + t1.dispatch.event_one("x1c", "y1c") + t2.dispatch.event_one("x2c", "y2c") + + eq_( + m1.mock_calls, + [ + call.my_event_one("x1a", "y1a"), + call.my_event_one("x2a", "y2a"), + call.my_event_one("x2b", "y2b"), + ], + ) + + def test_two_subclasses_one_event_reg_cleanup(self): + """test #12216""" + + from sqlalchemy.event import registry + + Target = self._fixture() + + class TargetSubclassOne(Target): + pass + + class TargetSubclassTwo(Target): + pass + + m1 = Mock() + + def my_event_one(x, y): + m1.my_event_one(x, y) + + event.listen(TargetSubclassOne, "event_one", my_event_one) + event.listen(TargetSubclassTwo, "event_one", my_event_one) + + key1 = (id(TargetSubclassOne), "event_one", id(my_event_one)) + key2 = (id(TargetSubclassTwo), "event_one", id(my_event_one)) + + assert key1 in registry._key_to_collection + assert key2 in registry._key_to_collection + + del TargetSubclassOne + gc_collect() + + # the key remains because the gc routine would be based on deleting + # Target (I think) + assert key1 in registry._key_to_collection + assert key2 in registry._key_to_collection + + del TargetSubclassTwo + gc_collect() + + assert key1 in registry._key_to_collection + assert key2 in registry._key_to_collection + + # event.remove(TargetSubclassTwo, "event_one", my_event_one) + def test_clslevel(self): Target = self._fixture() @@ -1503,6 +1639,38 @@ def test_listener_collection_removed_cleanup(self): assert key not in registry._key_to_collection assert collection_ref not in registry._collection_to_key + @testing.requires.predictable_gc + def test_listener_collection_removed_cleanup_clslevel(self): + """test related to #12216""" + + from sqlalchemy.event import registry + + Target = self._fixture() + + m1 = Mock() + + event.listen(Target, "event_one", m1) + + key = (id(Target), "event_one", id(m1)) + + assert key in registry._key_to_collection + collection_ref = list(registry._key_to_collection[key])[0] + assert collection_ref in registry._collection_to_key + + t1 = Target() + t1.dispatch.event_one("t1") + + del t1 + + del Target + + gc_collect() + + # gc of a target class does not currently cause these collections + # to be cleaned up + assert key in registry._key_to_collection + assert collection_ref in registry._collection_to_key + def test_remove_not_listened(self): Target = self._fixture() diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index b920f25f0a5..d86322e12ee 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -6,9 +6,11 @@ import re import sys +from sqlalchemy.engine.url import make_url from sqlalchemy.testing import config from sqlalchemy.testing import fixtures from sqlalchemy.testing import requires +from sqlalchemy.testing import skip_test class DocTest(fixtures.TestBase): @@ -65,12 +67,9 @@ def _run_doctest(self, *fnames): doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL - | _get_allow_unicode_flag() ) runner = doctest.DocTestRunner( - verbose=None, - optionflags=optionflags, - checker=_get_unicode_checker(), + verbose=config.options.verbose >= 2, optionflags=optionflags ) parser = doctest.DocTestParser() globs = {"print_function": print} @@ -163,90 +162,28 @@ def test_orm_queryguide_select(self): ) def test_orm_queryguide_inheritance(self): - self._run_doctest( - "orm/queryguide/inheritance.rst", - ) + self._run_doctest("orm/queryguide/inheritance.rst") @requires.update_from def test_orm_queryguide_dml(self): - self._run_doctest( - "orm/queryguide/dml.rst", - ) + self._run_doctest("orm/queryguide/dml.rst") def test_orm_large_collections(self): - self._run_doctest( - "orm/large_collections.rst", - ) + self._run_doctest("orm/large_collections.rst") def test_orm_queryguide_columns(self): - self._run_doctest( - "orm/queryguide/columns.rst", - ) + self._run_doctest("orm/queryguide/columns.rst") def test_orm_quickstart(self): self._run_doctest("orm/quickstart.rst") - -# unicode checker courtesy pytest - - -def _get_unicode_checker(): - """ - Returns a doctest.OutputChecker subclass that takes in account the - ALLOW_UNICODE option to ignore u'' prefixes in strings. Useful - when the same doctest should run in Python 2 and Python 3. - - An inner class is used to avoid importing "doctest" at the module - level. - """ - if hasattr(_get_unicode_checker, "UnicodeOutputChecker"): - return _get_unicode_checker.UnicodeOutputChecker() - - import doctest - import re - - class UnicodeOutputChecker(doctest.OutputChecker): - """ - Copied from doctest_nose_plugin.py from the nltk project: - https://github.com/nltk/nltk - """ - - _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE) - - def check_output(self, want, got, optionflags): - res = doctest.OutputChecker.check_output( - self, want, got, optionflags - ) - if res: - return True - - if not (optionflags & _get_allow_unicode_flag()): - return False - - else: # pragma: no cover - # the code below will end up executed only in Python 2 in - # our tests, and our coverage check runs in Python 3 only - def remove_u_prefixes(txt): - return re.sub(self._literal_re, r"\1\2", txt) - - want = remove_u_prefixes(want) - got = remove_u_prefixes(got) - res = doctest.OutputChecker.check_output( - self, want, got, optionflags - ) - return res - - _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker - return _get_unicode_checker.UnicodeOutputChecker() - - -def _get_allow_unicode_flag(): - """ - Registers and returns the ALLOW_UNICODE flag. - """ - import doctest - - return doctest.register_optionflag("ALLOW_UNICODE") + @requires.greenlet + def test_asyncio(self): + try: + make_url("sqlite+aiosqlite://").get_dialect().import_dbapi() + except ImportError: + skip_test("missing aiosqile") + self._run_doctest("orm/extensions/asyncio.rst") # increase number to force pipeline run. 1 diff --git a/test/base/test_typing_utils.py b/test/base/test_typing_utils.py new file mode 100644 index 00000000000..51f5e13c418 --- /dev/null +++ b/test/base/test_typing_utils.py @@ -0,0 +1,644 @@ +# NOTE: typing implementation is full of heuristic so unit test it to avoid +# unexpected breakages. + +import typing + +import typing_extensions + +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import requires +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import is_ +from sqlalchemy.util import py310 +from sqlalchemy.util import py312 +from sqlalchemy.util import py314 +from sqlalchemy.util import py38 +from sqlalchemy.util import typing as sa_typing + +TV = typing.TypeVar("TV") + + +def union_types(): + res = [typing.Union[int, str]] + if py310: + res.append(int | str) + return res + + +def null_union_types(): + res = [ + typing.Optional[typing.Union[int, str]], + typing.Union[int, str, None], + typing.Union[int, str, "None"], + ] + if py310: + res.append(int | str | None) + res.append(typing.Optional[int | str]) + res.append(typing.Union[int, str] | None) + res.append(typing.Optional[int] | str) + return res + + +def generic_unions(): + # remove new-style unions `int | str` that are not generic + res = union_types() + null_union_types() + if py310 and not py314: + new_ut = type(int | str) + res = [t for t in res if not isinstance(t, new_ut)] + return res + + +def make_fw_ref(anno: str) -> typing.ForwardRef: + return typing.Union[anno] + + +TypeAliasType = getattr( + typing, "TypeAliasType", typing_extensions.TypeAliasType +) + +TA_int = TypeAliasType("TA_int", int) +TAext_int = typing_extensions.TypeAliasType("TAext_int", int) +TA_union = TypeAliasType("TA_union", typing.Union[int, str]) +TAext_union = typing_extensions.TypeAliasType( + "TAext_union", typing.Union[int, str] +) +TA_null_union = TypeAliasType("TA_null_union", typing.Union[int, str, None]) +TAext_null_union = typing_extensions.TypeAliasType( + "TAext_null_union", typing.Union[int, str, None] +) +TA_null_union2 = TypeAliasType( + "TA_null_union2", typing.Union[int, str, "None"] +) +TAext_null_union2 = typing_extensions.TypeAliasType( + "TAext_null_union2", typing.Union[int, str, "None"] +) +TA_null_union3 = TypeAliasType( + "TA_null_union3", typing.Union[int, "typing.Union[None, bool]"] +) +TAext_null_union3 = typing_extensions.TypeAliasType( + "TAext_null_union3", typing.Union[int, "typing.Union[None, bool]"] +) +TA_null_union4 = TypeAliasType( + "TA_null_union4", typing.Union[int, "TA_null_union2"] +) +TAext_null_union4 = typing_extensions.TypeAliasType( + "TAext_null_union4", typing.Union[int, "TAext_null_union2"] +) +TA_union_ta = TypeAliasType("TA_union_ta", typing.Union[TA_int, str]) +TAext_union_ta = typing_extensions.TypeAliasType( + "TAext_union_ta", typing.Union[TAext_int, str] +) +TA_null_union_ta = TypeAliasType( + "TA_null_union_ta", typing.Union[TA_null_union, float] +) +TAext_null_union_ta = typing_extensions.TypeAliasType( + "TAext_null_union_ta", typing.Union[TAext_null_union, float] +) +TA_list = TypeAliasType( + "TA_list", typing.Union[int, str, typing.List["TA_list"]] +) +TAext_list = typing_extensions.TypeAliasType( + "TAext_list", typing.Union[int, str, typing.List["TAext_list"]] +) +# these below not valid. Verify that it does not cause exceptions in any case +TA_recursive = TypeAliasType("TA_recursive", typing.Union["TA_recursive", str]) +TAext_recursive = typing_extensions.TypeAliasType( + "TAext_recursive", typing.Union["TAext_recursive", str] +) +TA_null_recursive = TypeAliasType( + "TA_null_recursive", typing.Union[TA_recursive, None] +) +TAext_null_recursive = typing_extensions.TypeAliasType( + "TAext_null_recursive", typing.Union[TAext_recursive, None] +) +TA_recursive_a = TypeAliasType( + "TA_recursive_a", typing.Union["TA_recursive_b", int] +) +TAext_recursive_a = typing_extensions.TypeAliasType( + "TAext_recursive_a", typing.Union["TAext_recursive_b", int] +) +TA_recursive_b = TypeAliasType( + "TA_recursive_b", typing.Union["TA_recursive_a", str] +) +TAext_recursive_b = typing_extensions.TypeAliasType( + "TAext_recursive_b", typing.Union["TAext_recursive_a", str] +) +TA_generic = TypeAliasType("TA_generic", typing.List[TV], type_params=(TV,)) +TAext_generic = typing_extensions.TypeAliasType( + "TAext_generic", typing.List[TV], type_params=(TV,) +) +TA_generic_typed = TA_generic[int] +TAext_generic_typed = TAext_generic[int] +TA_generic_null = TypeAliasType( + "TA_generic_null", typing.Union[typing.List[TV], None], type_params=(TV,) +) +TAext_generic_null = typing_extensions.TypeAliasType( + "TAext_generic_null", + typing.Union[typing.List[TV], None], + type_params=(TV,), +) +TA_generic_null_typed = TA_generic_null[str] +TAext_generic_null_typed = TAext_generic_null[str] + + +def type_aliases(): + return [ + TA_int, + TAext_int, + TA_union, + TAext_union, + TA_null_union, + TAext_null_union, + TA_null_union2, + TAext_null_union2, + TA_null_union3, + TAext_null_union3, + TA_null_union4, + TAext_null_union4, + TA_union_ta, + TAext_union_ta, + TA_null_union_ta, + TAext_null_union_ta, + TA_list, + TAext_list, + TA_recursive, + TAext_recursive, + TA_null_recursive, + TAext_null_recursive, + TA_recursive_a, + TAext_recursive_a, + TA_recursive_b, + TAext_recursive_b, + TA_generic, + TAext_generic, + TA_generic_typed, + TAext_generic_typed, + TA_generic_null, + TAext_generic_null, + TA_generic_null_typed, + TAext_generic_null_typed, + ] + + +NT_str = typing.NewType("NT_str", str) +NT_null = typing.NewType("NT_null", None) +# this below is not valid. Verify that it does not cause exceptions in any case +NT_union = typing.NewType("NT_union", typing.Union[str, int]) + + +def new_types(): + return [NT_str, NT_null, NT_union] + + +A_str = typing_extensions.Annotated[str, "meta"] +A_null_str = typing_extensions.Annotated[ + typing.Union[str, None], "other_meta", "null" +] +A_union = typing_extensions.Annotated[typing.Union[str, int], "other_meta"] +A_null_union = typing_extensions.Annotated[ + typing.Union[str, int, None], "other_meta", "null" +] + + +def compare_type_by_string(a, b): + """python 3.14 has made ForwardRefs not really comparable or reliably + hashable. + + As we need to compare types here, including structures like + `Union["str", "int"]`, without having to dive into cpython's source code + each time a new release comes out, compare based on stringification, + which still presents changing rules but at least are easy to diagnose + and correct for different python versions. + + See discussion at https://github.com/python/cpython/issues/129463 + for background + + """ + + if isinstance(a, (set, list)): + a = sorted(a, key=lambda x: str(x)) + if isinstance(b, (set, list)): + b = sorted(b, key=lambda x: str(x)) + + eq_(str(a), str(b)) + + +def annotated_l(): + return [A_str, A_null_str, A_union, A_null_union] + + +def all_types(): + return ( + union_types() + + null_union_types() + + type_aliases() + + new_types() + + annotated_l() + ) + + +def exec_code(code: str, *vars: str) -> typing.Any: + assert vars + scope = {} + exec(code, None, scope) + if len(vars) == 1: + return scope[vars[0]] + return [scope[name] for name in vars] + + +class TestTestingThings(fixtures.TestBase): + def test_unions_are_the_same(self): + # the point of this test is to reduce the cases to test since + # some symbols are the same in typing and typing_extensions. + # If a test starts failing then additional cases should be added, + # similar to what it's done for TypeAliasType + + # no need to test typing_extensions.Union, typing_extensions.Optional + is_(typing.Union, typing_extensions.Union) + is_(typing.Optional, typing_extensions.Optional) + + @requires.python312 + def test_make_type_alias_type(self): + # verify that TypeAliasType('foo', int) it the same as 'type foo = int' + x_type = exec_code("type x = int", "x") + x = typing.TypeAliasType("x", int) + + eq_(type(x_type), type(x)) + eq_(x_type.__name__, x.__name__) + eq_(x_type.__value__, x.__value__) + + def test_make_fw_ref(self): + compare_type_by_string(make_fw_ref("str"), typing.ForwardRef("str")) + compare_type_by_string( + make_fw_ref("str|int"), typing.ForwardRef("str|int") + ) + compare_type_by_string( + make_fw_ref("Optional[Union[str, int]]"), + typing.ForwardRef("Optional[Union[str, int]]"), + ) + + +class TestTyping(fixtures.TestBase): + def test_is_pep593(self): + eq_(sa_typing.is_pep593(str), False) + eq_(sa_typing.is_pep593(None), False) + eq_(sa_typing.is_pep593(typing_extensions.Annotated[int, "a"]), True) + if py310: + eq_(sa_typing.is_pep593(typing.Annotated[int, "a"]), True) + + for t in annotated_l(): + eq_(sa_typing.is_pep593(t), True) + for t in ( + union_types() + null_union_types() + type_aliases() + new_types() + ): + eq_(sa_typing.is_pep593(t), False) + + def test_is_literal(self): + if py38: + eq_(sa_typing.is_literal(typing.Literal["a"]), True) + eq_(sa_typing.is_literal(typing_extensions.Literal["a"]), True) + eq_(sa_typing.is_literal(None), False) + for t in all_types(): + eq_(sa_typing.is_literal(t), False) + + def test_is_newtype(self): + eq_(sa_typing.is_newtype(str), False) + + for t in new_types(): + eq_(sa_typing.is_newtype(t), True) + for t in ( + union_types() + null_union_types() + type_aliases() + annotated_l() + ): + eq_(sa_typing.is_newtype(t), False) + + def test_is_generic(self): + class W(typing.Generic[TV]): + pass + + eq_(sa_typing.is_generic(typing.List[int]), True) + eq_(sa_typing.is_generic(W), False) + eq_(sa_typing.is_generic(W[str]), True) + + if py312: + t = exec_code("class W[T]: pass", "W") + eq_(sa_typing.is_generic(t), False) + eq_(sa_typing.is_generic(t[int]), True) + + generics = [ + TA_generic_typed, + TAext_generic_typed, + TA_generic_null_typed, + TAext_generic_null_typed, + *annotated_l(), + *generic_unions(), + ] + + for t in all_types(): + if py314: + exp = any(t == k for k in generics) + else: + # use is since union compare equal between new/old style + exp = any(t is k for k in generics) + eq_(sa_typing.is_generic(t), exp, t) + + def test_is_pep695(self): + eq_(sa_typing.is_pep695(str), False) + for t in ( + union_types() + null_union_types() + new_types() + annotated_l() + ): + eq_(sa_typing.is_pep695(t), False) + for t in type_aliases(): + eq_(sa_typing.is_pep695(t), True) + + @requires.python38 + def test_pep695_value(self): + eq_(sa_typing.pep695_values(int), {int}) + eq_( + sa_typing.pep695_values(typing.Union[int, str]), + {typing.Union[int, str]}, + ) + + for t in ( + union_types() + null_union_types() + new_types() + annotated_l() + ): + eq_(sa_typing.pep695_values(t), {t}) + + eq_( + sa_typing.pep695_values(typing.Union[int, TA_int]), + {typing.Union[int, TA_int]}, + ) + eq_( + sa_typing.pep695_values(typing.Union[int, TAext_int]), + {typing.Union[int, TAext_int]}, + ) + + eq_(sa_typing.pep695_values(TA_int), {int}) + eq_(sa_typing.pep695_values(TAext_int), {int}) + eq_(sa_typing.pep695_values(TA_union), {int, str}) + eq_(sa_typing.pep695_values(TAext_union), {int, str}) + eq_(sa_typing.pep695_values(TA_null_union), {int, str, None}) + eq_(sa_typing.pep695_values(TAext_null_union), {int, str, None}) + eq_(sa_typing.pep695_values(TA_null_union2), {int, str, None}) + eq_(sa_typing.pep695_values(TAext_null_union2), {int, str, None}) + + compare_type_by_string( + sa_typing.pep695_values(TA_null_union3), + [int, typing.ForwardRef("typing.Union[None, bool]")], + ) + + compare_type_by_string( + sa_typing.pep695_values(TAext_null_union3), + {int, typing.ForwardRef("typing.Union[None, bool]")}, + ) + + compare_type_by_string( + sa_typing.pep695_values(TA_null_union4), + [int, typing.ForwardRef("TA_null_union2")], + ) + compare_type_by_string( + sa_typing.pep695_values(TAext_null_union4), + {int, typing.ForwardRef("TAext_null_union2")}, + ) + + eq_(sa_typing.pep695_values(TA_union_ta), {int, str}) + eq_(sa_typing.pep695_values(TAext_union_ta), {int, str}) + eq_(sa_typing.pep695_values(TA_null_union_ta), {int, str, None, float}) + + compare_type_by_string( + sa_typing.pep695_values(TAext_null_union_ta), + {int, str, None, float}, + ) + + compare_type_by_string( + sa_typing.pep695_values(TA_list), + [int, str, typing.List[typing.ForwardRef("TA_list")]], + ) + + compare_type_by_string( + sa_typing.pep695_values(TAext_list), + {int, str, typing.List[typing.ForwardRef("TAext_list")]}, + ) + + compare_type_by_string( + sa_typing.pep695_values(TA_recursive), + [str, typing.ForwardRef("TA_recursive")], + ) + compare_type_by_string( + sa_typing.pep695_values(TAext_recursive), + {typing.ForwardRef("TAext_recursive"), str}, + ) + compare_type_by_string( + sa_typing.pep695_values(TA_null_recursive), + [str, typing.ForwardRef("TA_recursive"), None], + ) + compare_type_by_string( + sa_typing.pep695_values(TAext_null_recursive), + {typing.ForwardRef("TAext_recursive"), str, None}, + ) + compare_type_by_string( + sa_typing.pep695_values(TA_recursive_a), + [int, typing.ForwardRef("TA_recursive_b")], + ) + compare_type_by_string( + sa_typing.pep695_values(TAext_recursive_a), + {typing.ForwardRef("TAext_recursive_b"), int}, + ) + compare_type_by_string( + sa_typing.pep695_values(TA_recursive_b), + [str, typing.ForwardRef("TA_recursive_a")], + ) + compare_type_by_string( + sa_typing.pep695_values(TAext_recursive_b), + {typing.ForwardRef("TAext_recursive_a"), str}, + ) + + @requires.up_to_date_typealias_type + def test_pep695_value_generics(self): + # generics + + eq_(sa_typing.pep695_values(TA_generic), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TAext_generic), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TA_generic_typed), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TAext_generic_typed), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TA_generic_null), {None, typing.List[TV]}) + eq_( + sa_typing.pep695_values(TAext_generic_null), + {None, typing.List[TV]}, + ) + eq_( + sa_typing.pep695_values(TA_generic_null_typed), + {None, typing.List[TV]}, + ) + eq_( + sa_typing.pep695_values(TAext_generic_null_typed), + {None, typing.List[TV]}, + ) + + def test_is_fwd_ref(self): + eq_(sa_typing.is_fwd_ref(int), False) + eq_(sa_typing.is_fwd_ref(make_fw_ref("str")), True) + eq_(sa_typing.is_fwd_ref(typing.Union[str, int]), False) + eq_(sa_typing.is_fwd_ref(typing.Union["str", int]), False) + eq_(sa_typing.is_fwd_ref(typing.Union["str", int], True), True) + + for t in all_types(): + eq_(sa_typing.is_fwd_ref(t), False) + + def test_de_optionalize_union_types(self): + fn = sa_typing.de_optionalize_union_types + + eq_( + fn(typing.Optional[typing.Union[int, str]]), typing.Union[int, str] + ) + eq_(fn(typing.Union[int, str, None]), typing.Union[int, str]) + + eq_(fn(typing.Union[int, str, "None"]), typing.Union[int, str]) + + eq_(fn(make_fw_ref("None")), typing_extensions.Never) + eq_(fn(make_fw_ref("typing.Union[None]")), typing_extensions.Never) + eq_(fn(make_fw_ref("Union[None, str]")), typing.ForwardRef("str")) + + compare_type_by_string( + fn(make_fw_ref("Union[None, str, int]")), + typing.Union["str", "int"], + ) + + compare_type_by_string( + fn(make_fw_ref("Optional[int]")), typing.ForwardRef("int") + ) + + compare_type_by_string( + fn(make_fw_ref("typing.Optional[Union[int | str]]")), + typing.ForwardRef("Union[int | str]"), + ) + + for t in null_union_types(): + res = fn(t) + eq_(sa_typing.is_union(res), True) + eq_(type(None) not in res.__args__, True) + + for t in union_types() + type_aliases() + new_types() + annotated_l(): + eq_(fn(t), t) + + compare_type_by_string( + fn(make_fw_ref("Union[typing.Dict[str, int], int, None]")), + typing.Union[ + "typing.Dict[str, int]", + "int", + ], + ) + + def test_make_union_type(self): + eq_(sa_typing.make_union_type(int), int) + eq_(sa_typing.make_union_type(None), type(None)) + eq_(sa_typing.make_union_type(int, str), typing.Union[int, str]) + eq_( + sa_typing.make_union_type(int, typing.Optional[str]), + typing.Union[int, str, None], + ) + eq_( + sa_typing.make_union_type(int, typing.Union[str, bool]), + typing.Union[int, str, bool], + ) + eq_( + sa_typing.make_union_type(bool, TA_int, NT_str), + typing.Union[bool, TA_int, NT_str], + ) + eq_( + sa_typing.make_union_type(bool, TAext_int, NT_str), + typing.Union[bool, TAext_int, NT_str], + ) + + @requires.up_to_date_typealias_type + @requires.python38 + def test_includes_none_generics(self): + # TODO: these are false negatives + false_negatives = { + TA_null_union4, # does not evaluate FW ref + TAext_null_union4, # does not evaluate FW ref + } + for t in type_aliases() + new_types(): + if t in false_negatives: + exp = False + else: + exp = "null" in t.__name__ + eq_(sa_typing.includes_none(t), exp, str(t)) + + @requires.python38 + def test_includes_none(self): + eq_(sa_typing.includes_none(None), True) + eq_(sa_typing.includes_none(type(None)), True) + eq_(sa_typing.includes_none(typing.ForwardRef("None")), True) + eq_(sa_typing.includes_none(int), False) + for t in union_types(): + eq_(sa_typing.includes_none(t), False) + + for t in null_union_types(): + eq_(sa_typing.includes_none(t), True, str(t)) + + for t in annotated_l(): + eq_( + sa_typing.includes_none(t), + "null" in sa_typing.get_args(t), + str(t), + ) + # nested things + eq_(sa_typing.includes_none(typing.Union[int, "None"]), True) + eq_(sa_typing.includes_none(typing.Union[bool, TA_null_union]), True) + eq_( + sa_typing.includes_none(typing.Union[bool, TAext_null_union]), True + ) + eq_(sa_typing.includes_none(typing.Union[bool, NT_null]), True) + # nested fw + eq_( + sa_typing.includes_none( + typing.Union[int, "typing.Union[str, None]"] + ), + True, + ) + eq_( + sa_typing.includes_none( + typing.Union[int, "typing.Union[int, str]"] + ), + False, + ) + + # there are not supported. should return True + eq_( + sa_typing.includes_none(typing.Union[bool, "TA_null_union"]), False + ) + eq_( + sa_typing.includes_none(typing.Union[bool, "TAext_null_union"]), + False, + ) + eq_(sa_typing.includes_none(typing.Union[bool, "NT_null"]), False) + + def test_is_union(self): + eq_(sa_typing.is_union(str), False) + for t in union_types() + null_union_types(): + eq_(sa_typing.is_union(t), True) + for t in type_aliases() + new_types() + annotated_l(): + eq_(sa_typing.is_union(t), False) + + def test_TypingInstances(self): + is_(sa_typing._type_tuples, sa_typing._type_instances) + is_( + isinstance(sa_typing._type_instances, sa_typing._TypingInstances), + True, + ) + + # cached + is_( + sa_typing._type_instances.Literal, + sa_typing._type_instances.Literal, + ) + + for k in ["Literal", "Annotated", "TypeAliasType"]: + types = set() + ti = getattr(sa_typing._type_instances, k) + for lib in [typing, typing_extensions]: + lt = getattr(lib, k, None) + if lt is not None: + types.add(lt) + is_(lt in ti, True) + eq_(len(ti), len(types), k) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 7dcf0968a7c..de8712c8523 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,4 +1,5 @@ import copy +from decimal import Decimal import inspect from pathlib import Path import pickle @@ -31,6 +32,7 @@ from sqlalchemy.util import compat from sqlalchemy.util import FastIntFlag from sqlalchemy.util import get_callable_argspec +from sqlalchemy.util import is_non_string_iterable from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence @@ -1550,6 +1552,30 @@ def __ne__(self, other): return True +class MiscTest(fixtures.TestBase): + @testing.combinations( + (["one", "two", "three"], True), + (("one", "two", "three"), True), + ((), True), + ("four", False), + (252, False), + (Decimal("252"), False), + (b"four", False), + (iter("four"), True), + (b"", False), + ("", False), + (None, False), + ({"dict": "value"}, True), + ({}, True), + ({"set", "two"}, True), + (set(), True), + (util.immutabledict(), True), + (util.immutabledict({"key": "value"}), True), + ) + def test_non_string_iterable_check(self, fixture, expected): + is_(is_non_string_iterable(fixture), expected) + + class IdentitySetTest(fixtures.TestBase): obj_type = object diff --git a/test/base/test_warnings.py b/test/base/test_warnings.py index ee286a7bc9e..069835ff9ec 100644 --- a/test/base/test_warnings.py +++ b/test/base/test_warnings.py @@ -36,7 +36,7 @@ def test_warn_deprecated_limited_cap(self): messages.add(message) eq_(len(printouts), occurrences) - eq_(len(messages), cap) + assert cap / 2 < len(messages) <= cap class ClsWarningTest(fixtures.TestBase): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 74867ccbe21..eb4dba0a079 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -175,7 +175,7 @@ def test_insert(self): t = table("sometable", column("somecolumn")) self.assert_compile( t.insert(), - "INSERT INTO sometable (somecolumn) VALUES " "(:somecolumn)", + "INSERT INTO sometable (somecolumn) VALUES (:somecolumn)", ) def test_update(self): @@ -393,7 +393,11 @@ def test_update_to_select_schema(self): "check_post_param": {}, }, ), - (lambda t: t.c.foo.in_([None]), "sometable.foo IN (NULL)", {}), + ( + lambda t: t.c.foo.in_([None]), + "sometable.foo IN (__[POSTCOMPILE_foo_1])", + {}, + ), ) def test_strict_binds(self, expr, compiled, kw): """test the 'strict' compiler binds.""" @@ -702,9 +706,9 @@ def test_schema_single_token_bracketed( select(tbl), "SELECT %(name)s.test.id FROM %(name)s.test" % {"name": rendered_schema}, - schema_translate_map={None: schemaname} - if use_schema_translate - else None, + schema_translate_map=( + {None: schemaname} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -777,16 +781,20 @@ def test_force_schema_quoted_name_w_dot_case_sensitive( "test", metadata, Column("id", Integer, primary_key=True), - schema=quoted_name("Foo.dbo", True) - if not use_schema_translate - else None, + schema=( + quoted_name("Foo.dbo", True) + if not use_schema_translate + else None + ), ) self.assert_compile( select(tbl), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test", - schema_translate_map={None: quoted_name("Foo.dbo", True)} - if use_schema_translate - else None, + schema_translate_map=( + {None: quoted_name("Foo.dbo", True)} + if use_schema_translate + else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -804,9 +812,9 @@ def test_force_schema_quoted_w_dot_case_sensitive( self.assert_compile( select(tbl), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test", - schema_translate_map={None: "[Foo.dbo]"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "[Foo.dbo]"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -824,9 +832,9 @@ def test_schema_autosplit_w_dot_case_insensitive( self.assert_compile( select(tbl), "SELECT foo.dbo.test.id FROM foo.dbo.test", - schema_translate_map={None: "foo.dbo"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "foo.dbo"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -842,9 +850,9 @@ def test_schema_autosplit_w_dot_case_sensitive(self, use_schema_translate): self.assert_compile( select(tbl), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test", - schema_translate_map={None: "Foo.dbo"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "Foo.dbo"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -858,7 +866,7 @@ def test_delete_schema(self): ) self.assert_compile( tbl.delete().where(tbl.c.id == 1), - "DELETE FROM paj.test WHERE paj.test.id = " ":id_1", + "DELETE FROM paj.test WHERE paj.test.id = :id_1", ) s = select(tbl.c.id).where(tbl.c.id == 1) self.assert_compile( @@ -878,7 +886,7 @@ def test_delete_schema_multipart(self): ) self.assert_compile( tbl.delete().where(tbl.c.id == 1), - "DELETE FROM banana.paj.test WHERE " "banana.paj.test.id = :id_1", + "DELETE FROM banana.paj.test WHERE banana.paj.test.id = :id_1", ) s = select(tbl.c.id).where(tbl.c.id == 1) self.assert_compile( @@ -995,7 +1003,7 @@ def test_function(self): ) self.assert_compile( select(func.max(t.c.col1)), - "SELECT max(sometable.col1) AS max_1 FROM " "sometable", + "SELECT max(sometable.col1) AS max_1 FROM sometable", ) def test_function_overrides(self): @@ -1068,7 +1076,7 @@ def test_delete_returning(self): ) d = delete(table1).returning(table1.c.myid, table1.c.name) self.assert_compile( - d, "DELETE FROM mytable OUTPUT deleted.myid, " "deleted.name" + d, "DELETE FROM mytable OUTPUT deleted.myid, deleted.name" ) d = ( delete(table1) @@ -1941,7 +1949,7 @@ def test_identity_object_no_primary_key_non_nullable(self): ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(3,1)" ")", + "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(3,1))", ) def test_identity_separate_from_primary_key(self): diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index e87b9825f1b..26b7208ec8a 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -326,6 +326,7 @@ def test_pymssql_disconnect(self): "message 20006", # Write to the server failed "message 20017", # Unexpected EOF from the server "message 20047", # DBPROCESS is dead or not enabled + "The server failed to resume the transaction", ]: eq_(dialect.is_disconnect(error, None, None), True) diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index b68b21339ea..33f648b82a0 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -664,7 +664,7 @@ def test_scalar_strings_control(self, scalar_strings, connection): def test_scalar_strings_named_control(self, scalar_strings, connection): result = ( connection.exec_driver_sql( - "SELECT anon_1.my_string " "FROM scalar_strings() AS anon_1" + "SELECT anon_1.my_string FROM scalar_strings() AS anon_1" ) .scalars() .all() diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index b6a1d411a25..7222ba47ae3 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -389,7 +389,7 @@ def test_global_temp_different_collation( ): """test #8035""" - tname = f"##foo{random.randint(1,1000000)}" + tname = f"##foo{random.randint(1, 1000000)}" with temp_db_alt_collation_fixture.connect() as conn: conn.exec_driver_sql(f"CREATE TABLE {tname} (id int primary key)") @@ -1028,10 +1028,13 @@ def define_tables(cls, metadata): for i in range(col_num) ], ) - cls.view_str = ( - view_str - ) = "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" % ( - ",".join("long_named_column_number_%d" % i for i in range(col_num)) + cls.view_str = view_str = ( + "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" + % ( + ",".join( + "long_named_column_number_%d" % i for i in range(col_num) + ) + ) ) assert len(view_str) > 4000 diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index b2e05d951d0..4364872bafe 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -25,6 +25,7 @@ from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import Interval +from sqlalchemy import JSON from sqlalchemy import LargeBinary from sqlalchemy import literal from sqlalchemy import MetaData @@ -53,14 +54,20 @@ from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import match +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.sql import column +from sqlalchemy.sql import delete from sqlalchemy.sql import table +from sqlalchemy.sql import update from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import literal_column from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import eq_ignore_whitespace +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -182,7 +189,7 @@ def test_create_index_with_prefix(self): self.assert_compile( schema.CreateIndex(idx), - "CREATE FULLTEXT INDEX test_idx1 " "ON testtbl (data(10))", + "CREATE FULLTEXT INDEX test_idx1 ON testtbl (data(10))", ) def test_create_index_with_text(self): @@ -406,6 +413,105 @@ def test_create_pk_with_using(self): "PRIMARY KEY (data) USING btree)", ) + @testing.combinations( + (True, True, (10, 2, 2)), + (True, True, (10, 2, 1)), + (False, True, (10, 2, 0)), + (True, False, (8, 0, 14)), + (True, False, (8, 0, 13)), + (False, False, (8, 0, 12)), + argnames="has_brackets,is_mariadb,version", + ) + def test_create_server_default_with_function_using( + self, has_brackets, is_mariadb, version + ): + dialect = mysql.dialect(is_mariadb=is_mariadb) + dialect.server_version_info = version + + m = MetaData() + tbl = Table( + "testtbl", + m, + Column("time", DateTime, server_default=func.current_timestamp()), + Column("name", String(255), server_default="some str"), + Column( + "description", String(255), server_default=func.lower("hi") + ), + Column("data", JSON, server_default=func.json_object()), + Column( + "updated1", + DateTime, + server_default=text("now() on update now()"), + ), + Column( + "updated2", + DateTime, + server_default=text("now() On UpDate now()"), + ), + Column( + "updated3", + DateTime, + server_default=text("now() ON UPDATE now()"), + ), + Column( + "updated4", + DateTime, + server_default=text("now(3)"), + ), + Column( + "updated5", + DateTime, + server_default=text("nOW(3)"), + ), + Column( + "updated6", + DateTime, + server_default=text("notnow(1)"), + ), + Column( + "updated7", + DateTime, + server_default=text("CURRENT_TIMESTAMP(3)"), + ), + ) + + eq_(dialect._support_default_function, has_brackets) + + if has_brackets: + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE testtbl (" + "time DATETIME DEFAULT CURRENT_TIMESTAMP, " + "name VARCHAR(255) DEFAULT 'some str', " + "description VARCHAR(255) DEFAULT (lower('hi')), " + "data JSON DEFAULT (json_object()), " + "updated1 DATETIME DEFAULT now() on update now(), " + "updated2 DATETIME DEFAULT now() On UpDate now(), " + "updated3 DATETIME DEFAULT now() ON UPDATE now(), " + "updated4 DATETIME DEFAULT now(3), " + "updated5 DATETIME DEFAULT nOW(3), " + "updated6 DATETIME DEFAULT (notnow(1)), " + "updated7 DATETIME DEFAULT CURRENT_TIMESTAMP(3))", + dialect=dialect, + ) + else: + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE testtbl (" + "time DATETIME DEFAULT CURRENT_TIMESTAMP, " + "name VARCHAR(255) DEFAULT 'some str', " + "description VARCHAR(255) DEFAULT lower('hi'), " + "data JSON DEFAULT json_object(), " + "updated1 DATETIME DEFAULT now() on update now(), " + "updated2 DATETIME DEFAULT now() On UpDate now(), " + "updated3 DATETIME DEFAULT now() ON UPDATE now(), " + "updated4 DATETIME DEFAULT now(3), " + "updated5 DATETIME DEFAULT nOW(3), " + "updated6 DATETIME DEFAULT notnow(1), " + "updated7 DATETIME DEFAULT CURRENT_TIMESTAMP(3))", + dialect=dialect, + ) + def test_create_index_expr(self): m = MetaData() t1 = Table("foo", m, Column("x", Integer)) @@ -567,7 +673,6 @@ def test_groupby_rollup(self): class SQLTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests MySQL-dialect specific compilation.""" __dialect__ = mysql.dialect() @@ -674,6 +779,14 @@ def test_update_limit(self): .with_dialect_options(mysql_limit=5), "UPDATE t SET col1=%s LIMIT 5", ) + + # does not make sense but we want this to compile + self.assert_compile( + t.update() + .values({"col1": 123}) + .with_dialect_options(mysql_limit=0), + "UPDATE t SET col1=%s LIMIT 0", + ) self.assert_compile( t.update() .values({"col1": 123}) @@ -688,6 +801,39 @@ def test_update_limit(self): "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1", ) + def test_delete_limit(self): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + self.assert_compile(t.delete(), "DELETE FROM t") + self.assert_compile( + t.delete().with_dialect_options(mysql_limit=5), + "DELETE FROM t LIMIT 5", + ) + # does not make sense but we want this to compile + self.assert_compile( + t.delete().with_dialect_options(mysql_limit=0), + "DELETE FROM t LIMIT 0", + ) + self.assert_compile( + t.delete().with_dialect_options(mysql_limit=None), + "DELETE FROM t", + ) + self.assert_compile( + t.delete() + .where(t.c.col2 == 456) + .with_dialect_options(mysql_limit=1), + "DELETE FROM t WHERE t.col2 = %s LIMIT 1", + ) + + @testing.combinations((update,), (delete,)) + def test_update_delete_limit_int_only(self, crud_fn): + t = sql.table("t", sql.column("col1"), sql.column("col2")) + + with expect_raises(ValueError): + crud_fn(t).with_dialect_options(mysql_limit="not an int").compile( + dialect=mysql.dialect() + ) + def test_utc_timestamp(self): self.assert_compile(func.utc_timestamp(), "utc_timestamp()") @@ -877,7 +1023,7 @@ def test_too_long_index(self): self.assert_compile( schema.CreateIndex(ix1), - "CREATE INDEX %s " "ON %s (%s)" % (exp, tname, cname), + "CREATE INDEX %s ON %s (%s)" % (exp, tname, cname), ) def test_innodb_autoincrement(self): @@ -1128,6 +1274,31 @@ def test_from_values(self, version: Variation): self.assert_compile(stmt, expected_sql, dialect=dialect) + @testing.variation("version", ["mysql8", "all_others"]) + def test_from_select(self, version: Variation): + stmt = insert(self.table).from_select( + ["id", "bar"], + select(self.table.c.id, literal("bar2")), + ) + stmt = stmt.on_duplicate_key_update( + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) + + expected_sql = ( + "INSERT INTO foos (id, bar) SELECT foos.id, %s AS anon_1 " + "FROM foos " + "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)" + ) + if version.all_others: + dialect = None + elif version.mysql8: + dialect = mysql.dialect() + dialect._requires_alias_for_on_duplicate_key = True + else: + version.fail() + + self.assert_compile(stmt, expected_sql, dialect=dialect) + def test_from_literal(self): stmt = insert(self.table).values( [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] @@ -1225,6 +1396,25 @@ def test_mysql8_on_update_dont_dup_alias_name(self): dialect=dialect, ) + def test_on_update_instrumented_attribute_dict(self): + class Base(DeclarativeBase): + pass + + class T(Base): + __tablename__ = "table" + + foo: Mapped[int] = mapped_column(Integer, primary_key=True) + + q = insert(T).values(foo=1).on_duplicate_key_update({T.foo: 2}) + self.assert_compile( + q, + ( + "INSERT INTO `table` (foo) VALUES (%s) " + "ON DUPLICATE KEY UPDATE foo = %s" + ), + {"foo": 1, "param_1": 2}, + ) + class RegexpCommon(testing.AssertsCompiledSQL): def setup_test(self): diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index c50755df414..7e31c666f3a 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -257,21 +257,40 @@ def test_ssl_arguments(self, driver_name): ("read_timeout", 30), ("write_timeout", 30), ("client_flag", 1234), - ("local_infile", 1234), + ("local_infile", 1), + ("local_infile", True), + ("local_infile", False), ("use_unicode", False), ("charset", "hello"), + ("unix_socket", "somesocket"), + argnames="kwarg, value", ) - def test_normal_arguments_mysqldb(self, kwarg, value): - from sqlalchemy.dialects.mysql import mysqldb + @testing.combinations( + ("mysql+mysqldb", ()), + ("mysql+mariadbconnector", {"use_unicode", "charset"}), + ("mariadb+mariadbconnector", {"use_unicode", "charset"}), + ("mysql+pymysql", ()), + ( + "mysql+mysqlconnector", + {"read_timeout", "write_timeout", "local_infile"}, + ), + argnames="dialect_name,skip", + ) + def test_query_arguments(self, kwarg, value, dialect_name, skip): - dialect = mysqldb.dialect() - connect_args = dialect.create_connect_args( - make_url( - "mysql+mysqldb://scott:tiger@localhost:3306/test" - "?%s=%s" % (kwarg, value) - ) + if kwarg in skip: + return + + url_value = {True: "true", False: "false"}.get(value, value) + + url = make_url( + f"{dialect_name}://scott:tiger@" + f"localhost:3306/test?{kwarg}={url_value}" ) + dialect = url.get_dialect()() + + connect_args = dialect.create_connect_args(url) eq_(connect_args[1][kwarg], value) def test_mysqlconnector_buffered_arg(self): @@ -283,15 +302,19 @@ def test_mysqlconnector_buffered_arg(self): )[1] eq_(kw["buffered"], True) - kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db?buffered=false") - )[1] - eq_(kw["buffered"], False) + # this is turned off for now due to + # https://bugs.mysql.com/bug.php?id=117548 + if dialect.supports_server_side_cursors: + kw = dialect.create_connect_args( + make_url("mysql+mysqlconnector://u:p@host/db?buffered=false") + )[1] + eq_(kw["buffered"], False) - kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db") - )[1] - eq_(kw["buffered"], True) + kw = dialect.create_connect_args( + make_url("mysql+mysqlconnector://u:p@host/db") + )[1] + # defaults to False as of 2.0.39 + eq_(kw.get("buffered"), None) def test_mysqlconnector_raise_on_warnings_arg(self): from sqlalchemy.dialects.mysql import mysqlconnector @@ -320,8 +343,10 @@ def test_mysqlconnector_raise_on_warnings_arg(self): [ "mysql+mysqldb", "mysql+pymysql", + "mysql+mariadbconnector", "mariadb+mysqldb", "mariadb+pymysql", + "mariadb+mariadbconnector", ] ) def test_random_arg(self): diff --git a/test/dialect/mysql/test_for_update.py b/test/dialect/mysql/test_for_update.py index 5717a32997c..5c26d8eb6d5 100644 --- a/test/dialect/mysql/test_for_update.py +++ b/test/dialect/mysql/test_for_update.py @@ -3,6 +3,7 @@ See #4246 """ + import contextlib from sqlalchemy import Column @@ -89,7 +90,11 @@ def _assert_a_is_locked(self, should_be_locked): # set x/y > 10 try: alt_trans.execute(update(A).values(x=15, y=19)) - except (exc.InternalError, exc.OperationalError) as err: + except ( + exc.InternalError, + exc.OperationalError, + exc.DatabaseError, + ) as err: assert "Lock wait timeout exceeded" in str(err) assert should_be_locked else: @@ -102,7 +107,11 @@ def _assert_b_is_locked(self, should_be_locked): # set x/y > 10 try: alt_trans.execute(update(B).values(x=15, y=19)) - except (exc.InternalError, exc.OperationalError) as err: + except ( + exc.InternalError, + exc.OperationalError, + exc.DatabaseError, + ) as err: assert "Lock wait timeout exceeded" in str(err) assert should_be_locked else: diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index 5a4e6ca8d5c..35aebb470c3 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -3,6 +3,8 @@ from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy.dialects.mysql import insert @@ -63,6 +65,22 @@ def test_on_duplicate_key_update_multirow(self, connection): [(1, "ab", "bz", False)], ) + def test_on_duplicate_key_from_select(self, connection): + foos = self.tables.foos + conn = connection + conn.execute(insert(foos).values(dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).from_select( + ["id", "bar", "baz"], + select(foos.c.id, literal("bar2"), literal("baz2")), + ) + stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + + conn.execute(stmt) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "bar2", "bz", False)], + ) + def test_on_duplicate_key_update_singlerow(self, connection): foos = self.tables.foos conn = connection diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 9cbc38378fb..890c9edbf9d 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -5,17 +5,23 @@ from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Computed +from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import false from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text from sqlalchemy import true +from sqlalchemy.dialects.mysql import TIMESTAMP from sqlalchemy.testing import assert_raises from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ @@ -50,6 +56,60 @@ def test_is_boolean_symbols_despite_no_native(self, connection): ) +class ServerDefaultCreateTest(fixtures.TestBase): + __only_on__ = "mysql", "mariadb" + __backend__ = True + + @testing.combinations( + (Integer, text("10")), + (Integer, text("'10'")), + (Integer, "10"), + (Boolean, true()), + (Integer, text("3+5"), testing.requires.mysql_expression_defaults), + (Integer, text("3 + 5"), testing.requires.mysql_expression_defaults), + (Integer, text("(3 * 5)"), testing.requires.mysql_expression_defaults), + (DateTime, func.now()), + ( + Integer, + literal_column("3") + literal_column("5"), + testing.requires.mysql_expression_defaults, + ), + ( + DateTime, + text("now() ON UPDATE now()"), + ), + ( + DateTime, + text("now() on update now()"), + ), + ( + DateTime, + text("now() ON UPDATE now()"), + ), + ( + TIMESTAMP(fsp=3), + text("now(3)"), + testing.requires.mysql_fsp, + ), + ( + TIMESTAMP(fsp=3), + text("CURRENT_TIMESTAMP(3)"), + testing.requires.mysql_fsp, + ), + argnames="datatype, default", + ) + def test_create_server_defaults( + self, connection, metadata, datatype, default + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + class MatchTest(fixtures.TablesTest): __only_on__ = "mysql", "mariadb" __backend__ = True diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index f3d1f34599b..92cf3818e24 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -764,103 +764,152 @@ def test_system_views(self): view_names = dialect.get_view_names(connection, "information_schema") self.assert_("TABLES" in view_names) - def test_nullable_reflection(self, metadata, connection): - """test reflection of NULL/NOT NULL, in particular with TIMESTAMP - defaults where MySQL is inconsistent in how it reports CREATE TABLE. - - """ - meta = metadata - - # this is ideally one table, but older MySQL versions choke - # on the multiple TIMESTAMP columns - row = connection.exec_driver_sql( - "show variables like '%%explicit_defaults_for_timestamp%%'" - ).first() - explicit_defaults_for_timestamp = row[1].lower() in ("on", "1", "true") - - reflected = [] - for idx, cols in enumerate( + @testing.combinations( + ( [ - [ - "x INTEGER NULL", - "y INTEGER NOT NULL", - "z INTEGER", - "q TIMESTAMP NULL", - ], - ["p TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP"], - ["r TIMESTAMP NOT NULL"], - ["s TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"], - ["t TIMESTAMP"], - ["u TIMESTAMP DEFAULT CURRENT_TIMESTAMP"], - ] - ): - Table("nn_t%d" % idx, meta) # to allow DROP - - connection.exec_driver_sql( - """ - CREATE TABLE nn_t%d ( - %s - ) - """ - % (idx, ", \n".join(cols)) - ) - - reflected.extend( - { - "name": d["name"], - "nullable": d["nullable"], - "default": d["default"], - } - for d in inspect(connection).get_columns("nn_t%d" % idx) - ) - - if connection.dialect._is_mariadb_102: - current_timestamp = "current_timestamp()" - else: - current_timestamp = "CURRENT_TIMESTAMP" - - eq_( - reflected, + "x INTEGER NULL", + "y INTEGER NOT NULL", + "z INTEGER", + "q TIMESTAMP NULL", + ], [ {"name": "x", "nullable": True, "default": None}, {"name": "y", "nullable": False, "default": None}, {"name": "z", "nullable": True, "default": None}, {"name": "q", "nullable": True, "default": None}, - {"name": "p", "nullable": True, "default": current_timestamp}, + ], + ), + ( + ["p TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP"], + [ + { + "name": "p", + "nullable": True, + "default": "CURRENT_TIMESTAMP", + } + ], + ), + ( + ["r TIMESTAMP NOT NULL"], + [ { "name": "r", "nullable": False, - "default": None - if explicit_defaults_for_timestamp - else ( - "%(current_timestamp)s " - "ON UPDATE %(current_timestamp)s" - ) - % {"current_timestamp": current_timestamp}, - }, - {"name": "s", "nullable": False, "default": current_timestamp}, + "default": None, + "non_explicit_defaults_for_ts_default": ( + "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" + ), + } + ], + ), + ( + ["s TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"], + [ + { + "name": "s", + "nullable": False, + "default": "CURRENT_TIMESTAMP", + } + ], + ), + ( + ["t TIMESTAMP"], + [ { "name": "t", - "nullable": True - if explicit_defaults_for_timestamp - else False, - "default": None - if explicit_defaults_for_timestamp - else ( - "%(current_timestamp)s " - "ON UPDATE %(current_timestamp)s" - ) - % {"current_timestamp": current_timestamp}, - }, + "nullable": True, + "default": None, + "non_explicit_defaults_for_ts_nullable": False, + "non_explicit_defaults_for_ts_default": ( + "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" + ), + } + ], + ), + ( + ["u TIMESTAMP DEFAULT CURRENT_TIMESTAMP"], + [ { "name": "u", - "nullable": True - if explicit_defaults_for_timestamp - else False, - "default": current_timestamp, - }, + "nullable": True, + "non_explicit_defaults_for_ts_nullable": False, + "default": "CURRENT_TIMESTAMP", + } ], - ) + ), + ( + ["v INTEGER GENERATED ALWAYS AS (4711) VIRTUAL NOT NULL"], + [ + { + "name": "v", + "nullable": False, + "default": None, + } + ], + testing.requires.mysql_notnull_generated_columns, + ), + argnames="ddl_columns,expected_reflected", + ) + def test_nullable_reflection( + self, metadata, connection, ddl_columns, expected_reflected + ): + """test reflection of NULL/NOT NULL, in particular with TIMESTAMP + defaults where MySQL is inconsistent in how it reports CREATE TABLE. + + """ + row = connection.exec_driver_sql( + "show variables like '%%explicit_defaults_for_timestamp%%'" + ).first() + explicit_defaults_for_timestamp = row[1].lower() in ("on", "1", "true") + + def get_expected_default(er): + if ( + not explicit_defaults_for_timestamp + and "non_explicit_defaults_for_ts_default" in er + ): + default = er["non_explicit_defaults_for_ts_default"] + else: + default = er["default"] + + if default is not None and connection.dialect._is_mariadb_102: + default = default.replace( + "CURRENT_TIMESTAMP", "current_timestamp()" + ) + + return default + + def get_expected_nullable(er): + if ( + not explicit_defaults_for_timestamp + and "non_explicit_defaults_for_ts_nullable" in er + ): + return er["non_explicit_defaults_for_ts_nullable"] + else: + return er["nullable"] + + expected_reflected = [ + { + "name": er["name"], + "nullable": get_expected_nullable(er), + "default": get_expected_default(er), + } + for er in expected_reflected + ] + + Table("nullable_refl", metadata) + + cols_ddl = ", \n".join(ddl_columns) + connection.exec_driver_sql(f"CREATE TABLE nullable_refl ({cols_ddl})") + + reflected = [ + { + "name": d["name"], + "nullable": d["nullable"], + "default": d["default"], + } + for d in inspect(connection).get_columns("nullable_refl") + ] + eq_(reflected, expected_reflected) def test_reflection_with_unique_constraint(self, metadata, connection): insp = inspect(connection) @@ -1148,7 +1197,7 @@ def test_correct_for_mysql_bugs_88718_96365(self): dialect._casing = casing dialect.default_schema_name = "Test" connection = mock.Mock( - dialect=dialect, execute=lambda stmt, params: ischema + dialect=dialect, execute=lambda stmt: ischema ) dialect._correct_for_mysql_bugs_88718_96365(fkeys, connection) eq_( @@ -1508,7 +1557,7 @@ def test_fk_reflection(self): " CONSTRAINT `addresses_user_id_fkey` " "FOREIGN KEY (`user_id`) " "REFERENCES `users` (`id`) " - "ON DELETE CASCADE ON UPDATE SET NULL" + "ON DELETE SET DEFAULT ON UPDATE SET NULL" ) eq_( m.groups(), @@ -1518,7 +1567,7 @@ def test_fk_reflection(self): "`users`", "`id`", None, - "CASCADE", + "SET DEFAULT", "SET NULL", ), ) diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index 1d279e720db..2e5033ec571 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -21,6 +21,7 @@ from sqlalchemy import types as sqltypes from sqlalchemy import UnicodeText from sqlalchemy.dialects.mysql import base as mysql +from sqlalchemy.dialects.mysql import mariadb from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -385,7 +386,7 @@ def test_timestamp_fsp(self): mysql.MSTimeStamp(), DefaultClause( sql.text( - "'1999-09-09 09:09:09' " "ON UPDATE CURRENT_TIMESTAMP" + "'1999-09-09 09:09:09' ON UPDATE CURRENT_TIMESTAMP" ) ), ], @@ -398,7 +399,7 @@ def test_timestamp_fsp(self): mysql.MSTimeStamp, DefaultClause( sql.text( - "'1999-09-09 09:09:09' " "ON UPDATE CURRENT_TIMESTAMP" + "'1999-09-09 09:09:09' ON UPDATE CURRENT_TIMESTAMP" ) ), ], @@ -410,9 +411,7 @@ def test_timestamp_fsp(self): [ mysql.MSTimeStamp(), DefaultClause( - sql.text( - "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" - ) + sql.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") ), ], {}, @@ -423,9 +422,7 @@ def test_timestamp_fsp(self): [ mysql.MSTimeStamp, DefaultClause( - sql.text( - "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" - ) + sql.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") ), ], {"nullable": False}, @@ -478,6 +475,17 @@ def test_float_type_compile(self, type_, sql_text): self.assert_compile(type_, sql_text) +class INETMariadbTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = mariadb.MariaDBDialect() + + @testing.combinations( + (mariadb.INET4(), "INET4"), + (mariadb.INET6(), "INET6"), + ) + def test_mariadb_inet6(self, type_, res): + self.assert_compile(type_, res) + + class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): __dialect__ = mysql.dialect() __only_on__ = "mysql", "mariadb" @@ -1209,7 +1217,7 @@ def test_enum_compile(self): t1 = Table("sometable", MetaData(), Column("somecolumn", e1)) self.assert_compile( schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " "ENUM('x','y','z'))", + "CREATE TABLE sometable (somecolumn ENUM('x','y','z'))", ) t1 = Table( "sometable", diff --git a/test/dialect/oracle/_oracledb_mode.py b/test/dialect/oracle/_oracledb_mode.py index a02a5389b2c..d9c426b4bb9 100644 --- a/test/dialect/oracle/_oracledb_mode.py +++ b/test/dialect/oracle/_oracledb_mode.py @@ -5,7 +5,7 @@ def _get_version(conn): # this is the suggested way of finding the mode, from - # https://python-oracledb.readthedocs.io/en/latest/user_guide/tracing.html#vsessconinfo + # https://python-oracledb.readthedocs.io/en/latest/user_guide/tracing.html#finding-the-python-oracledb-mode sql = ( "SELECT UNIQUE CLIENT_DRIVER " "FROM V$SESSION_CONNECT_INFO " diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index c7a6858d4cb..7effcf3aa58 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -92,7 +92,7 @@ def test_owner(self): ) self.assert_compile( parent.join(child), - "ed.parent JOIN ed.child ON ed.parent.id = " "ed.child.parent_id", + "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id", ) def test_subquery(self): @@ -310,6 +310,17 @@ def test_simple_fetch_offset(self): checkparams={"param_1": 20, "param_2": 10}, ) + @testing.only_on("oracle>=23.4") + def test_fetch_type(self): + t = table("sometable", column("col1"), column("col2")) + s = select(t).fetch(2, oracle_fetch_approximate=True) + self.assert_compile( + s, + "SELECT sometable.col1, sometable.col2 FROM sometable " + "FETCH APPROX FIRST __[POSTCOMPILE_param_1] ROWS ONLY", + checkparams={"param_1": 2}, + ) + def test_limit_two(self): t = table("sometable", column("col1"), column("col2")) s = select(t).limit(10).offset(20).subquery() @@ -810,8 +821,8 @@ class MyType(TypeDecorator): def test_use_binds_for_limits_disabled_one_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=False, enable_offset_fetch=False @@ -829,8 +840,8 @@ def test_use_binds_for_limits_disabled_one_legacy(self): def test_use_binds_for_limits_disabled_two_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=False, enable_offset_fetch=False @@ -849,8 +860,8 @@ def test_use_binds_for_limits_disabled_two_legacy(self): def test_use_binds_for_limits_disabled_three_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=False, enable_offset_fetch=False @@ -871,8 +882,8 @@ def test_use_binds_for_limits_disabled_three_legacy(self): def test_use_binds_for_limits_enabled_one_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=True, enable_offset_fetch=False @@ -890,8 +901,8 @@ def test_use_binds_for_limits_enabled_one_legacy(self): def test_use_binds_for_limits_enabled_two_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=True, enable_offset_fetch=False @@ -911,8 +922,8 @@ def test_use_binds_for_limits_enabled_two_legacy(self): def test_use_binds_for_limits_enabled_three_legacy(self): t = table("sometable", column("col1"), column("col2")) with testing.expect_deprecated( - "The ``use_binds_for_limits`` Oracle dialect parameter is " - "deprecated." + "The ``use_binds_for_limits`` Oracle Database dialect parameter " + "is deprecated." ): dialect = oracle.OracleDialect( use_binds_for_limits=True, enable_offset_fetch=False @@ -1183,7 +1194,7 @@ def test_outer_join_seven(self): q = select(table1.c.name).where(table1.c.name == "foo") self.assert_compile( q, - "SELECT mytable.name FROM mytable WHERE " "mytable.name = :name_1", + "SELECT mytable.name FROM mytable WHERE mytable.name = :name_1", dialect=oracle.dialect(use_ansi=False), ) @@ -1416,7 +1427,7 @@ def test_returning_update_computed_warning(self): ) with testing.expect_warnings( - "Computed columns don't work with Oracle UPDATE" + "Computed columns don't work with Oracle Database UPDATE" ): self.assert_compile( t1.update().values(id=1, foo=5).returning(t1.c.bar), @@ -1498,7 +1509,7 @@ def test_create_table_compress(self): ) self.assert_compile( schema.CreateTable(tbl2), - "CREATE TABLE testtbl2 (data INTEGER) " "COMPRESS FOR OLTP", + "CREATE TABLE testtbl2 (data INTEGER) COMPRESS FOR OLTP", ) def test_create_index_bitmap_compress(self): @@ -1552,7 +1563,7 @@ def test_column_computed_persisted_true(self): ) assert_raises_message( exc.CompileError, - r".*Oracle computed columns do not support 'stored' ", + r".*Oracle Database computed columns do not support 'stored' ", schema.CreateTable(t).compile, dialect=oracle.dialect(), ) @@ -1627,6 +1638,26 @@ def test_double_to_oracle_double(self): cast(column("foo"), d1), "CAST(foo AS DOUBLE PRECISION)" ) + @testing.combinations( + ("TEST_TABLESPACE", 'TABLESPACE "TEST_TABLESPACE"'), + ("test_tablespace", "TABLESPACE test_tablespace"), + ("TestTableSpace", 'TABLESPACE "TestTableSpace"'), + argnames="tablespace, expected_sql", + ) + def test_table_tablespace(self, tablespace, expected_sql): + m = MetaData() + + t = Table( + "table1", + m, + Column("x", Integer), + oracle_tablespace=tablespace, + ) + self.assert_compile( + schema.CreateTable(t), + f"CREATE TABLE table1 (x INTEGER) {expected_sql}", + ) + class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): def test_basic(self): @@ -1832,3 +1863,15 @@ def test_table_valued(self): "SELECT anon_1.string1, anon_1.string2 " "FROM TABLE (three_pairs()) anon_1", ) + + @testing.combinations(func.TABLE, func.table, func.Table) + def test_table_function(self, fn): + """Issue #12100 Use case is: + https://python-oracledb.readthedocs.io/en/latest/user_guide/bind.html#binding-a-large-number-of-items-in-an-in-list + """ + fn_call = fn("simulate_name_array") + stmt = select(1).select_from(fn_call) + self.assert_compile( + stmt, + f"SELECT 1 FROM {fn_call.name}(:{fn_call.name}_1)", + ) diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 93cf0b74578..1f8a23f70dc 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -36,6 +36,7 @@ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing.assertions import expect_raises_message +from sqlalchemy.testing.assertions import is_ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table @@ -69,6 +70,8 @@ def test_minimum_version(self): class OracleDbDialectTest(fixtures.TestBase): + __only_on__ = "oracle+oracledb" + def test_oracledb_version_parse(self): dialect = oracledb.OracleDialect_oracledb() @@ -84,19 +87,36 @@ def check(version): def test_minimum_version(self): with expect_raises_message( exc.InvalidRequestError, - "oracledb version 1 and above are supported", + r"oracledb version \(1,\) and above are supported", ): oracledb.OracleDialect_oracledb(dbapi=Mock(version="0.1.5")) dialect = oracledb.OracleDialect_oracledb(dbapi=Mock(version="7.1.0")) eq_(dialect.oracledb_ver, (7, 1, 0)) + def test_get_dialect(self): + u = url.URL.create("oracle://") + d = oracledb.OracleDialect_oracledb.get_dialect_cls(u) + is_(d, oracledb.OracleDialect_oracledb) + d = oracledb.OracleDialect_oracledb.get_async_dialect_cls(u) + is_(d, oracledb.OracleDialectAsync_oracledb) + d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u) + is_(d, oracledb.OracleDialectAsync_oracledb) + d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u) + is_(d, oracledb.OracleDialectAsync_oracledb) + + def test_async_version(self): + e = create_engine("oracle+oracledb_async://") + is_true(isinstance(e.dialect, oracledb.OracleDialectAsync_oracledb)) + class OracledbMode(fixtures.TestBase): __backend__ = True __only_on__ = "oracle+oracledb" def _run_in_process(self, fn, fn_kw=None): + if config.db.dialect.is_async: + config.skip_test("thick mode unsupported in async mode") ctx = get_context("spawn") queue = ctx.Queue() process = ctx.Process( @@ -202,6 +222,7 @@ def get_isolation_level(connection): testing.db.dialect.get_isolation_level(dbapi_conn), "READ COMMITTED", ) + conn.close() def test_graceful_failure_isolation_level_not_available(self): engine = engines.testing_engine() @@ -464,7 +485,7 @@ def test_computed_update_warning(self, connection): eq_(result.returned_defaults, (52,)) else: with testing.expect_warnings( - "Computed columns don't work with Oracle UPDATE" + "Computed columns don't work with Oracle Database UPDATE" ): result = conn.execute( test.update().values(foo=10).return_defaults() @@ -511,9 +532,7 @@ def setup_test_class(cls): def test_out_params(self, connection): result = connection.execute( - text( - "begin foo(:x_in, :x_out, :y_out, " ":z_out); end;" - ).bindparams( + text("begin foo(:x_in, :x_out, :y_out, :z_out); end;").bindparams( bindparam("x_in", Float), outparam("x_out", Integer), outparam("y_out", Float), @@ -537,7 +556,7 @@ def test_no_out_params_w_returning(self, connection, metadata): exc.InvalidRequestError, r"Using explicit outparam\(\) objects with " r"UpdateBase.returning\(\) in the same Core DML statement " - "is not supported in the Oracle dialect.", + "is not supported in the Oracle Database dialects.", ): connection.execute(stmt) @@ -662,7 +681,6 @@ def server_version_info(conn): dialect._get_server_version_info = server_version_info dialect.get_isolation_level = Mock() - dialect._check_unicode_returns = Mock() dialect._check_unicode_description = Mock() dialect._get_default_schema_name = Mock() dialect._detect_decimal_char = Mock() @@ -842,7 +860,7 @@ def test_basic(self): with testing.db.connect() as conn: eq_( conn.exec_driver_sql( - "/*+ this is a comment */ SELECT 1 FROM " "DUAL" + "/*+ this is a comment */ SELECT 1 FROM DUAL" ).fetchall(), [(1,)], ) @@ -860,6 +878,7 @@ def test_sequences_are_integers(self, connection): def test_limit_offset_for_update(self, metadata, connection): # oracle can't actually do the ROWNUM thing with FOR UPDATE # very well. + # Seems to be fixed in 23. t = Table( "t1", @@ -884,7 +903,7 @@ def test_limit_offset_for_update(self, metadata, connection): # as of #8221, this fails also. limit w/o order by is useless # in any case. stmt = t.select().with_for_update().limit(2) - if testing.against("oracle>=12"): + if testing.against("oracle>=12") and testing.against("oracle<23"): with expect_raises_message(exc.DatabaseError, "ORA-02014"): connection.execute(stmt).fetchall() else: diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index 2a82c25d9fd..35735889488 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -21,6 +21,11 @@ from sqlalchemy import Unicode from sqlalchemy import UniqueConstraint from sqlalchemy.dialects import oracle +from sqlalchemy.dialects.oracle import VECTOR +from sqlalchemy.dialects.oracle import VectorDistanceType +from sqlalchemy.dialects.oracle import VectorIndexConfig +from sqlalchemy.dialects.oracle import VectorIndexType +from sqlalchemy.dialects.oracle import VectorStorageFormat from sqlalchemy.dialects.oracle.base import BINARY_DOUBLE from sqlalchemy.dialects.oracle.base import BINARY_FLOAT from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION @@ -684,6 +689,39 @@ def test_reflect_hidden_column(self): finally: conn.exec_driver_sql("DROP TABLE my_table") + def test_tablespace(self, connection, metadata): + tbl = Table( + "test_tablespace", + metadata, + Column("data", Integer), + oracle_tablespace="temp", + ) + metadata.create_all(connection) + + m2 = MetaData() + + tbl = Table("test_tablespace", m2, autoload_with=connection) + assert tbl.dialect_options["oracle"]["tablespace"] == "TEMP" + + @testing.only_on("oracle>=23.4") + def test_reflection_w_vector_column(self, connection, metadata): + tb1 = Table( + "test_vector", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + metadata.create_all(connection) + + m2 = MetaData() + + tb1 = Table("test_vector", m2, autoload_with=connection) + assert tb1.columns.keys() == ["id", "name", "embedding"] + class ViewReflectionTest(fixtures.TestBase): __only_on__ = "oracle" @@ -1166,6 +1204,42 @@ def obj_definition(obj): eq_(len(reflectedtable.constraints), 1) eq_(len(reflectedtable.indexes), 5) + @testing.only_on("oracle>=23.4") + def test_vector_index(self, metadata, connection): + tb1 = Table( + "test_vector", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + tb1.create(connection) + + ivf_index = Index( + "ivf_vector_index", + tb1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + ivf_index.create(connection) + + expected = [ + { + "name": "ivf_vector_index", + "column_names": ["embedding"], + "dialect_options": {}, + "unique": False, + }, + ] + eq_(inspect(connection).get_indexes("test_vector"), expected) + class DBLinkReflectionTest(fixtures.TestBase): __requires__ = ("oracle_test_dblink",) @@ -1227,7 +1301,7 @@ def _run_test(self, metadata, connection, specs, attributes): for attr in attributes: r_attr = getattr(reflected_type, attr) e_attr = getattr(expected_spec, attr) - col = f"c{i+1}" + col = f"c{i + 1}" eq_( r_attr, e_attr, @@ -1540,8 +1614,8 @@ def setup_test(self): (schema, "parent"): [], } self.options[schema] = { - (schema, "my_table"): {}, - (schema, "parent"): {}, + (schema, "my_table"): {"oracle_tablespace": "USERS"}, + (schema, "parent"): {"oracle_tablespace": "USERS"}, } def test_tables(self, connection): diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 82a81612e1e..331103e8f25 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1,3 +1,4 @@ +import array import datetime import decimal import os @@ -15,6 +16,7 @@ from sqlalchemy import exc from sqlalchemy import FLOAT from sqlalchemy import Float +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import LargeBinary from sqlalchemy import literal @@ -37,6 +39,11 @@ from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.dialects.oracle import cx_oracle from sqlalchemy.dialects.oracle import oracledb +from sqlalchemy.dialects.oracle import VECTOR +from sqlalchemy.dialects.oracle import VectorDistanceType +from sqlalchemy.dialects.oracle import VectorIndexConfig +from sqlalchemy.dialects.oracle import VectorIndexType +from sqlalchemy.dialects.oracle import VectorStorageFormat from sqlalchemy.sql import column from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises_message @@ -50,6 +57,7 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.util import b +from sqlalchemy.util.concurrency import await_fallback def exec_sql(conn, sql, *args, **kwargs): @@ -375,12 +383,13 @@ def test_interval_literal_processor(self, connection): def test_no_decimal_float_precision(self): with expect_raises_message( exc.ArgumentError, - "Oracle FLOAT types use 'binary precision', which does not " - "convert cleanly from decimal 'precision'. Please specify this " - "type with a separate Oracle variant, such as " + "Oracle Database FLOAT types use 'binary precision', which does " + "not convert cleanly from decimal 'precision'. Please specify " + "this type with a separate Oracle Database variant, such as " r"FLOAT\(precision=5\).with_variant\(oracle.FLOAT\(" r"binary_precision=16\), 'oracle'\), so that the Oracle " - "specific 'binary_precision' may be specified accurately.", + "Database specific 'binary_precision' may be specified " + "accurately.", ): FLOAT(5).compile(dialect=oracle.dialect()) @@ -571,7 +580,7 @@ def _dont_test_numeric_nan_decimal(self, metadata, connection): ) def test_numerics_broken_inspection(self, metadata, connection): - """Numeric scenarios where Oracle type info is 'broken', + """Numeric scenarios where Oracle Database type info is 'broken', returning us precision, scale of the form (0, 0) or (0, -127). We convert to Decimal and let int()/float() processors take over. @@ -950,6 +959,194 @@ def test_longstring(self, metadata, connection): finally: exec_sql(connection, "DROP TABLE Z_TEST") + @testing.only_on("oracle>=23.4") + def test_vector_dim(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) + + t1.create(connection) + eq_(t1.c.c1.type.dim, 3) + + @testing.only_on("oracle>=23.4") + def test_vector_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6, 7, 8, 5]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_insert_array(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=array.array("b", [6, 7, 8, 5])), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + + connection.execute(t1.delete().where(t1.c.id == 1)) + + connection.execute( + t1.insert(), dict(id=1, c1=array.array("b", [6, 7])) + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_multiformat_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6.12, 7.54, 8.33]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6.12, 7.54, 8.33]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_format(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) + + t1.create(connection) + eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32) + + @testing.only_on("oracle>=23.4") + def test_vector_hnsw_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + + t1.create(connection) + + hnsw_index = Index( + "hnsw_vector_index", t1.c.embedding, oracle_vector=True + ) + hnsw_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_ivf_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + + t1.create(connection) + ivf_index = Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + ivf_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_l2_distance(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.INT8), + ), + ) + + t1.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10])) + connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3])) + connection.execute( + t1.insert(), + dict(id=3, embedding=[15, 16, 17]), + ) + + query_vector = [2, 3, 4] + res = connection.execute( + t1.select().order_by((t1.c.embedding.l2_distance(query_vector))) + ).first() + eq_(res.embedding, [1, 2, 3]) + class LOBFetchTest(fixtures.TablesTest): __only_on__ = "oracle" @@ -998,13 +1195,23 @@ def insert_data(cls, connection): for i in range(1, 11): connection.execute(binary_table.insert(), dict(id=i, data=stream)) + def _read_lob(self, engine, row): + if engine.dialect.is_async: + data = await_fallback(row._mapping["data"].read()) + bindata = await_fallback(row._mapping["bindata"].read()) + else: + data = row._mapping["data"].read() + bindata = row._mapping["bindata"].read() + return data, bindata + def test_lobs_without_convert(self): engine = testing_engine(options=dict(auto_convert_lobs=False)) t = self.tables.z_test with engine.begin() as conn: row = conn.execute(t.select().where(t.c.id == 1)).first() - eq_(row._mapping["data"].read(), "this is text 1") - eq_(row._mapping["bindata"].read(), b("this is binary 1")) + data, bindata = self._read_lob(engine, row) + eq_(data, "this is text 1") + eq_(bindata, b("this is binary 1")) def test_lobs_with_convert(self, connection): t = self.tables.z_test @@ -1028,17 +1235,13 @@ def test_lobs_without_convert_many_rows(self): results = result.fetchall() def go(): - eq_( - [ - dict( - id=row._mapping["id"], - data=row._mapping["data"].read(), - bindata=row._mapping["bindata"].read(), - ) - for row in results - ], - self.data, - ) + actual = [] + for row in results: + data, bindata = self._read_lob(engine, row) + actual.append( + dict(id=row._mapping["id"], data=data, bindata=bindata) + ) + eq_(actual, self.data) # this comes from cx_Oracle because these are raw # cx_Oracle.Variable objects diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py index ed3d63d8336..98410f72e89 100644 --- a/test/dialect/postgresql/test_async_pg_py3k.py +++ b/test/dialect/postgresql/test_async_pg_py3k.py @@ -10,11 +10,14 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy.dialects.postgresql import asyncpg as asyncpg_dialect from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy.testing import async_test from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.util import greenlet_spawn class AsyncPgTest(fixtures.TestBase): @@ -165,6 +168,112 @@ async def async_setup(engine, enums): ], ) + @testing.combinations( + None, + "read committed", + "repeatable read", + "serializable", + argnames="isolation_level", + ) + @async_test + async def test_honor_server_level_iso_setting( + self, async_testing_engine, isolation_level + ): + """test for #12159""" + + engine = async_testing_engine() + + arg, kw = engine.dialect.create_connect_args(engine.url) + + # 1. create an asyncpg.connection directly, set a session level + # isolation level on it (this is similar to server default isolation + # level) + raw_asyncpg_conn = await engine.dialect.dbapi.asyncpg.connect( + *arg, **kw + ) + + if isolation_level: + await raw_asyncpg_conn.execute( + f"set SESSION CHARACTERISTICS AS TRANSACTION " + f"isolation level {isolation_level}" + ) + + # 2. fetch it, confirm the setting took and matches + raw_iso_level = ( + await raw_asyncpg_conn.fetchrow("show transaction isolation level") + )[0] + if isolation_level: + eq_(raw_iso_level, isolation_level.lower()) + + # 3.build our pep-249 wrapper around asyncpg.connection + dbapi_conn = asyncpg_dialect.AsyncAdapt_asyncpg_connection( + engine.dialect.dbapi, + raw_asyncpg_conn, + ) + + # 4. show the isolation level inside of a query. this will + # call asyncpg.connection.transaction() in order to run the + # statement. + cursor = await greenlet_spawn(dbapi_conn.cursor) + await greenlet_spawn( + cursor.execute, "show transaction isolation level" + ) + row = cursor.fetchone() + + # 5. see that the raw iso level is maintained + eq_(row[0], raw_iso_level) + + await greenlet_spawn(dbapi_conn.close) + + @testing.variation("trans", ["commit", "rollback"]) + @async_test + async def test_dont_reset_open_transaction( + self, trans, async_testing_engine + ): + """test for #11819""" + + engine = async_testing_engine() + + control_conn = await engine.connect() + await control_conn.execution_options(isolation_level="AUTOCOMMIT") + + conn = await engine.connect() + txid_current = ( + await conn.exec_driver_sql("select txid_current()") + ).scalar() + + with expect_raises(exc.MissingGreenlet): + if trans.commit: + conn.sync_connection.connection.dbapi_connection.commit() + elif trans.rollback: + conn.sync_connection.connection.dbapi_connection.rollback() + else: + trans.fail() + + trans_exists = ( + await control_conn.exec_driver_sql( + f"SELECT count(*) FROM pg_stat_activity " + f"where backend_xid={txid_current}" + ) + ).scalar() + eq_(trans_exists, 1) + + if trans.commit: + await conn.commit() + elif trans.rollback: + await conn.rollback() + else: + trans.fail() + + trans_exists = ( + await control_conn.exec_driver_sql( + f"SELECT count(*) FROM pg_stat_activity " + f"where backend_xid={txid_current}" + ) + ).scalar() + eq_(trans_exists, 0) + await engine.dispose() + @async_test async def test_failed_commit_recover(self, metadata, async_testing_engine): Table("t1", metadata, Column("id", Integer, primary_key=True)) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 5851a86e6d6..2a763593b2e 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,3 +1,5 @@ +import random + from sqlalchemy import and_ from sqlalchemy import BigInteger from sqlalchemy import bindparam @@ -20,6 +22,7 @@ from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import null +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import Sequence @@ -35,6 +38,7 @@ from sqlalchemy import types as sqltypes from sqlalchemy import UniqueConstraint from sqlalchemy import update +from sqlalchemy import VARCHAR from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY @@ -52,6 +56,7 @@ from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.dialects.postgresql.ranges import MultiRange from sqlalchemy.orm import aliased from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import Session @@ -61,6 +66,7 @@ from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises @@ -262,7 +268,7 @@ def test_generic_enum(self): ) self.assert_compile( postgresql.CreateEnumType(e2), - "CREATE TYPE someschema.somename AS ENUM " "('x', 'y', 'z')", + "CREATE TYPE someschema.somename AS ENUM ('x', 'y', 'z')", ) self.assert_compile(postgresql.DropEnumType(e1), "DROP TYPE somename") self.assert_compile( @@ -271,7 +277,7 @@ def test_generic_enum(self): t1 = Table("sometable", MetaData(), Column("somecolumn", e1)) self.assert_compile( schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " "somename)", + "CREATE TABLE sometable (somecolumn somename)", ) t1 = Table( "sometable", @@ -582,6 +588,19 @@ def test_create_table_with_oncommit_option(self): "CREATE TABLE atable (id INTEGER) ON COMMIT DROP", ) + def test_create_table_with_using_option(self): + m = MetaData() + tbl = Table( + "atable", + m, + Column("id", Integer), + postgresql_using="heap", + ) + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE atable (id INTEGER) USING heap", + ) + def test_create_table_with_multiple_options(self): m = MetaData() tbl = Table( @@ -591,10 +610,11 @@ def test_create_table_with_multiple_options(self): postgresql_tablespace="sometablespace", postgresql_with_oids=False, postgresql_on_commit="preserve_rows", + postgresql_using="heap", ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) WITHOUT OIDS " + "CREATE TABLE atable (id INTEGER) USING heap WITHOUT OIDS " "ON COMMIT PRESERVE ROWS TABLESPACE sometablespace", ) @@ -668,7 +688,7 @@ def test_create_index_with_ops(self): self.assert_compile( schema.CreateIndex(idx), - "CREATE INDEX test_idx1 ON testtbl " "(data text_pattern_ops)", + "CREATE INDEX test_idx1 ON testtbl (data text_pattern_ops)", dialect=postgresql.dialect(), ) self.assert_compile( @@ -711,7 +731,7 @@ def test_create_index_with_ops(self): unique=True, ) ), - "CREATE UNIQUE INDEX test_idx3 ON test_tbl " "(data3)", + "CREATE UNIQUE INDEX test_idx3 ON test_tbl (data3)", ), ( lambda tbl: schema.CreateIndex( @@ -774,6 +794,40 @@ def test_nulls_not_distinct(self, expr_fn, expected): expr = testing.resolve_lambda(expr_fn, tbl=tbl) self.assert_compile(expr, expected, dialect=dd) + @testing.combinations( + ( + lambda tbl: schema.AddConstraint( + UniqueConstraint(tbl.c.id, postgresql_include=[tbl.c.value]) + ), + "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)", + ), + ( + lambda tbl: schema.AddConstraint( + PrimaryKeyConstraint( + tbl.c.id, postgresql_include=[tbl.c.value, "misc"] + ) + ), + "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)", + ), + ( + lambda tbl: schema.CreateIndex( + Index("idx", tbl.c.id, postgresql_include=[tbl.c.value]) + ), + "CREATE INDEX idx ON foo (id) INCLUDE (value)", + ), + ) + def test_include(self, expr_fn, expected): + m = MetaData() + tbl = Table( + "foo", + m, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + Column("misc", String), + ) + expr = testing.resolve_lambda(expr_fn, tbl=tbl) + self.assert_compile(expr, expected) + def test_create_index_with_labeled_ops(self): m = MetaData() tbl = Table( @@ -878,17 +932,17 @@ def test_create_index_with_using(self): self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX test_idx1 ON testtbl " "(data)", + "CREATE INDEX test_idx1 ON testtbl (data)", dialect=postgresql.dialect(), ) self.assert_compile( schema.CreateIndex(idx2), - "CREATE INDEX test_idx2 ON testtbl " "USING btree (data)", + "CREATE INDEX test_idx2 ON testtbl USING btree (data)", dialect=postgresql.dialect(), ) self.assert_compile( schema.CreateIndex(idx3), - "CREATE INDEX test_idx3 ON testtbl " "USING hash (data)", + "CREATE INDEX test_idx3 ON testtbl USING hash (data)", dialect=postgresql.dialect(), ) @@ -909,7 +963,7 @@ def test_create_index_with_with(self): self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX test_idx1 ON testtbl " "(data)", + "CREATE INDEX test_idx1 ON testtbl (data)", ) self.assert_compile( schema.CreateIndex(idx2), @@ -932,7 +986,7 @@ def test_create_index_with_using_unusual_conditions(self): schema.CreateIndex( Index("test_idx1", tbl.c.data, postgresql_using="GIST") ), - "CREATE INDEX test_idx1 ON testtbl " "USING gist (data)", + "CREATE INDEX test_idx1 ON testtbl USING gist (data)", ) self.assert_compile( @@ -974,7 +1028,7 @@ def test_create_index_with_tablespace(self): self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX test_idx1 ON testtbl " "(data)", + "CREATE INDEX test_idx1 ON testtbl (data)", dialect=postgresql.dialect(), ) self.assert_compile( @@ -1124,6 +1178,48 @@ def test_create_foreign_key_column_not_valid(self): ")", ) + def test_create_foreign_key_constraint_ondelete_column_list(self): + m = MetaData() + pktable = Table( + "pktable", + m, + Column("tid", Integer, primary_key=True), + Column("id", Integer, primary_key=True), + ) + fktable = Table( + "fktable", + m, + Column("tid", Integer), + Column("id", Integer), + Column("fk_id_del_set_null", Integer), + Column("fk_id_del_set_default", Integer, server_default=text("0")), + ForeignKeyConstraint( + columns=["tid", "fk_id_del_set_null"], + refcolumns=[pktable.c.tid, pktable.c.id], + ondelete="SET NULL (fk_id_del_set_null)", + ), + ForeignKeyConstraint( + columns=["tid", "fk_id_del_set_default"], + refcolumns=[pktable.c.tid, pktable.c.id], + ondelete="SET DEFAULT(fk_id_del_set_default)", + ), + ) + + self.assert_compile( + schema.CreateTable(fktable), + "CREATE TABLE fktable (" + "tid INTEGER, id INTEGER, " + "fk_id_del_set_null INTEGER, " + "fk_id_del_set_default INTEGER DEFAULT 0, " + "FOREIGN KEY(tid, fk_id_del_set_null)" + " REFERENCES pktable (tid, id)" + " ON DELETE SET NULL (fk_id_del_set_null), " + "FOREIGN KEY(tid, fk_id_del_set_default)" + " REFERENCES pktable (tid, id)" + " ON DELETE SET DEFAULT(fk_id_del_set_default)" + ")", + ) + def test_exclude_constraint_min(self): m = MetaData() tbl = Table("testtbl", m, Column("room", Integer, primary_key=True)) @@ -1715,6 +1811,15 @@ def test_for_update(self): "FOR UPDATE OF table1", ) + # test issue #12417 + subquery = select(table1.c.myid).with_for_update(of=table1).lateral() + statement = select(subquery.c.myid) + self.assert_compile( + statement, + "SELECT anon_1.myid FROM LATERAL (SELECT mytable.myid AS myid " + "FROM mytable FOR UPDATE OF mytable) AS anon_1", + ) + def test_for_update_with_schema(self): m = MetaData() table1 = Table( @@ -1919,6 +2024,14 @@ def test_array_literal_type(self): String, ) + @testing.combinations( + ("with type_", Date, "ARRAY[]::DATE[]"), + ("no type_", None, "ARRAY[]"), + id_="iaa", + ) + def test_array_literal_empty(self, type_, expected): + self.assert_compile(postgresql.array([], type_=type_), expected) + def test_array_literal(self): self.assert_compile( func.array_dims( @@ -2069,7 +2182,7 @@ def test_update_array_slice(self): # default dialect does not, as DBAPIs may be doing this for us self.assert_compile( t.update().values({t.c.data[2:5]: [2, 3, 4]}), - "UPDATE t SET data[%s:%s]=" "%s", + "UPDATE t SET data[%s:%s]=%s", checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2}, dialect=PGDialect(paramstyle="format"), ) @@ -2125,7 +2238,7 @@ def test_from_only(self): tbl3 = Table("testtbl3", m, Column("id", Integer), schema="testschema") stmt = tbl3.select().with_hint(tbl3, "ONLY", "postgresql") expected = ( - "SELECT testschema.testtbl3.id FROM " "ONLY testschema.testtbl3" + "SELECT testschema.testtbl3.id FROM ONLY testschema.testtbl3" ) self.assert_compile(stmt, expected) @@ -2574,7 +2687,7 @@ def test_eager_grouping_flag(self, expr, expected, type_): self.assert_compile(expr, expected) - def test_custom_object_hook(self): + def test_range_custom_object_hook(self): # See issue #8884 from datetime import date @@ -2594,6 +2707,30 @@ def test_custom_object_hook(self): "WHERE usages.date <@ %(date_1)s::DATERANGE", ) + def test_multirange_custom_object_hook(self): + from datetime import date + + usages = table( + "usages", + column("id", Integer), + column("date", Date), + column("amount", Integer), + ) + period = MultiRange( + [ + Range(date(2022, 1, 1), (2023, 1, 1)), + Range(date(2024, 1, 1), (2025, 1, 1)), + ] + ) + stmt = select(func.sum(usages.c.amount)).where( + usages.c.date.op("<@")(period) + ) + self.assert_compile( + stmt, + "SELECT sum(usages.amount) AS sum_1 FROM usages " + "WHERE usages.date <@ %(date_1)s::DATEMULTIRANGE", + ) + def test_bitwise_xor(self): c1 = column("c1", Integer) c2 = column("c2", Integer) @@ -2660,6 +2797,11 @@ def define_tables(cls, metadata): (cls.table_with_metadata.c.description, "&&"), where=cls.table_with_metadata.c.description != "foo", ) + cls.excl_constr_anon_str = ExcludeConstraint( + (cls.table_with_metadata.c.name, "="), + (cls.table_with_metadata.c.description, "&&"), + where="description != 'foo'", + ) cls.goofy_index = Index( "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m" ) @@ -2678,6 +2820,69 @@ def define_tables(cls, metadata): Column("name", String(50), key="name_keyed"), ) + @testing.combinations( + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where=text("name = 'hi'") + ), + "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=[users.c.id], index_where=users.c.name == "hi" + ), + "ON CONFLICT (id) WHERE name = %(name_1)s DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where="name = 'hi'" + ), + exc.ArgumentError, + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=users.c.name == "hi", + ), + "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s " + "WHERE users.name = %(name_1)s", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=text("name = 'hi'"), + ), + "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s " + "WHERE name = 'hi'", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where="name = 'hi'", + ), + exc.ArgumentError, + ), + ) + def test_assorted_arg_coercion(self, case, expected): + stmt = insert(self.tables.users) + + if isinstance(expected, type) and issubclass(expected, Exception): + with expect_raises(expected): + testing.resolve_lambda( + case, stmt=stmt, users=self.tables.users + ), + else: + self.assert_compile( + testing.resolve_lambda( + case, stmt=stmt, users=self.tables.users + ), + f"INSERT INTO users (id, name) VALUES (%(id)s, %(name)s) " + f"{expected}", + ) + @testing.combinations("control", "excluded", "dict") def test_set_excluded(self, scenario): """test #8014, sending all of .excluded to set""" @@ -3071,6 +3276,20 @@ def test_do_update_unnamed_exclude_constraint_target(self): "DO UPDATE SET name = excluded.name", ) + def test_do_update_unnamed_exclude_constraint_string_target(self): + i = insert(self.table1).values(dict(name="foo")) + i = i.on_conflict_do_update( + constraint=self.excl_constr_anon_str, + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != 'foo' " + "DO UPDATE SET name = excluded.name", + ) + def test_do_update_add_whereclause(self): i = insert(self.table1).values(dict(name="foo")) i = i.on_conflict_do_update( @@ -3091,6 +3310,26 @@ def test_do_update_add_whereclause(self): "AND mytable.description != %(description_2)s", ) + def test_do_update_str_index_where(self): + i = insert(self.table1).values(dict(name="foo")) + i = i.on_conflict_do_update( + constraint=self.excl_constr_anon_str, + set_=dict(name=i.excluded.name), + where=( + (self.table1.c.name != "brah") + & (self.table1.c.description != "brah") + ), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != 'foo' " + "DO UPDATE SET name = excluded.name " + "WHERE mytable.name != %(name_1)s " + "AND mytable.description != %(description_1)s", + ) + def test_do_update_add_whereclause_references_excluded(self): i = insert(self.table1).values(dict(name="foo")) i = i.on_conflict_do_update( @@ -3214,7 +3453,6 @@ def test_quote_raw_string_col(self): class DistinctOnTest(fixtures.MappedTest, AssertsCompiledSQL): - """Test 'DISTINCT' with SQL expression language and orm.Query with an emphasis on PG's 'DISTINCT ON' syntax. @@ -3283,7 +3521,7 @@ def test_query_plain(self): sess = Session() self.assert_compile( sess.query(self.table).distinct(), - "SELECT DISTINCT t.id AS t_id, t.a AS t_a, " "t.b AS t_b FROM t", + "SELECT DISTINCT t.id AS t_id, t.a AS t_a, t.b AS t_b FROM t", ) def test_query_on_columns(self): @@ -3368,7 +3606,6 @@ def test_distinct_on_subquery_named(self): class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests for full text searching""" __dialect__ = postgresql.dialect() @@ -3910,3 +4147,49 @@ def test_aggregate_order_by(self): ), compare_values=False, ) + + def test_array_equivalent_keys_one_element(self): + self._run_cache_key_equal_fixture( + lambda: ( + array([random.randint(0, 10)]), + array([random.randint(0, 10)], type_=Integer), + array([random.randint(0, 10)], type_=Integer), + ), + compare_values=False, + ) + + def test_array_equivalent_keys_two_elements(self): + self._run_cache_key_equal_fixture( + lambda: ( + array([random.randint(0, 10), random.randint(0, 10)]), + array( + [random.randint(0, 10), random.randint(0, 10)], + type_=Integer, + ), + array( + [random.randint(0, 10), random.randint(0, 10)], + type_=Integer, + ), + ), + compare_values=False, + ) + + def test_array_heterogeneous(self): + self._run_cache_key_fixture( + lambda: ( + array([], type_=Integer), + array([], type_=Text), + array([]), + array([random.choice(["t1", "t2", "t3"])]), + array( + [ + random.choice(["t1", "t2", "t3"]), + random.choice(["t1", "t2", "t3"]), + ] + ), + array([random.choice(["t1", "t2", "t3"])], type_=Text), + array([random.choice(["t1", "t2", "t3"])], type_=VARCHAR(30)), + array([random.randint(0, 10), random.randint(0, 10)]), + ), + compare_values=False, + ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index db2d5e73dc6..109101011fc 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -178,6 +178,28 @@ def test_range_frozen(self): with expect_raises(dataclasses.FrozenInstanceError): r1.lower = 8 # type: ignore + @testing.only_on("postgresql+asyncpg") + def test_asyncpg_terminate_catch(self): + """test for #11005""" + + with testing.db.connect() as connection: + emulated_dbapi_connection = connection.connection.dbapi_connection + + async def boom(): + raise OSError("boom") + + with mock.patch.object( + emulated_dbapi_connection, + "_connection", + mock.Mock(close=mock.Mock(return_value=boom())), + ) as mock_asyncpg_connection: + emulated_dbapi_connection.terminate() + + eq_( + mock_asyncpg_connection.mock_calls, + [mock.call.close(timeout=2), mock.call.terminate()], + ) + def test_version_parsing(self): def mock_conn(res): return mock.Mock( @@ -343,6 +365,7 @@ class Error(Exception): "SSL SYSCALL error: EOF detected", "SSL SYSCALL error: Operation timed out", "SSL SYSCALL error: Bad address", + "SSL SYSCALL error: Success", ]: eq_(dialect.is_disconnect(Error(error), None, None), True) @@ -721,7 +744,7 @@ def test_non_int_port_disallowed(self, dialect, url_string): "postgresql+psycopg2://USER:PASS@/DB" "?host=hostA,hostC&port=111,222,333", ), - ("postgresql+psycopg2://USER:PASS@/DB" "?host=hostA&port=111,222",), + ("postgresql+psycopg2://USER:PASS@/DB?host=hostA&port=111,222",), ( "postgresql+asyncpg://USER:PASS@/DB" "?host=hostA,hostB,hostC&port=111,333", @@ -1017,6 +1040,12 @@ class MiscBackendTest( __only_on__ = "postgresql" __backend__ = True + @testing.fails_on(["+psycopg2"]) + def test_empty_sql_string(self, connection): + + result = connection.exec_driver_sql("") + assert result._soft_closed + @testing.provide_metadata def test_date_reflection(self): metadata = self.metadata @@ -1219,9 +1248,9 @@ def test_readonly_flag_engine(self, testing_engine, pre_ping): def test_autocommit_pre_ping(self, testing_engine, autocommit): engine = testing_engine( options={ - "isolation_level": "AUTOCOMMIT" - if autocommit - else "SERIALIZABLE", + "isolation_level": ( + "AUTOCOMMIT" if autocommit else "SERIALIZABLE" + ), "pool_pre_ping": True, } ) @@ -1239,9 +1268,9 @@ def test_asyncpg_transactional_ping(self, testing_engine, autocommit): engine = testing_engine( options={ - "isolation_level": "AUTOCOMMIT" - if autocommit - else "SERIALIZABLE", + "isolation_level": ( + "AUTOCOMMIT" if autocommit else "SERIALIZABLE" + ), "pool_pre_ping": True, } ) @@ -1354,6 +1383,7 @@ def test_notice_logging(self): conn.exec_driver_sql("SELECT note('another note')") finally: trans.rollback() + conn.close() finally: log.removeHandler(buf) log.setLevel(lev) @@ -1549,61 +1579,62 @@ def test_numeric_raise(self, connection): stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric) assert_raises(exc.InvalidRequestError, connection.execute, stmt) - @testing.only_on("postgresql+psycopg2") - def test_serial_integer(self): - class BITD(TypeDecorator): - impl = Integer - - cache_ok = True - - def load_dialect_impl(self, dialect): - if dialect.name == "postgresql": - return BigInteger() - else: - return Integer() - - for version, type_, expected in [ - (None, Integer, "SERIAL"), - (None, BigInteger, "BIGSERIAL"), - ((9, 1), SmallInteger, "SMALLINT"), - ((9, 2), SmallInteger, "SMALLSERIAL"), - (None, postgresql.INTEGER, "SERIAL"), - (None, postgresql.BIGINT, "BIGSERIAL"), - ( - None, - Integer().with_variant(BigInteger(), "postgresql"), - "BIGSERIAL", - ), - ( - None, - Integer().with_variant(postgresql.BIGINT, "postgresql"), - "BIGSERIAL", - ), - ( - (9, 2), - Integer().with_variant(SmallInteger, "postgresql"), - "SMALLSERIAL", - ), - (None, BITD(), "BIGSERIAL"), - ]: - m = MetaData() + @testing.combinations( + (None, Integer, "SERIAL"), + (None, BigInteger, "BIGSERIAL"), + ((9, 1), SmallInteger, "SMALLINT"), + ((9, 2), SmallInteger, "SMALLSERIAL"), + (None, SmallInteger, "SMALLSERIAL"), + (None, postgresql.INTEGER, "SERIAL"), + (None, postgresql.BIGINT, "BIGSERIAL"), + ( + None, + Integer().with_variant(BigInteger(), "postgresql"), + "BIGSERIAL", + ), + ( + None, + Integer().with_variant(postgresql.BIGINT, "postgresql"), + "BIGSERIAL", + ), + ( + (9, 2), + Integer().with_variant(SmallInteger, "postgresql"), + "SMALLSERIAL", + ), + (None, "BITD()", "BIGSERIAL"), + argnames="version, type_, expected", + ) + def test_serial_integer(self, version, type_, expected, testing_engine): + if type_ == "BITD()": - t = Table("t", m, Column("c", type_, primary_key=True)) + class BITD(TypeDecorator): + impl = Integer - if version: - dialect = testing.db.dialect.__class__() - dialect._get_server_version_info = mock.Mock( - return_value=version - ) - dialect.initialize(testing.db.connect()) - else: - dialect = testing.db.dialect + cache_ok = True - ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) - eq_( - ddl_compiler.get_column_specification(t.c.c), - "c %s NOT NULL" % expected, - ) + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return BigInteger() + else: + return Integer() + + type_ = BITD() + t = Table("t", MetaData(), Column("c", type_, primary_key=True)) + + if version: + engine = testing_engine() + dialect = engine.dialect + dialect._get_server_version_info = mock.Mock(return_value=version) + engine.connect().close() # initialize the dialect + else: + dialect = testing.db.dialect + + ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + eq_( + ddl_compiler.get_column_specification(t.c.c), + "c %s NOT NULL" % expected, + ) @testing.requires.psycopg2_compatibility def test_initial_transaction_state_psycopg2(self): @@ -1698,3 +1729,37 @@ def test_get_dialect(self): def test_async_version(self): e = create_engine("postgresql+psycopg_async://") is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg)) + + @testing.skip_if(lambda c: c.db.dialect.is_async) + def test_client_side_cursor(self, testing_engine): + from psycopg import ClientCursor + + engine = testing_engine( + options={"connect_args": {"cursor_factory": ClientCursor}} + ) + + with engine.connect() as c: + res = c.execute(select(1, 2, 3)).one() + eq_(res, (1, 2, 3)) + with c.connection.driver_connection.cursor() as cursor: + is_true(isinstance(cursor, ClientCursor)) + + @config.async_test + @testing.skip_if(lambda c: not c.db.dialect.is_async) + async def test_async_client_side_cursor(self, testing_engine): + from psycopg import AsyncClientCursor + + engine = testing_engine( + options={"connect_args": {"cursor_factory": AsyncClientCursor}}, + asyncio=True, + ) + + async with engine.connect() as c: + res = (await c.execute(select(1, 2, 3))).one() + eq_(res, (1, 2, 3)) + async with ( + await c.get_raw_connection() + ).driver_connection.cursor() as cursor: + is_true(isinstance(cursor, AsyncClientCursor)) + + await engine.dispose() diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index a9320f2c503..691f6c39620 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -583,7 +583,10 @@ def test_on_conflict_do_update_exotic_targets_four(self, connection): [(43, "nameunique2", "name2@gmail.com", "not")], ) - def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection): + @testing.variation("string_index_elements", [True, False]) + def test_on_conflict_do_update_exotic_targets_four_no_pk( + self, connection, string_index_elements + ): users = self.tables.users_xtra self._exotic_targets_fixture(connection) @@ -591,7 +594,11 @@ def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection): # upsert on target login_email, not id i = insert(users) i = i.on_conflict_do_update( - index_elements=[users.c.login_email], + index_elements=( + ["login_email"] + if string_index_elements + else [users.c.login_email] + ), set_=dict( id=i.excluded.id, name=i.excluded.name, diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 8d8d9a7ec9d..fc68e08ed4d 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -26,6 +26,8 @@ from sqlalchemy import Time from sqlalchemy import true from sqlalchemy import tuple_ +from sqlalchemy import Uuid +from sqlalchemy import values from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import REGCONFIG @@ -977,7 +979,7 @@ def test_expression_pyformat(self, connection): if self._strs_render_bind_casts(connection): self.assert_compile( matchtable.c.title.match("somstr"), - "matchtable.title @@ " "plainto_tsquery(%(title_1)s::VARCHAR)", + "matchtable.title @@ plainto_tsquery(%(title_1)s::VARCHAR)", ) else: self.assert_compile( @@ -1005,7 +1007,7 @@ def test_expression_positional(self, connection): (func.to_tsquery,), (func.plainto_tsquery,), (func.phraseto_tsquery,), - (func.websearch_to_tsquery,), + (func.websearch_to_tsquery, testing.skip_if("postgresql < 11")), argnames="to_ts_func", ) @testing.variation("use_regconfig", [True, False, "literal"]) @@ -1238,10 +1240,9 @@ def test_tuple_containment(self, connection): class ExtractTest(fixtures.TablesTest): - """The rationale behind this test is that for many years we've had a system of embedding type casts into the expressions rendered by visit_extract() - on the postgreql platform. The reason for this cast is not clear. + on the postgresql platform. The reason for this cast is not clear. So here we try to produce a wide range of cases to ensure that these casts are not needed; see [ticket:2740]. @@ -1639,6 +1640,10 @@ def test_with_ordinality_star(self, connection): eq_(connection.execute(stmt).all(), [(4, 1), (3, 2), (2, 3), (1, 4)]) + def test_array_empty_with_type(self, connection): + stmt = select(postgresql.array([], type_=Integer)) + eq_(connection.execute(stmt).all(), [([],)]) + def test_plain_old_unnest(self, connection): fn = func.unnest( postgresql.array(["one", "two", "three", "four"]) @@ -1792,3 +1797,59 @@ def test_render_derived_quoting_straight_json(self, connection, cast_fn): stmt = select(fn.c.CaseSensitive, fn.c["the % value"]) eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")]) + + +class RequiresCastTest(fixtures.TablesTest): + __only_on__ = "postgresql" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("uuid", Uuid), + Column("j", JSON), + Column("jb", JSONB), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables["t"].insert(), + [ + {"id": 1, "uuid": "d24587a1-06d9-41df-b1c3-3f423b97a755"}, + {"id": 2, "uuid": "4b07e1c8-d60c-4ea8-9d01-d7cd01362224"}, + ], + ) + + def test_update_values(self, connection): + value = values( + Column("id", Integer), + Column("uuid", Uuid), + Column("j", JSON), + Column("jb", JSONB), + name="update_data", + ).data( + [ + ( + 1, + "8b6ec1ec-b979-4d0b-b2ce-9acc6e4c2943", + {"foo": 1}, + {"foo_jb": 1}, + ), + ( + 2, + "a2123bcb-7ea3-420a-8284-1db4b2759d79", + {"bar": 2}, + {"bar_jb": 2}, + ), + ] + ) + connection.execute( + self.tables["t"] + .update() + .values(uuid=value.c.uuid, j=value.c.j, jb=value.c.jb) + .where(self.tables["t"].c.id == value.c.id) + ) diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index ab4fa2c038d..5dd8e00070d 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -7,6 +7,7 @@ from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKeyConstraint from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import inspect @@ -20,10 +21,13 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text +from sqlalchemy import text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import base as postgresql +from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import ExcludeConstraint +from sqlalchemy.dialects.postgresql import INET from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.dialects.postgresql import INTERVAL from sqlalchemy.dialects.postgresql import pg_catalog @@ -34,6 +38,7 @@ from sqlalchemy.sql import ddl as sa_ddl from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import config from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_warns @@ -404,84 +409,164 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql > 8.3" __backend__ = True - @classmethod - def setup_test_class(cls): - with testing.db.begin() as con: - for ddl in [ - 'CREATE SCHEMA "SomeSchema"', - "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", - "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", - "CREATE TYPE testtype AS ENUM ('test')", - "CREATE DOMAIN enumdomain AS testtype", - "CREATE DOMAIN arraydomain AS INTEGER[]", - 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', - "CREATE DOMAIN nullable_domain AS TEXT CHECK " - "(VALUE IN('FOO', 'BAR'))", - "CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL", - "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK " - "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) " - "CHECK(VALUE != 22)", - ]: - try: - con.exec_driver_sql(ddl) - except exc.DBAPIError as e: - if "already exists" not in str(e): - raise e - con.exec_driver_sql( - "CREATE TABLE testtable (question integer, answer " - "testdomain)" - ) - con.exec_driver_sql( - "CREATE TABLE test_schema.testtable(question " - "integer, answer test_schema.testdomain, anything " - "integer)" - ) - con.exec_driver_sql( - "CREATE TABLE crosschema (question integer, answer " - "test_schema.testdomain)" + # these fixtures are all currently using individual test scope, + # on a connection that's in a transaction that's rolled back. + # previously, this test would build up all the domains / tables + # at the class level and commit them. PostgreSQL seems to be extremely + # fast at building up / tearing down domains / schemas etc within an + # uncommitted transaction so it seems OK to keep these at per-test + # scope. + + @testing.fixture() + def broken_nullable_domains(self): + if not testing.requires.postgresql_working_nullable_domains.enabled: + config.skip_test( + "reflection of nullable domains broken on PG 17.0-17.2" ) - con.exec_driver_sql( - "CREATE TABLE enum_test (id integer, data enumdomain)" - ) + @testing.fixture() + def testdomain(self, connection, broken_nullable_domains): + connection.exec_driver_sql( + "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42" + ) + yield + connection.exec_driver_sql("DROP DOMAIN testdomain") - con.exec_driver_sql( - "CREATE TABLE array_test (id integer, data arraydomain)" - ) + @testing.fixture + def testtable(self, connection, testdomain): + connection.exec_driver_sql( + "CREATE TABLE testtable (question integer, answer testdomain)" + ) + yield + connection.exec_driver_sql("DROP TABLE testtable") - con.exec_driver_sql( - "CREATE TABLE quote_test " - '(id integer, data "SomeSchema"."Quoted.Domain")' - ) - con.exec_driver_sql( - "CREATE TABLE nullable_domain_test " - "(not_nullable_domain_col nullable_domain not null," - "nullable_local not_nullable_domain)" - ) + @testing.fixture + def nullable_domains(self, connection, broken_nullable_domains): + connection.exec_driver_sql( + 'CREATE DOMAIN nullable_domain AS TEXT COLLATE "C" CHECK ' + "(VALUE IN('FOO', 'BAR'))" + ) + connection.exec_driver_sql( + "CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL" + ) + yield + connection.exec_driver_sql("DROP DOMAIN nullable_domain") + connection.exec_driver_sql("DROP DOMAIN not_nullable_domain") - @classmethod - def teardown_test_class(cls): - with testing.db.begin() as con: - con.exec_driver_sql("DROP TABLE testtable") - con.exec_driver_sql("DROP TABLE test_schema.testtable") - con.exec_driver_sql("DROP TABLE crosschema") - con.exec_driver_sql("DROP TABLE quote_test") - con.exec_driver_sql("DROP DOMAIN testdomain") - con.exec_driver_sql("DROP DOMAIN test_schema.testdomain") - con.exec_driver_sql("DROP TABLE enum_test") - con.exec_driver_sql("DROP DOMAIN enumdomain") - con.exec_driver_sql("DROP TYPE testtype") - con.exec_driver_sql("DROP TABLE array_test") - con.exec_driver_sql("DROP DOMAIN arraydomain") - con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') - con.exec_driver_sql('DROP SCHEMA "SomeSchema"') - - con.exec_driver_sql("DROP TABLE nullable_domain_test") - con.exec_driver_sql("DROP DOMAIN nullable_domain") - con.exec_driver_sql("DROP DOMAIN not_nullable_domain") - con.exec_driver_sql("DROP DOMAIN my_int") - - def test_table_is_reflected(self, connection): + @testing.fixture + def nullable_domain_table(self, connection, nullable_domains): + connection.exec_driver_sql( + "CREATE TABLE nullable_domain_test " + "(not_nullable_domain_col nullable_domain not null," + "nullable_local not_nullable_domain)" + ) + yield + connection.exec_driver_sql("DROP TABLE nullable_domain_test") + + @testing.fixture + def enum_domain(self, connection): + connection.exec_driver_sql("CREATE TYPE testtype AS ENUM ('test')") + connection.exec_driver_sql("CREATE DOMAIN enumdomain AS testtype") + yield + connection.exec_driver_sql("drop domain enumdomain") + connection.exec_driver_sql("drop type testtype") + + @testing.fixture + def enum_table(self, connection, enum_domain): + connection.exec_driver_sql( + "CREATE TABLE enum_test (id integer, data enumdomain)" + ) + yield + connection.exec_driver_sql("DROP TABLE enum_test") + + @testing.fixture + def array_domains(self, connection): + connection.exec_driver_sql("CREATE DOMAIN arraydomain AS INTEGER[]") + connection.exec_driver_sql( + "CREATE DOMAIN arraydomain_2d AS INTEGER[][]" + ) + connection.exec_driver_sql( + "CREATE DOMAIN arraydomain_3d AS INTEGER[][][]" + ) + yield + connection.exec_driver_sql("DROP DOMAIN arraydomain") + connection.exec_driver_sql("DROP DOMAIN arraydomain_2d") + connection.exec_driver_sql("DROP DOMAIN arraydomain_3d") + + @testing.fixture + def array_table(self, connection, array_domains): + connection.exec_driver_sql( + "CREATE TABLE array_test (" + "id integer, " + "datas arraydomain, " + "datass arraydomain_2d, " + "datasss arraydomain_3d" + ")" + ) + yield + connection.exec_driver_sql("DROP TABLE array_test") + + @testing.fixture + def some_schema(self, connection): + connection.exec_driver_sql('CREATE SCHEMA IF NOT EXISTS "SomeSchema"') + yield + connection.exec_driver_sql('DROP SCHEMA IF EXISTS "SomeSchema"') + + @testing.fixture + def quoted_schema_domain(self, connection, some_schema): + connection.exec_driver_sql( + 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0' + ) + yield + connection.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') + + @testing.fixture + def int_domain(self, connection): + connection.exec_driver_sql( + "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK " + "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) " + "CHECK(VALUE != 22)" + ) + yield + connection.exec_driver_sql("DROP DOMAIN my_int") + + @testing.fixture + def quote_table(self, connection, quoted_schema_domain): + connection.exec_driver_sql( + "CREATE TABLE quote_test " + '(id integer, data "SomeSchema"."Quoted.Domain")' + ) + yield + connection.exec_driver_sql("drop table quote_test") + + @testing.fixture + def testdomain_schema(self, connection): + connection.exec_driver_sql( + "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0" + ) + yield + connection.exec_driver_sql("DROP DOMAIN test_schema.testdomain") + + @testing.fixture + def testtable_schema(self, connection, testdomain_schema): + connection.exec_driver_sql( + "CREATE TABLE test_schema.testtable(question " + "integer, answer test_schema.testdomain, anything " + "integer)" + ) + yield + connection.exec_driver_sql("drop table test_schema.testtable") + + @testing.fixture + def crosschema_table(self, connection, testdomain_schema): + connection.exec_driver_sql( + "CREATE TABLE crosschema (question integer, answer " + f"{config.test_schema}.testdomain)" + ) + yield + connection.exec_driver_sql("DROP TABLE crosschema") + + def test_table_is_reflected(self, connection, testtable): metadata = MetaData() table = Table("testtable", metadata, autoload_with=connection) eq_( @@ -489,9 +574,11 @@ def test_table_is_reflected(self, connection): {"question", "answer"}, "Columns of reflected table didn't equal expected columns", ) - assert isinstance(table.c.answer.type, Integer) + assert isinstance(table.c.answer.type, DOMAIN) + assert table.c.answer.type.name, "testdomain" + assert isinstance(table.c.answer.type.data_type, Integer) - def test_nullable_from_domain(self, connection): + def test_nullable_from_domain(self, connection, nullable_domain_table): metadata = MetaData() table = Table( "nullable_domain_test", metadata, autoload_with=connection @@ -499,7 +586,7 @@ def test_nullable_from_domain(self, connection): is_(table.c.not_nullable_domain_col.nullable, False) is_(table.c.nullable_local.nullable, False) - def test_domain_is_reflected(self, connection): + def test_domain_is_reflected(self, connection, testtable): metadata = MetaData() table = Table("testtable", metadata, autoload_with=connection) eq_( @@ -511,29 +598,51 @@ def test_domain_is_reflected(self, connection): not table.columns.answer.nullable ), "Expected reflected column to not be nullable." - def test_enum_domain_is_reflected(self, connection): + def test_enum_domain_is_reflected(self, connection, enum_table): metadata = MetaData() table = Table("enum_test", metadata, autoload_with=connection) - eq_(table.c.data.type.enums, ["test"]) + assert isinstance(table.c.data.type, DOMAIN) + eq_(table.c.data.type.data_type.enums, ["test"]) - def test_array_domain_is_reflected(self, connection): + def test_array_domain_is_reflected(self, connection, array_table): metadata = MetaData() table = Table("array_test", metadata, autoload_with=connection) - eq_(table.c.data.type.__class__, ARRAY) - eq_(table.c.data.type.item_type.__class__, INTEGER) - def test_quoted_remote_schema_domain_is_reflected(self, connection): + def assert_is_integer_array_domain(domain, name): + # Postgres does not persist the dimensionality of the array. + # It's always treated as integer[] + assert isinstance(domain, DOMAIN) + assert domain.name == name + assert isinstance(domain.data_type, ARRAY) + assert isinstance(domain.data_type.item_type, INTEGER) + + array_domain = table.c.datas.type + assert_is_integer_array_domain(array_domain, "arraydomain") + + array_domain_2d = table.c.datass.type + assert_is_integer_array_domain(array_domain_2d, "arraydomain_2d") + + array_domain_3d = table.c.datasss.type + assert_is_integer_array_domain(array_domain_3d, "arraydomain_3d") + + def test_quoted_remote_schema_domain_is_reflected( + self, connection, quote_table + ): metadata = MetaData() table = Table("quote_test", metadata, autoload_with=connection) - eq_(table.c.data.type.__class__, INTEGER) + assert isinstance(table.c.data.type, DOMAIN) + assert table.c.data.type.name, "Quoted.Domain" + assert isinstance(table.c.data.type.data_type, Integer) - def test_table_is_reflected_test_schema(self, connection): + def test_table_is_reflected_test_schema( + self, connection, testtable_schema + ): metadata = MetaData() table = Table( "testtable", metadata, autoload_with=connection, - schema="test_schema", + schema=config.test_schema, ) eq_( set(table.columns.keys()), @@ -542,13 +651,13 @@ def test_table_is_reflected_test_schema(self, connection): ) assert isinstance(table.c.anything.type, Integer) - def test_schema_domain_is_reflected(self, connection): + def test_schema_domain_is_reflected(self, connection, testtable_schema): metadata = MetaData() table = Table( "testtable", metadata, autoload_with=connection, - schema="test_schema", + schema=config.test_schema, ) eq_( str(table.columns.answer.server_default.arg), @@ -559,7 +668,9 @@ def test_schema_domain_is_reflected(self, connection): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_crosschema_domain_is_reflected(self, connection): + def test_crosschema_domain_is_reflected( + self, connection, crosschema_table + ): metadata = MetaData() table = Table("crosschema", metadata, autoload_with=connection) eq_( @@ -571,7 +682,7 @@ def test_crosschema_domain_is_reflected(self, connection): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_unknown_types(self, connection): + def test_unknown_types(self, connection, testtable): from sqlalchemy.dialects.postgresql import base ischema_names = base.PGDialect.ischema_names @@ -591,8 +702,17 @@ def warns(): finally: base.PGDialect.ischema_names = ischema_names - @property - def all_domains(self): + @testing.fixture + def all_domains( + self, + quoted_schema_domain, + array_domains, + enum_domain, + nullable_domains, + int_domain, + testdomain, + testdomain_schema, + ): return { "public": [ { @@ -603,6 +723,27 @@ def all_domains(self): "type": "integer[]", "default": None, "constraints": [], + "collation": None, + }, + { + "visible": True, + "name": "arraydomain_2d", + "schema": "public", + "nullable": True, + "type": "integer[]", + "default": None, + "constraints": [], + "collation": None, + }, + { + "visible": True, + "name": "arraydomain_3d", + "schema": "public", + "nullable": True, + "type": "integer[]", + "default": None, + "constraints": [], + "collation": None, }, { "visible": True, @@ -612,6 +753,7 @@ def all_domains(self): "type": "testtype", "default": None, "constraints": [], + "collation": None, }, { "visible": True, @@ -626,6 +768,7 @@ def all_domains(self): # autogenerated name by pg {"check": "VALUE <> 22", "name": "my_int_check"}, ], + "collation": None, }, { "visible": True, @@ -635,6 +778,7 @@ def all_domains(self): "type": "text", "default": None, "constraints": [], + "collation": "default", }, { "visible": True, @@ -651,6 +795,7 @@ def all_domains(self): "name": "nullable_domain_check", } ], + "collation": "C", }, { "visible": True, @@ -660,6 +805,7 @@ def all_domains(self): "type": "integer", "default": "42", "constraints": [], + "collation": None, }, ], "test_schema": [ @@ -671,6 +817,7 @@ def all_domains(self): "type": "integer", "default": "0", "constraints": [], + "collation": None, } ], "SomeSchema": [ @@ -682,30 +829,66 @@ def all_domains(self): "type": "integer", "default": "0", "constraints": [], + "collation": None, } ], } - def test_inspect_domains(self, connection): + def test_inspect_domains(self, connection, all_domains): inspector = inspect(connection) - eq_(inspector.get_domains(), self.all_domains["public"]) + domains = inspector.get_domains() + + domain_names = {d["name"] for d in domains} + expect_domain_names = {d["name"] for d in all_domains["public"]} + eq_(domain_names, expect_domain_names) + + eq_(domains, all_domains["public"]) - def test_inspect_domains_schema(self, connection): + def test_inspect_domains_schema(self, connection, all_domains): inspector = inspect(connection) eq_( inspector.get_domains("test_schema"), - self.all_domains["test_schema"], - ) - eq_( - inspector.get_domains("SomeSchema"), self.all_domains["SomeSchema"] + all_domains["test_schema"], ) + eq_(inspector.get_domains("SomeSchema"), all_domains["SomeSchema"]) - def test_inspect_domains_star(self, connection): + def test_inspect_domains_star(self, connection, all_domains): inspector = inspect(connection) - all_ = [d for dl in self.all_domains.values() for d in dl] + all_ = [d for dl in all_domains.values() for d in dl] all_ += inspector.get_domains("information_schema") exp = sorted(all_, key=lambda d: (d["schema"], d["name"])) - eq_(inspector.get_domains("*"), exp) + domains = inspector.get_domains("*") + + eq_(domains, exp) + + +class ArrayReflectionTest(fixtures.TablesTest): + __only_on__ = "postgresql >= 10" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "array_table", + metadata, + Column("id", INTEGER, primary_key=True), + Column("datas", ARRAY(INTEGER)), + Column("datass", ARRAY(INTEGER, dimensions=2)), + Column("datasss", ARRAY(INTEGER, dimensions=3)), + ) + + def test_array_table_is_reflected(self, connection): + metadata = MetaData() + table = Table("array_table", metadata, autoload_with=connection) + + def assert_is_integer_array(data_type): + assert isinstance(data_type, ARRAY) + # posgres treats all arrays as one-dimensional arrays + assert isinstance(data_type.item_type, INTEGER) + + assert_is_integer_array(table.c.datas.type) + assert_is_integer_array(table.c.datass.type) + assert_is_integer_array(table.c.datasss.type) class ReflectionTest( @@ -728,6 +911,56 @@ def test_reflected_primary_key_order(self, metadata, connection): subject = Table("subject", meta2, autoload_with=connection) eq_(subject.primary_key.columns.keys(), ["p2", "p1"]) + @testing.skip_if( + "postgresql < 15.0", "on delete with column list not supported" + ) + def test_reflected_foreign_key_ondelete_column_list( + self, metadata, connection + ): + meta1 = metadata + pktable = Table( + "pktable", + meta1, + Column("tid", Integer, primary_key=True), + Column("id", Integer, primary_key=True), + ) + Table( + "fktable", + meta1, + Column("tid", Integer), + Column("id", Integer), + Column("fk_id_del_set_null", Integer), + Column("fk_id_del_set_default", Integer, server_default=text("0")), + ForeignKeyConstraint( + name="fktable_tid_fk_id_del_set_null_fkey", + columns=["tid", "fk_id_del_set_null"], + refcolumns=[pktable.c.tid, pktable.c.id], + ondelete="SET NULL (fk_id_del_set_null)", + ), + ForeignKeyConstraint( + name="fktable_tid_fk_id_del_set_default_fkey", + columns=["tid", "fk_id_del_set_default"], + refcolumns=[pktable.c.tid, pktable.c.id], + ondelete="SET DEFAULT(fk_id_del_set_default)", + ), + ) + + meta1.create_all(connection) + meta2 = MetaData() + fktable = Table("fktable", meta2, autoload_with=connection) + fkey_set_null = next( + c + for c in fktable.foreign_key_constraints + if c.name == "fktable_tid_fk_id_del_set_null_fkey" + ) + eq_(fkey_set_null.ondelete, "SET NULL (fk_id_del_set_null)") + fkey_set_default = next( + c + for c in fktable.foreign_key_constraints + if c.name == "fktable_tid_fk_id_del_set_default_fkey" + ) + eq_(fkey_set_default.ondelete, "SET DEFAULT (fk_id_del_set_default)") + def test_pg_weirdchar_reflection(self, metadata, connection): meta1 = metadata subject = Table( @@ -1492,6 +1725,54 @@ def test_index_reflection_with_access_method(self, metadata, connection): "gin", ) + def test_index_reflection_with_operator_class(self, metadata, connection): + """reflect indexes with operator class on columns""" + + Table( + "t", + metadata, + Column("id", Integer, nullable=False), + Column("name", String), + Column("alias", String), + Column("addr1", INET), + Column("addr2", INET), + ) + metadata.create_all(connection) + + # 'name' and 'addr1' use a non-default operator, 'addr2' uses the + # default one, and 'alias' uses no operator. + connection.exec_driver_sql( + "CREATE INDEX ix_t ON t USING btree" + " (name text_pattern_ops, alias, addr1 cidr_ops, addr2 inet_ops)" + ) + + ind = inspect(connection).get_indexes("t", None) + expected = [ + { + "unique": False, + "column_names": ["name", "alias", "addr1", "addr2"], + "name": "ix_t", + "dialect_options": { + "postgresql_ops": { + "addr1": "cidr_ops", + "name": "text_pattern_ops", + }, + }, + } + ] + if connection.dialect.server_version_info >= (11, 0): + expected[0]["include_columns"] = [] + expected[0]["dialect_options"]["postgresql_include"] = [] + eq_(ind, expected) + + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + r_ind = list(t1.indexes)[0] + eq_( + r_ind.dialect_options["postgresql"]["ops"], + {"name": "text_pattern_ops", "addr1": "cidr_ops"}, + ) + @testing.skip_if("postgresql < 15.0", "nullsnotdistinct not supported") def test_nullsnotdistinct(self, metadata, connection): Table( @@ -1541,6 +1822,7 @@ def test_nullsnotdistinct(self, metadata, connection): "column_names": ["y"], "name": "unq1", "dialect_options": { + "postgresql_include": [], "postgresql_nulls_not_distinct": True, }, "comment": None, @@ -2197,6 +2479,42 @@ def test_reflect_with_not_valid_check_constraint(self): ], ) + def test_reflect_with_no_inherit_check_constraint(self): + rows = [ + ("foo", "some name", "CHECK ((a IS NOT NULL)) NO INHERIT", None), + ( + "foo", + "some name", + "CHECK ((a IS NOT NULL)) NO INHERIT NOT VALID", + None, + ), + ] + conn = mock.Mock( + execute=lambda *arg, **kw: mock.MagicMock( + fetchall=lambda: rows, __iter__=lambda self: iter(rows) + ) + ) + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a IS NOT NULL", + "dialect_options": {"no_inherit": True}, + "comment": None, + }, + { + "name": "some name", + "sqltext": "a IS NOT NULL", + "dialect_options": {"not_valid": True, "no_inherit": True}, + "comment": None, + }, + ], + ) + def _apply_stm(self, connection, use_map): if use_map: return connection.execution_options( @@ -2337,6 +2655,51 @@ def all_none(): connection.execute(sa_ddl.DropConstraintComment(cst)) all_none() + @testing.skip_if("postgresql < 11.0", "not supported") + def test_reflection_constraints_with_include(self, connection, metadata): + Table( + "foo", + metadata, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + Column("foo", String), + Column("arr", ARRAY(Integer)), + Column("bar", SmallInteger), + ) + metadata.create_all(connection) + connection.exec_driver_sql( + "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)" + ) + connection.exec_driver_sql( + "ALTER TABLE foo " + "ADD PRIMARY KEY (id) INCLUDE (arr, foo, bar, value)" + ) + + unq = inspect(connection).get_unique_constraints("foo") + expected_unq = [ + { + "column_names": ["id"], + "name": "foo_id_value_key", + "dialect_options": { + "postgresql_nulls_not_distinct": False, + "postgresql_include": ["value"], + }, + "comment": None, + } + ] + eq_(unq, expected_unq) + + pk = inspect(connection).get_pk_constraint("foo") + expected_pk = { + "comment": None, + "constrained_columns": ["id"], + "dialect_options": { + "postgresql_include": ["arr", "foo", "bar", "value"] + }, + "name": "foo_pkey", + } + eq_(pk, expected_pk) + class CustomTypeReflectionTest(fixtures.TestBase): class CustomType: diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 0a98ef5045f..0df48f6fd12 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -73,6 +73,8 @@ from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql.ranges import MultiRange from sqlalchemy.exc import CompileError from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import declarative_base @@ -92,6 +94,7 @@ from sqlalchemy.testing.assertions import ComparesTables from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.suite import test_types as suite @@ -529,6 +532,7 @@ def make_type(**kw): "check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text", } ], + "collation": "default", } ], ) @@ -1073,7 +1077,7 @@ def test_standalone_enum(self, connection, metadata): connection, "fourfivesixtype" ) - def test_reflection(self, metadata, connection): + def test_enum_type_reflection(self, metadata, connection): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata ) @@ -1155,7 +1159,7 @@ def process_result_value(self, value, dialect): "one", "two", "three", - native_enum=True # make sure this is True because + native_enum=True, # make sure this is True because # it should *not* take effect due to # the variant ).with_variant( @@ -1227,6 +1231,213 @@ def test_generic_w_some_other_variant(self, metadata, connection): ] +class DomainTest( + AssertsCompiledSQL, fixtures.TestBase, AssertsExecutionResults +): + __backend__ = True + __only_on__ = "postgresql > 8.3" + + @testing.requires.postgresql_working_nullable_domains + def test_domain_type_reflection(self, metadata, connection): + positive_int = DOMAIN( + "positive_int", Integer(), check="value > 0", not_null=True + ) + my_str = DOMAIN("my_string", Text(), collation="C", default="~~") + Table( + "table", + metadata, + Column("value", positive_int), + Column("str", my_str), + ) + + metadata.create_all(connection) + m2 = MetaData() + t2 = Table("table", m2, autoload_with=connection) + + vt = t2.c.value.type + is_true(isinstance(vt, DOMAIN)) + is_true(isinstance(vt.data_type, Integer)) + eq_(vt.name, "positive_int") + eq_(str(vt.check), "VALUE > 0") + is_(vt.default, None) + is_(vt.collation, None) + is_true(vt.constraint_name is not None) + is_true(vt.not_null) + is_false(vt.create_type) + + st = t2.c.str.type + is_true(isinstance(st, DOMAIN)) + is_true(isinstance(st.data_type, Text)) + eq_(st.name, "my_string") + is_(st.check, None) + is_true("~~" in st.default) + eq_(st.collation, "C") + is_(st.constraint_name, None) + is_false(st.not_null) + is_false(st.create_type) + + def test_domain_create_table(self, metadata, connection): + metadata = self.metadata + Email = DOMAIN( + name="email", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ) + PosInt = DOMAIN( + name="pos_int", + data_type=Integer, + not_null=True, + check=r"VALUE > 0", + ) + t1 = Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("email", Email), + Column("number", PosInt), + ) + t1.create(connection) + t1.create(connection, checkfirst=True) # check the create + connection.execute( + t1.insert(), {"email": "test@example.com", "number": 42} + ) + connection.execute(t1.insert(), {"email": "a@b.c", "number": 1}) + connection.execute( + t1.insert(), {"email": "example@gmail.co.uk", "number": 99} + ) + eq_( + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), + [ + (1, "test@example.com", 42), + (2, "a@b.c", 1), + (3, "example@gmail.co.uk", 99), + ], + ) + + @testing.combinations( + tuple( + [ + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + create_type=True, + ), + ] + ), + tuple( + [ + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + create_type=False, + ), + ] + ), + argnames="domain", + ) + def test_create_drop_domain_with_table(self, connection, metadata, domain): + table = Table("e1", metadata, Column("e1", domain)) + + def _domain_names(): + return {d["name"] for d in inspect(connection).get_domains()} + + assert "mytype" not in _domain_names() + + if domain.create_type: + table.create(connection) + assert "mytype" in _domain_names() + else: + with expect_raises(exc.ProgrammingError): + table.create(connection) + connection.rollback() + + domain.create(connection) + assert "mytype" in _domain_names() + table.create(connection) + + table.drop(connection) + if domain.create_type: + assert "mytype" not in _domain_names() + + @testing.combinations( + (Integer, "value > 0", 4), + (String, "value != ''", "hello world"), + ( + UUID, + "value != '{00000000-0000-0000-0000-000000000000}'", + uuid.uuid4(), + ), + ( + DateTime, + "value >= '2020-01-01T00:00:00'", + datetime.datetime.fromisoformat("2021-01-01T00:00:00.000"), + ), + argnames="domain_datatype, domain_check, value", + ) + def test_domain_roundtrip( + self, metadata, connection, domain_datatype, domain_check, value + ): + table = Table( + "domain_roundtrip_test", + metadata, + Column("id", Integer, primary_key=True), + Column( + "value", + DOMAIN("valuedomain", domain_datatype, check=domain_check), + ), + ) + table.create(connection) + + connection.execute(table.insert(), {"value": value}) + + results = connection.execute( + table.select().order_by(table.c.id) + ).fetchall() + eq_(results, [(1, value)]) + + @testing.combinations( + (DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True), 4, -4), + ( + DOMAIN("email", String, check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'"), + "e@xample.com", + "fred", + ), + argnames="domain,pass_value,fail_value", + ) + def test_check_constraint( + self, metadata, connection, domain, pass_value, fail_value + ): + table = Table("table", metadata, Column("value", domain)) + table.create(connection) + + connection.execute(table.insert(), {"value": pass_value}) + + # psycopg/psycopg2 raise IntegrityError, while pg8000 raises + # ProgrammingError + with expect_raises(exc.DatabaseError): + connection.execute(table.insert(), {"value": fail_value}) + + @testing.combinations( + (DOMAIN("nullable_domain", Integer, not_null=True), 1), + (DOMAIN("non_nullable_domain", Integer, not_null=False), 1), + argnames="domain,pass_value", + ) + def test_domain_nullable(self, metadata, connection, domain, pass_value): + table = Table("table", metadata, Column("value", domain)) + table.create(connection) + connection.execute(table.insert(), {"value": pass_value}) + + if domain.not_null: + # psycopg/psycopg2 raise IntegrityError, while pg8000 raises + # ProgrammingError + with expect_raises(exc.DatabaseError): + connection.execute(table.insert(), {"value": None}) + else: + connection.execute(table.insert(), {"value": None}) + + class DomainDDLEventTest(DDLEventWCreateHarness, fixtures.TestBase): __backend__ = True @@ -1555,6 +1766,10 @@ def test_reflection(self, metadata, connection): t1.create(connection) m2 = MetaData() t2 = Table("t1", m2, autoload_with=connection) + + eq_(t1.c.c1.type.__class__, postgresql.TIME) + eq_(t1.c.c4.type.__class__, postgresql.TIMESTAMP) + eq_(t2.c.c1.type.precision, None) eq_(t2.c.c2.type.precision, 5) eq_(t2.c.c3.type.precision, 5) @@ -3232,9 +3447,51 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_bit_compile(self, type_, expected): self.assert_compile(type_, expected) + @testing.combinations( + (psycopg.dialect(),), + (psycopg2.dialect(),), + (asyncpg.dialect(),), + (pg8000.dialect(),), + argnames="dialect", + id_="n", + ) + def test_network_address_cast(self, metadata, dialect): + t = Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("addr", postgresql.INET), + Column("addr2", postgresql.MACADDR), + Column("addr3", postgresql.CIDR), + Column("addr4", postgresql.MACADDR8), + ) + stmt = select(t.c.id).where( + t.c.addr == "127.0.0.1", + t.c.addr2 == "08:00:2b:01:02:03", + t.c.addr3 == "192.168.100.128/25", + t.c.addr4 == "08:00:2b:01:02:03:04:05", + ) + param, param2, param3, param4 = { + "format": ("%s", "%s", "%s", "%s"), + "numeric_dollar": ("$1", "$2", "$3", "$4"), + "pyformat": ( + "%(addr_1)s", + "%(addr2_1)s", + "%(addr3_1)s", + "%(addr4_1)s", + ), + }[dialect.paramstyle] + expected = ( + "SELECT addresses.id FROM addresses " + f"WHERE addresses.addr = {param} " + f"AND addresses.addr2 = {param2} " + f"AND addresses.addr3 = {param3} " + f"AND addresses.addr4 = {param4}" + ) + self.assert_compile(stmt, expected, dialect=dialect) -class SpecialTypesTest(fixtures.TablesTest, ComparesTables): +class SpecialTypesTest(fixtures.TablesTest, ComparesTables): """test DDL and reflection of PG-specific types""" __only_on__ = ("postgresql >= 8.3.0",) @@ -3287,6 +3544,34 @@ def test_reflection(self, special_types_table, connection): assert t.c.precision_interval.type.precision == 3 assert t.c.bitstring.type.length == 4 + @testing.combinations( + (postgresql.INET, "127.0.0.1"), + (postgresql.CIDR, "192.168.100.128/25"), + (postgresql.MACADDR, "08:00:2b:01:02:03"), + ( + postgresql.MACADDR8, + "08:00:2b:01:02:03:04:05", + testing.skip_if("postgresql < 10"), + ), + argnames="column_type, value", + id_="na", + ) + def test_network_address_round_trip( + self, connection, metadata, column_type, value + ): + t = Table( + "addresses", + metadata, + Column("name", String), + Column("value", column_type), + ) + t.create(connection) + connection.execute(t.insert(), {"name": "test", "value": value}) + eq_( + connection.scalar(select(t.c.name).where(t.c.value == value)), + "test", + ) + def test_tsvector_round_trip(self, connection, metadata): t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) t.create(connection) @@ -3325,7 +3610,6 @@ def test_bit_reflection(self, metadata, connection): class UUIDTest(fixtures.TestBase): - """Test postgresql-specific UUID cases. See also generic UUID tests in testing/suite/test_types @@ -3889,6 +4173,53 @@ def __init__(self, name, data): eq_(s.query(Data.data, Data).all(), [(d.data, d)]) +class RangeMiscTests(fixtures.TestBase): + @testing.combinations( + (Range(2, 7), INT4RANGE), + (Range(-10, 7), INT4RANGE), + (Range(None, -7), INT4RANGE), + (Range(33, None), INT4RANGE), + (Range(-2147483648, 2147483647), INT4RANGE), + (Range(-2147483648 - 1, 2147483647), INT8RANGE), + (Range(-2147483648, 2147483647 + 1), INT8RANGE), + (Range(-2147483648 - 1, None), INT8RANGE), + (Range(None, 2147483647 + 1), INT8RANGE), + ) + def test_resolve_for_literal(self, obj, type_): + """This tests that the int4 / int8 version is selected correctly by + _resolve_for_literal.""" + lit = literal(obj) + eq_(type(lit.type), type_) + + @testing.combinations( + (Range(2, 7), INT4MULTIRANGE), + (Range(-10, 7), INT4MULTIRANGE), + (Range(None, -7), INT4MULTIRANGE), + (Range(33, None), INT4MULTIRANGE), + (Range(-2147483648, 2147483647), INT4MULTIRANGE), + (Range(-2147483648 - 1, 2147483647), INT8MULTIRANGE), + (Range(-2147483648, 2147483647 + 1), INT8MULTIRANGE), + (Range(-2147483648 - 1, None), INT8MULTIRANGE), + (Range(None, 2147483647 + 1), INT8MULTIRANGE), + ) + def test_resolve_for_literal_multi(self, obj, type_): + """This tests that the int4 / int8 version is selected correctly by + _resolve_for_literal.""" + list_ = MultiRange([Range(-1, 1), obj, Range(7, 100)]) + lit = literal(list_) + eq_(type(lit.type), type_) + + def test_multirange_sequence(self): + plain = [Range(-1, 1), Range(42, 43), Range(7, 100)] + mr = MultiRange(plain) + is_true(issubclass(MultiRange, list)) + is_true(isinstance(mr, list)) + eq_(mr, plain) + eq_(str(mr), str(plain)) + eq_(repr(mr), repr(plain)) + ne_(mr, plain[1:]) + + class _RangeTests: _col_type = None "The concrete range class these tests are for." @@ -3969,9 +4300,11 @@ def test_data_str(self, fn, op): self._test_clause( fn(self.col, self._data_str()), f"data_table.range {op} %(range_1)s", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) @testing.combinations(*_all_fns, id_="as") @@ -3979,9 +4312,11 @@ def test_data_obj(self, fn, op): self._test_clause( fn(self.col, self._data_obj()), f"data_table.range {op} %(range_1)s::{self._col_str}", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) @testing.combinations(*_comparisons, id_="as") @@ -3989,9 +4324,11 @@ def test_data_str_any(self, fn, op): self._test_clause( fn(self.col, any_(array([self._data_str()]))), f"data_table.range {op} ANY (ARRAY[%(param_1)s])", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) def test_where_is_null(self): @@ -4112,12 +4449,14 @@ def test_basic_py_sanity(self): ) is_true(range_.contains(values["il"])) + is_true(values["il"] in range_) is_false( range_.contains(Range(lower=values["ll"], upper=values["ih"])) ) is_false(range_.contains(values["rh"])) + is_false(values["rh"] in range_) is_true(range_ == range_) is_false(range_ != range_) @@ -4165,6 +4504,7 @@ def test_contains_value( ) r, expected = connection.execute(q).first() eq_(r.contains(v), expected) + eq_(v in r, expected) _common_ranges_to_test = ( lambda r, e: Range(empty=True), @@ -4225,6 +4565,12 @@ def test_contains_range(self, connection, r1t, r2t): f"{r1}.contains({r2}): got {py_contains}," f" expected {pg_contains}", ) + r2_in_r1 = r2 in r1 + eq_( + r2_in_r1, + pg_contains, + f"{r2} in {r1}: got {r2_in_r1}, expected {pg_contains}", + ) py_contained = r1.contained_by(r2) eq_( py_contained, @@ -4238,6 +4584,12 @@ def test_contains_range(self, connection, r1t, r2t): f"{r2}.contains({r1}: got {r2.contains(r1)}," f" expected {pg_contained})", ) + r1_in_r2 = r1 in r2 + eq_( + r1_in_r2, + pg_contained, + f"{r1} in {r2}: got {r1_in_r2}, expected {pg_contained}", + ) @testing.combinations( *_common_ranges_to_test, @@ -4637,11 +4989,21 @@ def test_auto_cast_back_to_type(self, connection): Brought up in #8540. """ + # see also CompileTest::test_range_custom_object_hook data_obj = self._data_obj() stmt = select(literal(data_obj, type_=self._col_type)) round_trip = connection.scalar(stmt) eq_(round_trip, data_obj) + def test_auto_cast_back_to_type_without_type(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_range_custom_object_hook + data_obj = self._data_obj() + lit = literal(data_obj) + round_trip = connection.scalar(select(lit)) + eq_(round_trip, data_obj) + eq_(type(lit.type), self._col_type) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -5136,10 +5498,17 @@ def test_difference(self): ) -class _MultiRangeTypeRoundTrip(fixtures.TablesTest): +class _MultiRangeTypeRoundTrip(fixtures.TablesTest, _RangeTests): __requires__ = ("multirange_types",) __backend__ = True + @testing.fixture(params=(True, False), ids=["multirange", "plain_list"]) + def data_obj(self, request): + if request.param: + return MultiRange(self._data_obj()) + else: + return list(self._data_obj()) + @classmethod def define_tables(cls, metadata): # no reason ranges shouldn't be primary keys, @@ -5151,7 +5520,7 @@ def define_tables(cls, metadata): ) cls.col = table.c.range - def test_auto_cast_back_to_type(self, connection): + def test_auto_cast_back_to_type(self, connection, data_obj): """test that a straight pass of the range type without any context will send appropriate casting info so that the driver can round trip it. @@ -5166,11 +5535,29 @@ def test_auto_cast_back_to_type(self, connection): Brought up in #8540. """ - data_obj = self._data_obj() + # see also CompileTest::test_multirange_custom_object_hook stmt = select(literal(data_obj, type_=self._col_type)) round_trip = connection.scalar(stmt) eq_(round_trip, data_obj) + def test_auto_cast_back_to_type_without_type(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_multirange_custom_object_hook + data_obj = MultiRange(self._data_obj()) + lit = literal(data_obj) + round_trip = connection.scalar(select(lit)) + eq_(round_trip, data_obj) + eq_(type(lit.type), self._col_type) + + @testing.fails("no automatic adaptation of plain list") + def test_auto_cast_back_to_type_without_type_plain_list(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_multirange_custom_object_hook + data_obj = list(self._data_obj()) + lit = literal(data_obj) + r = connection.scalar(select(lit)) + eq_(type(r), list) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -5184,12 +5571,12 @@ def test_reflect(self, connection): def _assert_data(self, conn): data = conn.execute(select(self.tables.data_table.c.range)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) - def test_textual_round_trip_w_dialect_type(self, connection): + def test_textual_round_trip_w_dialect_type(self, connection, data_obj): """test #8690""" data_table = self.tables.data_table - data_obj = self._data_obj() connection.execute( self.tables.data_table.insert(), {"range": data_obj} ) @@ -5202,9 +5589,9 @@ def test_textual_round_trip_w_dialect_type(self, connection): eq_(data_obj, v2) - def test_insert_obj(self, connection): + def test_insert_obj(self, connection, data_obj): connection.execute( - self.tables.data_table.insert(), {"range": self._data_obj()} + self.tables.data_table.insert(), {"range": data_obj} ) self._assert_data(connection) @@ -5225,6 +5612,7 @@ def test_union_result_text(self, connection): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ + range_)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) @testing.requires.psycopg_or_pg8000_compatibility def test_intersection_result_text(self, connection): @@ -5236,6 +5624,7 @@ def test_intersection_result_text(self, connection): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ * range_)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) @testing.requires.psycopg_or_pg8000_compatibility def test_difference_result_text(self, connection): @@ -5247,6 +5636,7 @@ def test_difference_result_text(self, connection): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ - range_)).fetchall() eq_(data, [([],)]) + eq_(type(data[0][0]), MultiRange) class _Int4MultiRangeTests: @@ -5257,11 +5647,7 @@ def _data_str(self): return "{[1,2), [3, 5), [9, 12)}" def _data_obj(self): - return [ - Range(1, 2), - Range(3, 5), - Range(9, 12), - ] + return [Range(1, 2), Range(3, 5), Range(9, 12)] class _Int8MultiRangeTests: @@ -5345,31 +5731,29 @@ class _DateTimeTZMultiRangeTests: _tstzs_delta = None def tstzs(self): - utc_now = cast( - func.current_timestamp().op("AT TIME ZONE")("utc"), - DateTime(timezone=True), + # note this was hitting DST issues when these tests were using a + # live date and running on or near 2024-03-09 :). hardcoded to a + # date a few days earlier + utc_now = datetime.datetime( + 2024, 3, 2, 14, 57, 50, 473566, tzinfo=datetime.timezone.utc ) if self._tstzs is None: - with testing.db.connect() as connection: - lower = connection.scalar(select(utc_now)) - upper = lower + datetime.timedelta(1) - self._tstzs = (lower, upper) + lower = utc_now + upper = lower + datetime.timedelta(1) + self._tstzs = (lower, upper) return self._tstzs def tstzs_delta(self): - utc_now = cast( - func.current_timestamp().op("AT TIME ZONE")("utc"), - DateTime(timezone=True), + utc_now = datetime.datetime( + 2024, 3, 2, 14, 57, 50, 473566, tzinfo=datetime.timezone.utc ) if self._tstzs_delta is None: - with testing.db.connect() as connection: - lower = connection.scalar( - select(utc_now) - ) + datetime.timedelta(3) - upper = lower + datetime.timedelta(2) - self._tstzs_delta = (lower, upper) + lower = utc_now + datetime.timedelta(3) + upper = lower + datetime.timedelta(2) + self._tstzs_delta = (lower, upper) + return self._tstzs_delta def _data_str(self): @@ -5461,6 +5845,17 @@ class DateTimeTZRMultiangeRoundTripTest( pass +class MultiRangeSequenceTest(fixtures.TestBase): + def test_methods(self): + plain = [Range(1, 3), Range(5, 9)] + multi = MultiRange(plain) + is_true(isinstance(multi, list)) + eq_(multi, plain) + ne_(multi, plain[:1]) + eq_(str(multi), str(plain)) + eq_(repr(multi), repr(plain)) + + class JSONTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" @@ -5887,7 +6282,7 @@ def setup_test(self): lambda self: self.jsoncol.has_all( {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}} ), - "test_table.test_column ?& %(test_column_1)s", + "test_table.test_column ?& %(test_column_1)s::JSONB", ), ( lambda self: self.jsoncol.has_all(self.any_), @@ -5905,7 +6300,7 @@ def setup_test(self): ), ( lambda self: self.jsoncol.contains({"k1": "r1v1"}), - "test_table.test_column @> %(test_column_1)s", + "test_table.test_column @> %(test_column_1)s::JSONB", ), ( lambda self: self.jsoncol.contains(self.any_), @@ -5913,7 +6308,7 @@ def setup_test(self): ), ( lambda self: self.jsoncol.contained_by({"foo": "1", "bar": None}), - "test_table.test_column <@ %(test_column_1)s", + "test_table.test_column <@ %(test_column_1)s::JSONB", ), ( lambda self: self.jsoncol.contained_by(self.any_), @@ -6279,9 +6674,11 @@ def test_imv_returning_datatypes( t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index d6e444bb301..0d8c671402d 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1,4 +1,5 @@ """SQLite-specific tests.""" + import datetime import json import os @@ -52,6 +53,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.types import Boolean from sqlalchemy.types import Date @@ -83,12 +85,12 @@ def test_boolean(self, connection, metadata): ) metadata.create_all(connection) for stmt in [ - "INSERT INTO bool_table (id, boo) " "VALUES (1, 'false');", - "INSERT INTO bool_table (id, boo) " "VALUES (2, 'true');", - "INSERT INTO bool_table (id, boo) " "VALUES (3, '1');", - "INSERT INTO bool_table (id, boo) " "VALUES (4, '0');", - "INSERT INTO bool_table (id, boo) " "VALUES (5, 1);", - "INSERT INTO bool_table (id, boo) " "VALUES (6, 0);", + "INSERT INTO bool_table (id, boo) VALUES (1, 'false');", + "INSERT INTO bool_table (id, boo) VALUES (2, 'true');", + "INSERT INTO bool_table (id, boo) VALUES (3, '1');", + "INSERT INTO bool_table (id, boo) VALUES (4, '0');", + "INSERT INTO bool_table (id, boo) VALUES (5, 1);", + "INSERT INTO bool_table (id, boo) VALUES (6, 0);", ]: connection.exec_driver_sql(stmt) @@ -652,7 +654,7 @@ def test_quoted_identifiers_functional_one(self): @testing.provide_metadata def test_quoted_identifiers_functional_two(self): - """ "test the edgiest of edge cases, quoted table/col names + """test the edgiest of edge cases, quoted table/col names that start and end with quotes. SQLite claims to have fixed this in @@ -740,7 +742,7 @@ def test_pool_class(self): ), ), ( - "sqlite:///file:path/to/database?" "mode=ro&uri=true", + "sqlite:///file:path/to/database?mode=ro&uri=true", ( ["file:path/to/database?mode=ro"], {"uri": True, "check_same_thread": False}, @@ -783,6 +785,16 @@ def test_column_computed(self, text, persisted): " y INTEGER GENERATED ALWAYS AS (x + 2)%s)" % text, ) + @testing.combinations( + (func.localtimestamp(),), + (func.now(),), + (func.char_length("test"),), + (func.aggregate_strings("abc", ","),), + argnames="fn", + ) + def test_builtin_functions_roundtrip(self, fn, connection): + connection.execute(select(fn)) + class AttachedDBTest(fixtures.TablesTest): __only_on__ = "sqlite" @@ -912,7 +924,6 @@ def test_col_targeting_union(self, connection): class SQLTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests SQLite-dialect specific compilation.""" __dialect__ = sqlite.dialect() @@ -968,7 +979,7 @@ def test_is_distinct_from(self): def test_localtime(self): self.assert_compile( - func.localtimestamp(), 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + func.localtimestamp(), "DATETIME(CURRENT_TIMESTAMP, 'localtime')" ) def test_constraints_with_schemas(self): @@ -1031,39 +1042,60 @@ def test_constraints_with_schemas(self): ")", ) - def test_column_defaults_ddl(self): + @testing.combinations( + ( + Boolean(create_constraint=True), + sql.false(), + "BOOLEAN DEFAULT 0, CHECK (x IN (0, 1))", + ), + ( + String(), + func.sqlite_version(), + "VARCHAR DEFAULT (sqlite_version())", + ), + (Integer(), func.abs(-5) + 17, "INTEGER DEFAULT (abs(-5) + 17)"), + ( + # test #12425 + String(), + func.now(), + "VARCHAR DEFAULT CURRENT_TIMESTAMP", + ), + ( + # test #12425 + String(), + func.datetime(func.now(), "localtime"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # test #12425 + String(), + text("datetime(CURRENT_TIMESTAMP, 'localtime')"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # default with leading spaces that should not be + # parenthesized + String, + text(" 'some default'"), + "VARCHAR DEFAULT 'some default'", + ), + (String, text("'some default'"), "VARCHAR DEFAULT 'some default'"), + argnames="datatype,default,expected", + ) + def test_column_defaults_ddl(self, datatype, default, expected): t = Table( "t", MetaData(), Column( "x", - Boolean(create_constraint=True), - server_default=sql.false(), + datatype, + server_default=default, ), ) self.assert_compile( CreateTable(t), - "CREATE TABLE t (x BOOLEAN DEFAULT (0), CHECK (x IN (0, 1)))", - ) - - t = Table( - "t", - MetaData(), - Column("x", String(), server_default=func.sqlite_version()), - ) - self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x VARCHAR DEFAULT (sqlite_version()))", - ) - - t = Table( - "t", - MetaData(), - Column("x", Integer(), server_default=func.abs(-5) + 17), - ) - self.assert_compile( - CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT (abs(-5) + 17))" + f"CREATE TABLE t (x {expected})", ) def test_create_partial_index(self): @@ -1144,6 +1176,28 @@ def test_create_table_without_rowid(self): "CREATE TABLE atable (id INTEGER) WITHOUT ROWID", ) + def test_create_table_strict(self): + m = MetaData() + table = Table("atable", m, Column("id", Integer), sqlite_strict=True) + self.assert_compile( + schema.CreateTable(table), + "CREATE TABLE atable (id INTEGER) STRICT", + ) + + def test_create_table_without_rowid_strict(self): + m = MetaData() + table = Table( + "atable", + m, + Column("id", Integer), + sqlite_with_rowid=False, + sqlite_strict=True, + ) + self.assert_compile( + schema.CreateTable(table), + "CREATE TABLE atable (id INTEGER) WITHOUT ROWID, STRICT", + ) + class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = sqlite.dialect() @@ -1155,7 +1209,7 @@ def test_on_conflict_clause_column_not_null(self): self.assert_compile( schema.CreateColumn(c), - "test INTEGER NOT NULL " "ON CONFLICT FAIL", + "test INTEGER NOT NULL ON CONFLICT FAIL", dialect=sqlite.dialect(), ) @@ -1194,7 +1248,7 @@ def test_on_conflict_clause_unique_constraint_from_column(self): self.assert_compile( CreateTable(t), - "CREATE TABLE n (x VARCHAR(30), " "UNIQUE (x) ON CONFLICT FAIL)", + "CREATE TABLE n (x VARCHAR(30), UNIQUE (x) ON CONFLICT FAIL)", dialect=sqlite.dialect(), ) @@ -1314,7 +1368,6 @@ def test_on_conflict_clause_primary_key_constraint(self): class InsertTest(fixtures.TestBase, AssertsExecutionResults): - """Tests inserts and autoincrement.""" __only_on__ = "sqlite" @@ -1816,6 +1869,29 @@ def setup_test_class(cls): Table("q", meta, Column("id", Integer), PrimaryKeyConstraint("id")) + # intentional new line + Table( + "r", + meta, + Column("id", Integer), + Column("value", Integer), + Column("prefix", String), + CheckConstraint("id > 0"), + UniqueConstraint("prefix", name="prefix_named"), + # Constraint definition with newline and tab characters + CheckConstraint( + """((value > 0) AND \n\t(value < 100) AND \n\t + (value != 50))""", + name="ck_r_value_multiline", + ), + UniqueConstraint("value"), + # Constraint name with special chars and 'check' in the name + CheckConstraint("value IS NOT NULL", name="^check-r* #\n\t"), + PrimaryKeyConstraint("id", name="pk_name"), + # Constraint definition with special characters. + CheckConstraint("prefix NOT GLOB '*[^-. /#,]*'"), + ) + meta.create_all(conn) # will contain an "autoindex" @@ -1851,8 +1927,20 @@ def setup_test_class(cls): conn.exec_driver_sql( "CREATE TABLE cp (" - "q INTEGER check (q > 1 AND q < 6),\n" - "CONSTRAINT cq CHECK (q == 1 OR (q > 2 AND q < 5))\n" + "id INTEGER NOT NULL,\n" + "q INTEGER, \n" + "p INTEGER, \n" + "CONSTRAINT cq CHECK (p = 1 OR (p > 2 AND p < 5)),\n" + "PRIMARY KEY (id)\n" + ")" + ) + + conn.exec_driver_sql( + "CREATE TABLE cp_inline (\n" + "id INTEGER NOT NULL,\n" + "q INTEGER CHECK (q > 1 AND q < 6), \n" + "p INTEGER CONSTRAINT cq CHECK (p = 1 OR (p > 2 AND p < 5)),\n" + "PRIMARY KEY (id)\n" ")" ) @@ -1911,6 +1999,7 @@ def teardown_test_class(cls): "b", "a1", "a2", + "r", ]: conn.exec_driver_sql("drop table %s" % name) @@ -2426,6 +2515,27 @@ def test_unique_constraint_unnamed_normal_temporary( [{"column_names": ["x"], "name": None}], ) + def test_unique_constraint_mixed_into_ck(self, connection): + """test #11832""" + + inspector = inspect(connection) + eq_( + inspector.get_unique_constraints("r"), + [ + {"name": "prefix_named", "column_names": ["prefix"]}, + {"name": None, "column_names": ["value"]}, + ], + ) + + def test_primary_key_constraint_mixed_into_ck(self, connection): + """test #11832""" + + inspector = inspect(connection) + eq_( + inspector.get_pk_constraint("r"), + {"constrained_columns": ["id"], "name": "pk_name"}, + ) + def test_primary_key_constraint_named(self): inspector = inspect(testing.db) eq_( @@ -2447,16 +2557,46 @@ def test_primary_key_constraint_no_pk(self): {"constrained_columns": [], "name": None}, ) - def test_check_constraint(self): + def test_check_constraint_plain(self): inspector = inspect(testing.db) eq_( inspector.get_check_constraints("cp"), [ - {"sqltext": "q == 1 OR (q > 2 AND q < 5)", "name": "cq"}, + {"sqltext": "p = 1 OR (p > 2 AND p < 5)", "name": "cq"}, + ], + ) + + def test_check_constraint_inline_plain(self): + inspector = inspect(testing.db) + eq_( + inspector.get_check_constraints("cp_inline"), + [ + {"sqltext": "p = 1 OR (p > 2 AND p < 5)", "name": "cq"}, {"sqltext": "q > 1 AND q < 6", "name": None}, ], ) + @testing.fails("need to come up with new regex and/or DDL parsing") + def test_check_constraint_multiline(self): + """test for #11677""" + + inspector = inspect(testing.db) + eq_( + inspector.get_check_constraints("r"), + [ + {"sqltext": "value IS NOT NULL", "name": "^check-r* #\n\t"}, + # Triple-quote multi-line definition should have added a + # newline and whitespace: + { + "sqltext": "((value > 0) AND \n\t(value < 100) AND \n\t\n" + " (value != 50))", + "name": "ck_r_value_multiline", + }, + {"sqltext": "id > 0", "name": None}, + {"sqltext": "prefix NOT GLOB '*[^-. /#,]*'", "name": None}, + ], + ) + @testing.combinations( ("plain_name", "plain_name"), ("name with spaces", "name with spaces"), @@ -2466,17 +2606,27 @@ def test_check_constraint(self): argnames="colname,expected", ) @testing.combinations( - "uq", "uq_inline", "pk", "ix", argnames="constraint_type" + "uq", + "uq_inline", + "uq_inline_tab_before", # tab before column params + "uq_inline_tab_within", # tab within column params + "pk", + "ix", + argnames="constraint_type", ) def test_constraint_cols( self, colname, expected, constraint_type, connection, metadata ): - if constraint_type == "uq_inline": + if constraint_type.startswith("uq_inline"): + inline_create_sql = { + "uq_inline": "CREATE TABLE t (%s INTEGER UNIQUE)", + "uq_inline_tab_before": "CREATE TABLE t (%s\tINTEGER UNIQUE)", + "uq_inline_tab_within": "CREATE TABLE t (%s INTEGER\tUNIQUE)", + } + t = Table("t", metadata, Column(colname, Integer)) connection.exec_driver_sql( - """ - CREATE TABLE t (%s INTEGER UNIQUE) - """ + inline_create_sql[constraint_type] % connection.dialect.identifier_preparer.quote(colname) ) else: @@ -2494,7 +2644,12 @@ def test_constraint_cols( t.create(connection) - if constraint_type in ("uq", "uq_inline"): + if constraint_type in ( + "uq", + "uq_inline", + "uq_inline_tab_before", + "uq_inline_tab_within", + ): const = inspect(connection).get_unique_constraints("t")[0] eq_(const["column_names"], [expected]) elif constraint_type == "pk": @@ -2508,7 +2663,6 @@ def test_constraint_cols( class SavepointTest(fixtures.TablesTest): - """test that savepoints work when we use the correct event setup""" __only_on__ = "sqlite" @@ -2743,8 +2897,7 @@ def _only_on_py38_w_sqlite_39(): """in python 3.9 and above you can actually do:: @(testing.requires.python38 + testing.only_on("sqlite > 3.9")) - def test_determinsitic_parameter(self): - ... + def test_determinsitic_parameter(self): ... that'll be cool. until then... @@ -2841,7 +2994,173 @@ def test_regexp_replace(self): ) -class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest): +class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase): + __dialect__ = "sqlite" + + @testing.combinations( + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where=text("name = 'hi'") + ), + "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where="name = 'hi'" + ), + exc.ArgumentError, + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=[users.c.id], index_where=users.c.name == "hi" + ), + "ON CONFLICT (id) WHERE name = __[POSTCOMPILE_name_1] DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=users.c.name == "hi", + ), + "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE users.name = ?", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=text("name = 'hi'"), + ), + "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE name = 'hi'", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where="name = 'hi'", + ), + exc.ArgumentError, + ), + argnames="case,expected", + ) + def test_assorted_arg_coercion(self, users, case, expected): + stmt = insert(users) + + if isinstance(expected, type) and issubclass(expected, Exception): + with expect_raises(expected): + testing.resolve_lambda(case, stmt=stmt, users=users), + else: + self.assert_compile( + testing.resolve_lambda(case, stmt=stmt, users=users), + f"INSERT INTO users (id, name) VALUES (?, ?) {expected}", + ) + + @testing.combinations("control", "excluded", "dict", argnames="scenario") + def test_set_excluded(self, scenario, users, users_w_key): + """test #8014, sending all of .excluded to set""" + + if scenario == "control": + + stmt = insert(users) + self.assert_compile( + stmt.on_conflict_do_update(set_=stmt.excluded), + "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + else: + + stmt = insert(users_w_key) + + if scenario == "excluded": + self.assert_compile( + stmt.on_conflict_do_update(set_=stmt.excluded), + "INSERT INTO users_w_key (id, name) VALUES (?, ?) " + "ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + else: + self.assert_compile( + stmt.on_conflict_do_update( + set_={ + "id": stmt.excluded.id, + "name_keyed": stmt.excluded.name_keyed, + } + ), + "INSERT INTO users_w_key (id, name) VALUES (?, ?) " + "ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + + def test_on_conflict_do_update_exotic_targets_six(self, users_xtra): + users = users_xtra + + unique_partial_index = schema.Index( + "idx_unique_partial_name", + users_xtra.c.name, + users_xtra.c.lets_index_this, + unique=True, + sqlite_where=users_xtra.c.lets_index_this == "unique_name", + ) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=unique_partial_index.columns, + index_where=unique_partial_index.dialect_options["sqlite"][ + "where" + ], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + # this test illustrates that the index_where clause can't use + # bound parameters, where we see below a literal_execute parameter is + # used (will be sent as literal to the DBAPI). SQLite otherwise + # fails here with "(sqlite3.OperationalError) ON CONFLICT clause does + # not match any PRIMARY KEY or UNIQUE constraint" if sent as a real + # bind parameter. + self.assert_compile( + i, + "INSERT INTO users_xtra (id, name, login_email, lets_index_this) " + "VALUES (?, ?, ?, ?) ON CONFLICT (name, lets_index_this) " + "WHERE lets_index_this = __[POSTCOMPILE_lets_index_this_1] " + "DO UPDATE " + "SET name = excluded.name, login_email = excluded.login_email", + ) + + @testing.fixture + def users(self): + metadata = MetaData() + return Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + + @testing.fixture + def users_w_key(self): + metadata = MetaData() + return Table( + "users_w_key", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), key="name_keyed"), + ) + + @testing.fixture + def users_xtra(self): + metadata = MetaData() + return Table( + "users_xtra", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("login_email", String(50)), + Column("lets_index_this", String(50)), + ) + + +class OnConflictTest(fixtures.TablesTest): __only_on__ = ("sqlite >= 3.24.0",) __backend__ = True @@ -2901,49 +3220,8 @@ def process_bind_param(self, value, dialect): ) def test_bad_args(self): - assert_raises( - ValueError, insert(self.tables.users).on_conflict_do_update - ) - - @testing.combinations("control", "excluded", "dict") - @testing.skip_if("+pysqlite_numeric") - @testing.skip_if("+pysqlite_dollar") - def test_set_excluded(self, scenario): - """test #8014, sending all of .excluded to set""" - - if scenario == "control": - users = self.tables.users - - stmt = insert(users) - self.assert_compile( - stmt.on_conflict_do_update(set_=stmt.excluded), - "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) - else: - users_w_key = self.tables.users_w_key - - stmt = insert(users_w_key) - - if scenario == "excluded": - self.assert_compile( - stmt.on_conflict_do_update(set_=stmt.excluded), - "INSERT INTO users_w_key (id, name) VALUES (?, ?) " - "ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) - else: - self.assert_compile( - stmt.on_conflict_do_update( - set_={ - "id": stmt.excluded.id, - "name_keyed": stmt.excluded.name_keyed, - } - ), - "INSERT INTO users_w_key (id, name) VALUES (?, ?) " - "ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) + with expect_raises(ValueError): + insert(self.tables.users).on_conflict_do_update() def test_on_conflict_do_no_call_twice(self): users = self.tables.users @@ -3464,7 +3742,7 @@ def test_on_conflict_do_update_no_row_actually_affected(self, connection): ) # The last inserted primary key should be 2 here - # it is taking the result from the the exotic fixture + # it is taking the result from the exotic fixture eq_(result.inserted_primary_key, (2,)) eq_( @@ -3568,3 +3846,100 @@ def test_get_temp_view_names(self, connection): eq_(res, ["sqlitetempview"]) finally: connection.exec_driver_sql("DROP VIEW sqlitetempview") + + +class ComputedReflectionTest(fixtures.TestBase): + __only_on__ = "sqlite" + __backend__ = True + + @classmethod + def setup_test_class(cls): + tables = [ + """CREATE TABLE test1 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') + );""", + """CREATE TABLE test2 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x'), + y VARCHAR GENERATED ALWAYS AS (s || 'y') + );""", + """CREATE TABLE test3 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) + );""", + """CREATE TABLE test4 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")), + y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")));""", + """CREATE TABLE test5 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED + );""", + """CREATE TABLE test6 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED, + y VARCHAR GENERATED ALWAYS AS (s || 'y') STORED + );""", + """CREATE TABLE test7 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED + );""", + """CREATE TABLE test8 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED, + y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")) STORED + );""", + ] + + with testing.db.begin() as conn: + for ct in tables: + conn.exec_driver_sql(ct) + + @classmethod + def teardown_test_class(cls): + with testing.db.begin() as conn: + for tn in cls.res: + conn.exec_driver_sql(f"DROP TABLE {tn}") + + res = { + "test1": {"x": {"text": "s || 'x'", "stored": False}}, + "test2": { + "x": {"text": "s || 'x'", "stored": False}, + "y": {"text": "s || 'y'", "stored": False}, + }, + "test3": {"x": {"text": 'INSTR(s, ",")', "stored": False}}, + "test4": { + "x": {"text": 'INSTR(s, ",")', "stored": False}, + "y": {"text": 'INSTR(x, ",")', "stored": False}, + }, + "test5": {"x": {"text": "s || 'x'", "stored": True}}, + "test6": { + "x": {"text": "s || 'x'", "stored": True}, + "y": {"text": "s || 'y'", "stored": True}, + }, + "test7": {"x": {"text": 'INSTR(s, ",")', "stored": True}}, + "test8": { + "x": {"text": 'INSTR(s, ",")', "stored": True}, + "y": {"text": 'INSTR(x, ",")', "stored": True}, + }, + } + + def test_reflection(self, connection): + meta = MetaData() + meta.reflect(connection) + eq_(len(meta.tables), len(self.res)) + for tbl in meta.tables.values(): + data = self.res[tbl.name] + seen = set() + for col in tbl.c: + if col.name not in data: + is_(col.computed, None) + else: + info = data[col.name] + seen.add(col.name) + msg = f"{tbl.name}-{col.name}" + is_true(bool(col.computed)) + eq_(col.computed.sqltext.text, info["text"], msg) + eq_(col.computed.persisted, info["stored"], msg) + eq_(seen, data.keys()) diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index f6fa21f29dd..30bf9e66f64 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -300,10 +300,6 @@ def test_connection_fairy_connection(self): is_(fairy.connection, fairy.dbapi_connection) -def select1(db): - return str(select(1).compile(dialect=db.dialect)) - - class ResetEventTest(fixtures.TestBase): def _fixture(self, **kw): dbapi = Mock() @@ -500,3 +496,21 @@ def test_implicit_returning_engine_parameter(self, implicit_returning): ) # parameter has no effect + + +class AsyncFallbackDeprecationTest(fixtures.TestBase): + __requires__ = ("greenlet",) + + def test_async_fallback_deprecated(self): + with assertions.expect_deprecated( + "The async_fallback dialect argument is deprecated and will be " + "removed in SQLAlchemy 2.1.", + ): + create_engine( + "postgresql+asyncpg://?async_fallback=True", module=mock.Mock() + ) + + def test_async_fallback_false_is_ok(self): + create_engine( + "postgresql+asyncpg://?async_fallback=False", module=mock.Mock() + ) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 6080f3dc6d0..28541ca33a1 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -34,6 +34,7 @@ from sqlalchemy.engine import default from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Engine +from sqlalchemy.pool import AsyncAdaptedQueuePool from sqlalchemy.pool import NullPool from sqlalchemy.pool import QueuePool from sqlalchemy.sql import column @@ -554,7 +555,7 @@ def test_stmt_exception_pickleable_no_dbapi(self): "Older versions don't support cursor pickling, newer ones do", ) @testing.fails_on( - "mysql+mysqlconnector", + "+mysqlconnector", "Exception doesn't come back exactly the same from pickle", ) @testing.fails_on( @@ -1781,6 +1782,38 @@ def test_per_engine_plus_global(self, testing_engine): eq_(canary.be2.call_count, 1) eq_(canary.be3.call_count, 2) + @testing.requires.ad_hoc_engines + def test_option_engine_registration_issue_one(self): + """test #12289""" + + e1 = create_engine(testing.db.url) + e2 = e1.execution_options(foo="bar") + e3 = e2.execution_options(isolation_level="AUTOCOMMIT") + + eq_( + e3._execution_options, + {"foo": "bar", "isolation_level": "AUTOCOMMIT"}, + ) + + @testing.requires.ad_hoc_engines + def test_option_engine_registration_issue_two(self): + """test #12289""" + + e1 = create_engine(testing.db.url) + e2 = e1.execution_options(foo="bar") + + @event.listens_for(e2, "engine_connect") + def r1(*arg, **kw): + pass + + e3 = e2.execution_options(bat="hoho") + + @event.listens_for(e3, "engine_connect") + def r2(*arg, **kw): + pass + + eq_(e3._execution_options, {"foo": "bar", "bat": "hoho"}) + def test_emit_sql_in_autobegin(self, testing_engine): e1 = testing_engine(config.db_url) @@ -1939,13 +1972,10 @@ def go2(dbapi_conn, xyz): def test_new_exec_driver_sql_no_events(self): m1 = Mock() - def select1(db): - return str(select(1).compile(dialect=db.dialect)) - with testing.db.connect() as conn: event.listen(conn, "before_execute", m1.before_execute) event.listen(conn, "after_execute", m1.after_execute) - conn.exec_driver_sql(select1(testing.db)) + conn.exec_driver_sql(str(select(1).compile(testing.db))) eq_(m1.mock_calls, []) def test_add_event_after_connect(self, testing_engine): @@ -2411,7 +2441,15 @@ def test_dispose_event(self, testing_engine): @testing.combinations(True, False, argnames="close") def test_close_parameter(self, testing_engine, close): eng = testing_engine( - options=dict(pool_size=1, max_overflow=0, poolclass=QueuePool) + options=dict( + pool_size=1, + max_overflow=0, + poolclass=( + QueuePool + if not testing.db.dialect.is_async + else AsyncAdaptedQueuePool + ), + ) ) conn = eng.connect() @@ -3654,12 +3692,12 @@ def mock_the_cursor(cursor, *arg): arg[-1].get_result_proxy = Mock(return_value=Mock(context=arg[-1])) return retval - m1.real_do_execute.side_effect = ( - m1.do_execute.side_effect - ) = mock_the_cursor - m1.real_do_executemany.side_effect = ( - m1.do_executemany.side_effect - ) = mock_the_cursor + m1.real_do_execute.side_effect = m1.do_execute.side_effect = ( + mock_the_cursor + ) + m1.real_do_executemany.side_effect = m1.do_executemany.side_effect = ( + mock_the_cursor + ) m1.real_do_execute_no_params.side_effect = ( m1.do_execute_no_params.side_effect ) = mock_the_cursor diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index a498ec85c83..119d5533201 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -990,6 +990,43 @@ def test_logging_token_option_connection(self, token_engine): c2.close() c3.close() + def test_logging_token_option_connection_updates(self, token_engine): + """test #11210""" + + eng = token_engine + + c1 = eng.connect().execution_options(logging_token="my_name_1") + + self._assert_token_in_execute(c1, "my_name_1") + + c1.execution_options(logging_token="my_name_2") + + self._assert_token_in_execute(c1, "my_name_2") + + c1.execution_options(logging_token=None) + + self._assert_no_tokens_in_execute(c1) + + c1.close() + + def test_logging_token_option_not_transactional(self, token_engine): + """test #11210""" + + eng = token_engine + + c1 = eng.connect() + + with c1.begin(): + self._assert_no_tokens_in_execute(c1) + + c1.execution_options(logging_token="my_name_1") + + self._assert_token_in_execute(c1, "my_name_1") + + self._assert_token_in_execute(c1, "my_name_1") + + c1.close() + def test_logging_token_option_engine(self, token_engine): eng = token_engine diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 4c144a4a31a..7c562bf39d1 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -62,13 +62,33 @@ class URLTest(fixtures.TestBase): "dbtype://username:password@hostspec/test database with@atsign", "dbtype://username:password@hostspec?query=but_no_db", "dbtype://username:password@hostspec:450?query=but_no_db", + "dbtype://username:password with spaces@hostspec:450?query=but_no_db", + "dbtype+apitype://username with space+and+plus:" + "password with space+and+plus@" + "hostspec:450?query=but_no_db", + "dbtype://user%25%26%7C:pass%25%26%7C@hostspec:499?query=but_no_db", + "dbtype://user🐍測試:pass🐍測試@hostspec:499?query=but_no_db", ) def test_rfc1738(self, text): u = url.make_url(text) assert u.drivername in ("dbtype", "dbtype+apitype") - assert u.username in ("username", None) - assert u.password in ("password", "apples/oranges", None) + assert u.username in ( + "username", + "user%&|", + "username with space+and+plus", + "user🐍測試", + None, + ) + assert u.password in ( + "password", + "password with spaces", + "password with space+and+plus", + "apples/oranges", + "pass%&|", + "pass🐍測試", + None, + ) assert u.host in ( "hostspec", "127.0.0.1", @@ -87,7 +107,8 @@ def test_rfc1738(self, text): "E:/work/src/LEM/db/hello.db", None, ), u.database - eq_(u.render_as_string(hide_password=False), text) + + eq_(url.make_url(u.render_as_string(hide_password=False)), u) def test_rfc1738_password(self): u = url.make_url("dbtype://user:pass word + other%3Awords@host/dbname") @@ -352,7 +373,7 @@ def test_create_engine_url_invalid(self): ( "foo1=bar1&foo2=bar21&foo2=bar22&foo3=bar31", "foo2=bar23&foo3=bar32&foo3=bar33", - "foo1=bar1&foo2=bar23&" "foo3=bar32&foo3=bar33", + "foo1=bar1&foo2=bar23&foo3=bar32&foo3=bar33", False, ), ) @@ -552,7 +573,7 @@ def test_engine_from_config(self): e = engine_from_config(config, module=dbapi, _initialize=False) assert e.pool._recycle == 50 assert e.url == url.make_url( - "postgresql+psycopg2://scott:tiger@somehost/test?foo" "z=somevalue" + "postgresql+psycopg2://scott:tiger@somehost/test?fooz=somevalue" ) assert e.echo is True @@ -770,6 +791,13 @@ def test_bad_args(self): module=mock_dbapi, ) + def test_cant_parse_str(self): + with expect_raises_message( + exc.ArgumentError, + r"^Could not parse SQLAlchemy URL from given URL string$", + ): + create_engine("notarealurl") + def test_urlattr(self): """test the url attribute on ``Engine``.""" diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 44c494bad4a..49736df9b65 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -460,8 +460,10 @@ def _checkin_event_fixture(self, _is_asyncio=False, _has_terminate=False): @event.listens_for(p, "reset") def reset(conn, rec, state): canary.append( - f"""reset_{'rollback_ok' - if state.asyncio_safe else 'no_rollback'}""" + f"""reset_{ + 'rollback_ok' + if state.asyncio_safe else 'no_rollback' + }""" ) @event.listens_for(p, "checkin") diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index a7883efa2fd..e1515a23a86 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1581,9 +1581,9 @@ def _run_with_retries(fn, context, cursor, statement, *arg, **kw): connection.rollback() time.sleep(retry_interval) - context.cursor = ( - cursor - ) = connection.connection.cursor() + context.cursor = cursor = ( + connection.connection.cursor() + ) else: raise else: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 003b457a51a..adb40370655 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,3 +1,4 @@ +import itertools import unicodedata import sqlalchemy as sa @@ -19,6 +20,8 @@ from sqlalchemy import testing from sqlalchemy import UniqueConstraint from sqlalchemy.engine import Inspector +from sqlalchemy.engine.reflection import cache +from sqlalchemy.sql.elements import quoted_name from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -2494,3 +2497,162 @@ def test_table_works_minus_fks(self, connection, tab_w_fks): "SELECT b_1.x, b_1.q, b_1.p, b_1.r, b_1.s, b_1.t " "FROM b AS b_1 JOIN a ON a.x = b_1.r", ) + + +class ReflectionCacheTest(fixtures.TestBase): + @testing.fixture(params=["arg", "kwarg"]) + def cache(self, connection, request): + dialect = connection.dialect + info_cache = {} + counter = itertools.count(1) + + @cache + def get_cached_name(self, connection, *args, **kw): + return next(counter) + + def get_cached_name_via_arg(name): + return get_cached_name( + dialect, connection, name, info_cache=info_cache + ) + + def get_cached_name_via_kwarg(name): + return get_cached_name( + dialect, connection, name=name, info_cache=info_cache + ) + + if request.param == "arg": + yield get_cached_name_via_arg + elif request.param == "kwarg": + yield get_cached_name_via_kwarg + else: + assert False + + @testing.fixture(params=[False, True]) + def quote(self, request): + yield request.param + + def test_single_string(self, cache): + # new value + eq_(cache("name1"), 1) + + # same value, counter not incremented + eq_(cache("name1"), 1) + + def test_multiple_string(self, cache): + # new value + eq_(cache("name1"), 1) + eq_(cache("name2"), 2) + + # same values, counter not incremented + eq_(cache("name1"), 1) + eq_(cache("name2"), 2) + + def test_single_quoted_name(self, cache, quote): + # new value + eq_(cache(quoted_name("name1", quote=quote)), 1) + + # same value, counter not incremented + eq_(cache(quoted_name("name1", quote=quote)), 1) + + def test_multiple_quoted_name(self, cache, quote): + # new value + eq_(cache(quoted_name("name1", quote=quote)), 1) + eq_(cache(quoted_name("name2", quote=quote)), 2) + + # same values, counter not incremented + eq_(cache(quoted_name("name1", quote=quote)), 1) + eq_(cache(quoted_name("name2", quote=quote)), 2) + + def test_single_quoted_name_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache("n1"), 2) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache("n1"), 2) + + def test_multiple_quoted_name_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache("n1"), 3) + eq_(cache("n2"), 4) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache("n1"), 3) + eq_(cache("n2"), 4) + + def test_single_quoted_name_false_true_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n1", quote=not quote)), 2) + eq_(cache("n1"), 3) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n1", quote=not quote)), 2) + eq_(cache("n1"), 3) + + def test_multiple_quoted_name_false_true_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache(quoted_name("n1", quote=not quote)), 3) + eq_(cache(quoted_name("n2", quote=not quote)), 4) + eq_(cache("n1"), 5) + eq_(cache("n2"), 6) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache(quoted_name("n1", quote=not quote)), 3) + eq_(cache(quoted_name("n2", quote=not quote)), 4) + eq_(cache("n1"), 5) + eq_(cache("n2"), 6) + + def test_multiple_quoted_name_false_true_and_string_arg_and_kwarg( + self, connection, quote + ): + dialect = connection.dialect + info_cache = {} + counter = itertools.count(1) + + @cache + def get_cached_name(self, connection, *args, **kw): + return next(counter) + + def cache_(*args, **kw): + return get_cached_name( + dialect, connection, *args, **kw, info_cache=info_cache + ) + + # new values + eq_(cache_(quoted_name("n1", quote=quote)), 1) + eq_(cache_(name=quoted_name("n1", quote=quote)), 2) + eq_(cache_(quoted_name("n2", quote=quote)), 3) + eq_(cache_(name=quoted_name("n2", quote=quote)), 4) + eq_(cache_(quoted_name("n1", quote=not quote)), 5) + eq_(cache_(name=quoted_name("n1", quote=not quote)), 6) + eq_(cache_(quoted_name("n2", quote=not quote)), 7) + eq_(cache_(name=quoted_name("n2", quote=not quote)), 8) + eq_(cache_("n1"), 9) + eq_(cache_(name="n1"), 10) + eq_(cache_("n2"), 11) + eq_(cache_(name="n2"), 12) + + # same values, counter not incremented + eq_(cache_(quoted_name("n1", quote=quote)), 1) + eq_(cache_(name=quoted_name("n1", quote=quote)), 2) + eq_(cache_(quoted_name("n2", quote=quote)), 3) + eq_(cache_(name=quoted_name("n2", quote=quote)), 4) + eq_(cache_(quoted_name("n1", quote=not quote)), 5) + eq_(cache_(name=quoted_name("n1", quote=not quote)), 6) + eq_(cache_(quoted_name("n2", quote=not quote)), 7) + eq_(cache_(name=quoted_name("n2", quote=not quote)), 8) + eq_(cache_("n1"), 9) + eq_(cache_(name="n1"), 10) + eq_(cache_("n2"), 11) + eq_(cache_(name="n2"), 12) diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 4ae87c4ad18..fb67c7434fe 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -12,6 +12,8 @@ from sqlalchemy.engine import characteristics from sqlalchemy.engine import default from sqlalchemy.engine import url +from sqlalchemy.pool import AsyncAdaptedQueuePool +from sqlalchemy.pool import QueuePool from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -345,9 +347,7 @@ def test_ctxmanager_interface(self, local_connection): assert not trans.is_active eq_( - connection.exec_driver_sql( - "select count(*) from " "users" - ).scalar(), + connection.exec_driver_sql("select count(*) from users").scalar(), 2, ) connection.rollback() @@ -473,7 +473,8 @@ def test_two_phase_transaction(self, local_connection): @testing.requires.two_phase_transactions @testing.requires.two_phase_recovery - def test_two_phase_recover(self): + @testing.variation("commit", [True, False]) + def test_two_phase_recover(self, commit): users = self.tables.users # 2020, still can't get this to work w/ modern MySQL or MariaDB. @@ -501,17 +502,29 @@ def test_two_phase_recover(self): [], ) # recover_twophase needs to be run in a new transaction - with testing.db.connect() as connection2: - recoverables = connection2.recover_twophase() - assert transaction.xid in recoverables - connection2.commit_prepared(transaction.xid, recover=True) + with testing.db.connect() as connection3: + # oracle transactions can't be recovered for commit after... + # about 1 second? OK + with testing.skip_if_timeout( + 0.50, + cleanup=( + lambda: connection3.rollback_prepared( + transaction.xid, recover=True + ) + ), + ): + recoverables = connection3.recover_twophase() + assert transaction.xid in recoverables - eq_( - connection2.execute( - select(users.c.user_id).order_by(users.c.user_id) - ).fetchall(), - [(1,)], - ) + if commit: + connection3.commit_prepared(transaction.xid, recover=True) + res = [(1,)] + else: + connection3.rollback_prepared(transaction.xid, recover=True) + res = [] + + stmt = select(users.c.user_id).order_by(users.c.user_id) + eq_(connection3.execute(stmt).fetchall(), res) @testing.requires.two_phase_transactions def test_multiple_two_phase(self, local_connection): @@ -1347,10 +1360,17 @@ def test_connection_invalidated(self): eq_(c2.get_isolation_level(), self._default_isolation_level()) def test_per_connection(self): - from sqlalchemy.pool import QueuePool eng = testing_engine( - options=dict(poolclass=QueuePool, pool_size=2, max_overflow=0) + options=dict( + poolclass=( + QueuePool + if not testing.db.dialect.is_async + else AsyncAdaptedQueuePool + ), + pool_size=2, + max_overflow=0, + ) ) c1 = eng.connect() diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 7289d5494eb..05941a79a2a 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -3,6 +3,7 @@ import inspect as stdlib_inspect from unittest.mock import patch +from sqlalchemy import AssertionPool from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import delete @@ -11,11 +12,16 @@ from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import NullPool +from sqlalchemy import QueuePool from sqlalchemy import select +from sqlalchemy import SingletonThreadPool +from sqlalchemy import StaticPool from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import union_all from sqlalchemy.engine import cursor as _cursor from sqlalchemy.ext.asyncio import async_engine_from_config @@ -263,9 +269,16 @@ async def test_engine_eq_ne(self, async_engine): is_false(async_engine == None) - @async_test - async def test_no_attach_to_event_loop(self, testing_engine): - """test #6409""" + def test_no_attach_to_event_loop(self, testing_engine): + """test #6409 + + note this test does not seem to trigger the bug that was originally + fixed in #6409, when using python 3.10 and higher (the original issue + can repro in 3.8 at least, based on my testing). It's been simplified + to no longer explicitly create a new loop, asyncio.run() already + creates a new loop. + + """ import asyncio import threading @@ -273,9 +286,6 @@ async def test_no_attach_to_event_loop(self, testing_engine): errs = [] def go(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - async def main(): tasks = [task() for _ in range(2)] @@ -398,6 +408,12 @@ async def go(): eq_(m.mock_calls, []) + @async_test + async def test_statement_compile(self, async_engine): + stmt = str(select(1).compile(async_engine)) + async with async_engine.connect() as conn: + eq_(str(select(1).compile(conn)), stmt) + def test_clear_compiled_cache(self, async_engine): async_engine.sync_engine._compiled_cache["foo"] = "bar" eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") @@ -520,6 +536,77 @@ async def test_isolation_level(self, async_connection): eq_(isolation_level, "SERIALIZABLE") + @testing.combinations( + ( + AsyncAdaptedQueuePool, + True, + ), + ( + QueuePool, + False, + ), + (NullPool, True), + (SingletonThreadPool, False), + (StaticPool, True), + (AssertionPool, True), + argnames="pool_cls,should_work", + ) + @testing.variation("instantiate", [True, False]) + @async_test + async def test_pool_classes( + self, async_testing_engine, pool_cls, instantiate, should_work + ): + """test #8771""" + if instantiate: + if pool_cls in (QueuePool, AsyncAdaptedQueuePool): + pool = pool_cls(creator=testing.db.pool._creator, timeout=10) + else: + pool = pool_cls( + creator=testing.db.pool._creator, + ) + + options = {"pool": pool} + else: + if pool_cls in (QueuePool, AsyncAdaptedQueuePool): + options = {"poolclass": pool_cls, "pool_timeout": 10} + else: + options = {"poolclass": pool_cls} + + if not should_work: + with expect_raises_message( + exc.ArgumentError, + f"Pool class {pool_cls.__name__} " + "cannot be used with asyncio engine", + ): + async_testing_engine(options=options) + return + + e = async_testing_engine(options=options) + + if pool_cls is AssertionPool: + async with e.connect() as conn: + result = await conn.scalar(select(1)) + eq_(result, 1) + return + + async def go(): + async with e.connect() as conn: + result = await conn.scalar(select(1)) + eq_(result, 1) + return result + + eq_(await asyncio.gather(*[go() for i in range(10)]), [1] * 10) + + def test_cant_use_async_pool_w_create_engine(self): + """supplemental test for #8771""" + + with expect_raises_message( + exc.ArgumentError, + "Pool class AsyncAdaptedQueuePool " + "cannot be used with non-asyncio engine", + ): + create_engine("sqlite://", poolclass=AsyncAdaptedQueuePool) + @testing.requires.queue_pool @async_test async def test_dispose(self, async_engine): @@ -785,6 +872,27 @@ async def async_creator(x, y, *, z=None): finally: await greenlet_spawn(conn.close) + @testing.combinations("stream", "stream_scalars", argnames="method") + @async_test + async def test_server_side_required_for_scalars( + self, async_engine, method + ): + with mock.patch.object( + async_engine.dialect, "supports_server_side_cursors", False + ): + async with async_engine.connect() as c: + with expect_raises_message( + exc.InvalidRequestError, + "Cant use `stream` or `stream_scalars` with the current " + "dialect since it does not support server side cursors.", + ): + if method == "stream": + await c.stream(select(1)) + elif method == "stream_scalars": + await c.stream_scalars(select(1)) + else: + testing.fail(method) + class AsyncCreatePoolTest(fixtures.TestBase): @config.fixture @@ -865,15 +973,12 @@ async def test_sync_before_cursor_execute_engine(self, async_engine): async with async_engine.connect() as conn: sync_conn = conn.sync_connection - await conn.execute(text("select 1")) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, - [ - mock.call( - sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False - ) - ], + [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], ) @async_test @@ -886,15 +991,12 @@ async def test_sync_before_cursor_execute_connection(self, async_engine): event.listen( async_engine.sync_engine, "before_cursor_execute", canary ) - await conn.execute(text("select 1")) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, - [ - mock.call( - sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False - ) - ], + [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], ) @async_test @@ -932,6 +1034,9 @@ async def test_inspect_connection(self, async_engine): class AsyncResultTest(EngineFixture): + __backend__ = True + __requires__ = ("server_side_cursors", "async_dialect") + @async_test async def test_no_ss_cursor_w_execute(self, async_engine): users = self.tables.users @@ -1230,20 +1335,72 @@ async def test_one_multi_result(self, async_engine): ): await result.one() - @testing.combinations( - ("scalars",), ("stream_scalars",), argnames="filter_" - ) + @testing.combinations(("scalars",), ("stream_scalars",), argnames="case") @async_test - async def test_scalars(self, async_engine, filter_): + async def test_scalars(self, async_engine, case): users = self.tables.users + stmt = select(users).order_by(users.c.user_id) async with async_engine.connect() as conn: - if filter_ == "scalars": - result = (await conn.scalars(select(users))).all() - elif filter_ == "stream_scalars": - result = await (await conn.stream_scalars(select(users))).all() + if case == "scalars": + result = (await conn.scalars(stmt)).all() + elif case == "stream_scalars": + result = await (await conn.stream_scalars(stmt)).all() eq_(result, list(range(1, 20))) + @async_test + @testing.combinations(("stream",), ("stream_scalars",), argnames="case") + async def test_stream_fetch_many_not_complete(self, async_engine, case): + users = self.tables.users + big_query = select(users).join(users.alias("other"), true()) + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(big_query) + elif case == "stream_scalars": + result = await conn.stream_scalars(big_query) + + f1 = await result.fetchmany(5) + f2 = await result.fetchmany(10) + f3 = await result.fetchmany(7) + eq_(len(f1) + len(f2) + len(f3), 22) + + res = await result.fetchall() + eq_(len(res), 19 * 19 - 22) + + @async_test + @testing.combinations(("stream",), ("execute",), argnames="case") + async def test_cursor_close(self, async_engine, case): + users = self.tables.users + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(select(users)) + cursor = result._real_result.cursor + elif case == "execute": + result = await conn.execute(select(users)) + cursor = result.cursor + + await conn.run_sync(lambda _: cursor.close()) + + @async_test + @testing.variation("case", ["scalar_one", "scalar_one_or_none", "scalar"]) + async def test_stream_scalar(self, async_engine, case: testing.Variation): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).limit(1).order_by(users.c.user_name) + ) + + if case.scalar_one: + u1 = await result.scalar_one() + elif case.scalar_one_or_none: + u1 = await result.scalar_one_or_none() + elif case.scalar: + u1 = await result.scalar() + else: + case.fail() + + eq_(u1, 1) + class TextSyncDBAPI(fixtures.TestBase): __requires__ = ("asyncio",) @@ -1259,7 +1416,13 @@ def test_sync_dbapi_raises(self): def async_engine(self): engine = create_engine("sqlite:///:memory:", future=True) engine.dialect.is_async = True - return _async_engine.AsyncEngine(engine) + engine.dialect.supports_server_side_cursors = True + with mock.patch.object( + engine.dialect.execution_ctx_cls, + "create_server_side_cursor", + engine.dialect.execution_ctx_cls.create_default_cursor, + ): + yield _async_engine.AsyncEngine(engine) @async_test @combinations( @@ -1396,3 +1559,23 @@ def test_regen_trans_but_not_conn(self, connection_no_trans): async_t2 = async_conn.get_transaction() is_(async_t1, async_t2) + + +class PoolRegenTest(EngineFixture): + @testing.requires.queue_pool + @async_test + @testing.variation("do_dispose", [True, False]) + async def test_gather_after_dispose(self, testing_engine, do_dispose): + engine = testing_engine( + asyncio=True, options=dict(pool_size=10, max_overflow=10) + ) + + async def thing(engine): + async with engine.connect() as conn: + await conn.exec_driver_sql(str(select(1).compile(engine))) + + if do_dispose: + await engine.dispose() + + tasks = [thing(engine) for _ in range(10)] + await asyncio.gather(*tasks) diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index e38a0cc52a9..5f9bf2e089e 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -4,7 +4,6 @@ from typing import List from typing import Optional -from sqlalchemy import Column from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey @@ -39,6 +38,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_deprecated @@ -47,6 +47,7 @@ from sqlalchemy.testing.assertions import not_in from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.provision import normalize_sequence +from sqlalchemy.testing.schema import Column from .test_engine_py3k import AsyncFixture as _AsyncFixture from ...orm import _fixtures @@ -314,6 +315,7 @@ async def test_stream_partitions(self, async_session, kw): @testing.combinations("statement", "execute", argnames="location") @async_test + @testing.requires.server_side_cursors async def test_no_ss_cursor_w_execute(self, async_session, location): User = self.classes.User @@ -767,7 +769,9 @@ async def go(legacy_inactive_history_style): class A: __tablename__ = "a" - id = Column(Integer, primary_key=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) b = relationship( "B", uselist=False, @@ -779,7 +783,9 @@ class A: @registry.mapped class B: __tablename__ = "b" - id = Column(Integer, primary_key=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) a_id = Column(ForeignKey("a.id")) async with async_engine.begin() as conn: @@ -790,14 +796,8 @@ class B: return go @testing.combinations( - ( - "legacy_style", - True, - ), - ( - "new_style", - False, - ), + ("legacy_style", True), + ("new_style", False), argnames="_legacy_inactive_history_style", id_="ia", ) @@ -935,6 +935,38 @@ async def test_get_transaction(self, async_session): is_(async_session.get_transaction(), None) is_(async_session.get_nested_transaction(), None) + @async_test + async def test_get_transaction_gced(self, async_session): + """test #12471 + + this tests that the AsyncSessionTransaction is regenerated if + we don't have any reference to it beforehand. + + """ + is_(async_session.get_transaction(), None) + is_(async_session.get_nested_transaction(), None) + + await async_session.begin() + + trans = async_session.get_transaction() + is_not(trans, None) + is_(trans.session, async_session) + is_false(trans.nested) + is_( + trans.sync_transaction, + async_session.sync_session.get_transaction(), + ) + + await async_session.begin_nested() + nested = async_session.get_nested_transaction() + is_not(nested, None) + is_true(nested.nested) + is_(nested.session, async_session) + is_( + nested.sync_transaction, + async_session.sync_session.get_nested_transaction(), + ) + @async_test async def test_async_object_session(self, async_engine): User = self.classes.User diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index d6d059cbef9..e21881b3334 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -934,22 +934,25 @@ class ActualDocument(ContactDocument): self.assert_compile( session.query(Document), - "SELECT pjoin.id AS pjoin_id, pjoin.doctype AS pjoin_doctype, " - "pjoin.type AS pjoin_type, pjoin.send_method AS pjoin_send_method " - "FROM " - "(SELECT actual_documents.id AS id, " - "actual_documents.send_method AS send_method, " - "actual_documents.doctype AS doctype, " - "'actual' AS type FROM actual_documents) AS pjoin" - if use_strict_attrs - else "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " - "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " - "pjoin.type AS pjoin_type " - "FROM " - "(SELECT actual_documents.id AS id, " - "actual_documents.send_method AS send_method, " - "actual_documents.doctype AS doctype, " - "'actual' AS type FROM actual_documents) AS pjoin", + ( + "SELECT pjoin.id AS pjoin_id, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type, " + "pjoin.send_method AS pjoin_send_method " + "FROM " + "(SELECT actual_documents.id AS id, " + "actual_documents.send_method AS send_method, " + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin" + if use_strict_attrs + else "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " + "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type " + "FROM " + "(SELECT actual_documents.id AS id, " + "actual_documents.send_method AS send_method, " + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin" + ), ) @testing.combinations(True, False) diff --git a/test/ext/mypy/plugin_files/mapped_attr_assign.py b/test/ext/mypy/plugin_files/mapped_attr_assign.py index 06bc24d9eb0..c7244c27a61 100644 --- a/test/ext/mypy/plugin_files/mapped_attr_assign.py +++ b/test/ext/mypy/plugin_files/mapped_attr_assign.py @@ -3,6 +3,7 @@ """ + from typing import Optional from sqlalchemy import Column diff --git a/test/ext/mypy/plugin_files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py index cbdbf009a0e..146b96b2a73 100644 --- a/test/ext/mypy/plugin_files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -2,6 +2,7 @@ type checked. """ + from typing import List from sqlalchemy import Column diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index f1b36ac52bb..1d75137a042 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -1,6 +1,14 @@ import os +import pathlib import shutil +try: + from mypy.version import __version__ as _mypy_version_str +except ImportError: + _mypy_version = None +else: + _mypy_version = tuple(int(x) for x in _mypy_version_str.split(".")) + from sqlalchemy import testing from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -23,10 +31,22 @@ def _incremental_dirs(): return files +def _mypy_missing_or_incompatible(): + return not _mypy_version or _mypy_version > (1, 10, 1) + + class MypyPluginTest(fixtures.MypyTest): + @testing.skip_if( + _mypy_missing_or_incompatible, + "Mypy must be present and compatible (<= 1.10.1)", + ) @testing.combinations( - *[(pathname) for pathname in _incremental_dirs()], + *[ + (pathlib.Path(pathname).name, pathname) + for pathname in _incremental_dirs() + ], argnames="pathname", + id_="ia", ) @testing.requires.patch_library def test_incremental(self, mypy_runner, per_func_cachedir, pathname): @@ -70,6 +90,10 @@ def test_incremental(self, mypy_runner, per_func_cachedir, pathname): % (patchfile, result[0]), ) + @testing.skip_if( + _mypy_missing_or_incompatible, + "Mypy must be present and compatible (<= 1.10.1)", + ) @testing.combinations( *( (os.path.basename(path), path, True) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 87812c9ac63..7e2b31a9b5b 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -3830,11 +3830,11 @@ class User(decl_base): id: Mapped[int] = mapped_column(primary_key=True) - user_keyword_associations: Mapped[ - List[UserKeywordAssociation] - ] = relationship( - back_populates="user", - cascade="all, delete-orphan", + user_keyword_associations: Mapped[List[UserKeywordAssociation]] = ( + relationship( + back_populates="user", + cascade="all, delete-orphan", + ) ) keywords: AssociationProxy[list[str]] = association_proxy( @@ -3886,12 +3886,12 @@ class User(dc_decl_base): primary_key=True, repr=True, init=False ) - user_keyword_associations: Mapped[ - List[UserKeywordAssociation] - ] = relationship( - back_populates="user", - cascade="all, delete-orphan", - init=False, + user_keyword_associations: Mapped[List[UserKeywordAssociation]] = ( + relationship( + back_populates="user", + cascade="all, delete-orphan", + init=False, + ) ) if embed_in_field: diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index c84bc1c78eb..a3ba1189b3d 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -667,11 +667,14 @@ def _make_tables(self, e): m, Column("id", Integer, primary_key=True), Column("data", String(50)), - Column( - "t_%d_id" % (i - 1), ForeignKey("table_%d.id" % (i - 1)) - ) - if i > 4 - else None, + ( + Column( + "t_%d_id" % (i - 1), + ForeignKey("table_%d.id" % (i - 1)), + ) + if i > 4 + else None + ), ) m.drop_all(e) m.create_all(e) diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index aa03dabc903..707e02dac10 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -209,9 +209,11 @@ def sqlite_my_function(element, compiler, **kw): self.assert_compile( stmt, - "SELECT my_function(t1.q) AS my_function_1 FROM t1" - if named - else "SELECT my_function(t1.q) AS anon_1 FROM t1", + ( + "SELECT my_function(t1.q) AS my_function_1 FROM t1" + if named + else "SELECT my_function(t1.q) AS anon_1 FROM t1" + ), dialect="sqlite", ) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index dd5b7158296..47756c94958 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -169,7 +169,8 @@ def __sa_instrumentation_manager__(cls): ) # This proves SA can handle a class with non-string dict keys - if util.cpython: + # Since python 3.13 non-string key raise a runtime warning. + if util.cpython and not util.py313: locals()[42] = 99 # Don't remove this line! def __init__(self, **kwargs): @@ -760,7 +761,6 @@ class C: class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): - """Allow custom Events implementations.""" @modifies_instrumentation_finders diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 3ff49fc82fe..4d579fa0c1d 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -51,7 +51,7 @@ class ShardTest: @classmethod def define_tables(cls, metadata): - global db1, db2, db3, db4, weather_locations, weather_reports + global weather_locations cls.tables.ids = ids = Table( "ids", metadata, Column("nextid", Integer, nullable=False) diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 69e9c133515..f6ad0de8d4d 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -7,6 +7,7 @@ from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy import literal_column from sqlalchemy import Numeric @@ -21,6 +22,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym +from sqlalchemy.orm.context import ORMSelectCompileState from sqlalchemy.sql import coercions from sqlalchemy.sql import operators from sqlalchemy.sql import roles @@ -423,6 +425,21 @@ def name(self): return A + @testing.fixture + def _unnamed_expr_matches_col_fixture(self): + Base = declarative_base() + + class A(Base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + foo = Column(String) + + @hybrid.hybrid_property + def bar(self): + return self.foo + + return A + def test_access_from_unmapped(self): """test #9519""" @@ -497,6 +514,41 @@ def test_labeling_for_unnamed(self, _unnamed_expr_fixture): "a.lastname AS name FROM a) AS anon_1", ) + @testing.variation("pre_populate_col_proxy", [True, False]) + def test_labeling_for_unnamed_matches_col( + self, _unnamed_expr_matches_col_fixture, pre_populate_col_proxy + ): + """test #11728""" + + A = _unnamed_expr_matches_col_fixture + + if pre_populate_col_proxy: + pre_stmt = select(A.id, A.foo) + pre_stmt.subquery().c + + stmt = select(A.id, A.bar) + self.assert_compile( + stmt, + "SELECT a.id, a.foo FROM a", + ) + + compile_state = ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) + eq_( + compile_state._column_naming_convention( + LABEL_STYLE_DISAMBIGUATE_ONLY, legacy=False + )(list(stmt.inner_columns)[1]), + "bar", + ) + eq_(stmt.subquery().c.keys(), ["id", "bar"]) + + self.assert_compile( + select(stmt.subquery()), + "SELECT anon_1.id, anon_1.foo FROM " + "(SELECT a.id AS id, a.foo AS foo FROM a) AS anon_1", + ) + def test_labeling_for_unnamed_tablename_plus_col( self, _unnamed_expr_fixture ): diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index dffdac8d842..42378477786 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -542,7 +542,7 @@ def test_coerce_raise(self): data={1, 2, 3}, ) - def test_in_place_mutation(self): + def test_in_place_mutation_int(self): sess = fixture_session() f1 = Foo(data=[1, 2]) @@ -554,7 +554,19 @@ def test_in_place_mutation(self): eq_(f1.data, [3, 2]) - def test_in_place_slice_mutation(self): + def test_in_place_mutation_str(self): + sess = fixture_session() + + f1 = Foo(data=["one", "two"]) + sess.add(f1) + sess.commit() + + f1.data[0] = "three" + sess.commit() + + eq_(f1.data, ["three", "two"]) + + def test_in_place_slice_mutation_int(self): sess = fixture_session() f1 = Foo(data=[1, 2, 3, 4]) @@ -566,6 +578,18 @@ def test_in_place_slice_mutation(self): eq_(f1.data, [1, 5, 6, 4]) + def test_in_place_slice_mutation_str(self): + sess = fixture_session() + + f1 = Foo(data=["one", "two", "three", "four"]) + sess.add(f1) + sess.commit() + + f1.data[1:3] = "five", "six" + sess.commit() + + eq_(f1.data, ["one", "five", "six", "four"]) + def test_del_slice(self): sess = fixture_session() @@ -1240,6 +1264,12 @@ class Foo(Mixin, Base): __tablename__ = "foo" id = Column(Integer, primary_key=True) + def test_in_place_mutation_str(self): + """this test is hardcoded to integer, skip strings""" + + def test_in_place_slice_mutation_str(self): + """this test is hardcoded to integer, skip strings""" + class MutableListWithScalarPickleTest( _MutableListTestBase, fixtures.MappedTest diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index 90c7f385789..98e2a8207f9 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -70,7 +70,7 @@ def _setup(self, test_collection_class): """Build a relationship situation using the given test_collection_class factory""" - global metadata, slides_table, bullets_table, Slide, Bullet + global slides_table, bullets_table, Slide, Bullet slides_table = Table( "test_Slides", diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index a52c59e2d34..fb92c752a67 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import scoped_session from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.entities import ComparableEntity @@ -279,6 +280,44 @@ def test_unicode(self): dialect="default", ) + @combinations( + ( + lambda: func.max(users.c.name).over(range_=(None, 0)), + "max(users.name) OVER (RANGE BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(range_=(0, None)), + "max(users.name) OVER (RANGE BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), + ( + lambda: func.max(users.c.name).over(rows=(None, 0)), + "max(users.name) OVER (ROWS BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(rows=(0, None)), + "max(users.name) OVER (ROWS BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), + ( + lambda: func.max(users.c.name).over(groups=(None, 0)), + "max(users.name) OVER (GROUPS BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(groups=(0, None)), + "max(users.name) OVER (GROUPS BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), + ) + def test_over(self, over_fn, sql): + o = over_fn() + self.assert_compile(o, sql) + ol = serializer.loads(serializer.dumps(o), users.metadata) + self.assert_compile(ol, sql) + class ColumnPropertyWParamTest( AssertsCompiledSQL, fixtures.DeclarativeMappedTest @@ -331,7 +370,3 @@ def test_deserailize_colprop(self): "CAST(left(test.some_id, :left_2) AS INTEGER) = :param_1", checkparams={"left_1": 6, "left_2": 6, "param_1": 123456}, ) - - -if __name__ == "__main__": - testing.main() diff --git a/test/orm/declarative/test_abs_import_only.py b/test/orm/declarative/test_abs_import_only.py index e1447364e66..287240575c8 100644 --- a/test/orm/declarative/test_abs_import_only.py +++ b/test/orm/declarative/test_abs_import_only.py @@ -64,9 +64,9 @@ class Foo(decl_base): if construct.Mapped: bars: orm.Mapped[typing.List[Bar]] = orm.relationship() elif construct.WriteOnlyMapped: - bars: orm.WriteOnlyMapped[ - typing.List[Bar] - ] = orm.relationship() + bars: orm.WriteOnlyMapped[typing.List[Bar]] = ( + orm.relationship() + ) elif construct.DynamicMapped: bars: orm.DynamicMapped[typing.List[Bar]] = orm.relationship() else: diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 7085b2af9f6..1f31544e065 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -35,6 +35,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import MappedColumn from sqlalchemy.orm import Mapper from sqlalchemy.orm import registry @@ -930,6 +931,42 @@ class User(BaseUser): # Check to see if __init_subclass__ works in supported versions eq_(UserType._set_random_keyword_used_here, True) + @testing.variation( + "basetype", + ["DeclarativeBase", "DeclarativeBaseNoMeta", "MappedAsDataclass"], + ) + def test_kw_support_in_declarative_base(self, basetype): + """test #10732""" + + if basetype.DeclarativeBase: + + class Base(DeclarativeBase): + pass + + elif basetype.DeclarativeBaseNoMeta: + + class Base(DeclarativeBaseNoMeta): + pass + + elif basetype.MappedAsDataclass: + + class Base(MappedAsDataclass): + pass + + else: + basetype.fail() + + class Mixin: + def __init_subclass__(cls, random_keyword: bool, **kw) -> None: + super().__init_subclass__(**kw) + cls._set_random_keyword_used_here = random_keyword + + class User(Base, Mixin, random_keyword=True): + __tablename__ = "user" + id_ = Column(Integer, primary_key=True) + + eq_(User._set_random_keyword_used_here, True) + def test_declarative_base_bad_registry(self): with assertions.expect_raises_message( exc.InvalidRequestError, @@ -1350,7 +1387,7 @@ class User(Base): assert_raises_message( sa.exc.ArgumentError, - "Can't add additional column 'foo' when " "specifying __table__", + "Can't add additional column 'foo' when specifying __table__", go, ) @@ -1788,7 +1825,7 @@ class Foo(Base, ComparableEntity): assert_raises_message( exc.InvalidRequestError, - "'addresses' is not an instance of " "ColumnProperty", + "'addresses' is not an instance of ColumnProperty", configure_mappers, ) @@ -1917,7 +1954,7 @@ class Bar(Base, ComparableEntity): assert_raises_message( AttributeError, - "does not have a mapped column named " "'__table__'", + "does not have a mapped column named '__table__'", configure_mappers, ) @@ -2471,7 +2508,7 @@ class User(Base, ComparableEntity): def test_oops(self): with testing.expect_warnings( - "Ignoring declarative-like tuple value of " "attribute 'name'" + "Ignoring declarative-like tuple value of attribute 'name'" ): class User(Base, ComparableEntity): diff --git a/test/orm/declarative/test_clsregistry.py b/test/orm/declarative/test_clsregistry.py index ffc8528125c..0cf775e4d27 100644 --- a/test/orm/declarative/test_clsregistry.py +++ b/test/orm/declarative/test_clsregistry.py @@ -230,7 +230,7 @@ def test_dupe_classes_cleanout(self): del f2 gc_collect() - eq_(len(clsregistry._registries), 1) + eq_(len(clsregistry._registries), 0) def test_dupe_classes_name_race(self): """test the race condition that the class was garbage " diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index cbe08f30e17..53a9366c3a7 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -27,6 +27,7 @@ from sqlalchemy import JSON from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import column_property @@ -76,6 +77,7 @@ def dc_decl_base(self, request, metadata): if request.param == "(MAD, DB)": class Base(MappedAsDataclass, DeclarativeBase): + _mad_before = True metadata = _md type_annotation_map = { str: String().with_variant(String(50), "mysql", "mariadb") @@ -84,6 +86,7 @@ class Base(MappedAsDataclass, DeclarativeBase): else: # test #8665 by reversing the order of the classes class Base(DeclarativeBase, MappedAsDataclass): + _mad_before = False metadata = _md type_annotation_map = { str: String().with_variant(String(50), "mysql", "mariadb") @@ -156,6 +159,8 @@ class B(dc_decl_base): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 def test_generic_class(self): """further test for #8665""" @@ -179,9 +184,9 @@ class GenericSetting( JSON, init=True, default_factory=lambda: {} ) - new_instance: GenericSetting[ # noqa: F841 - Dict[str, Any] - ] = GenericSetting(key="x", value={"foo": "bar"}) + new_instance: GenericSetting[Dict[str, Any]] = ( # noqa: F841 + GenericSetting(key="x", value={"foo": "bar"}) + ) def test_no_anno_doesnt_go_into_dc( self, dc_decl_base: Type[MappedAsDataclass] @@ -300,6 +305,8 @@ class B: a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation("dc_type", ["decorator", "superclass"]) def test_dataclass_fn(self, dc_type: Variation): annotations = {} @@ -374,6 +381,9 @@ def test_combine_args_from_pep593(self, decl_base: Type[DeclarativeBase]): dataclass defaults """ + + # anno only: global intpk, str30, s_str30, user_fk + intpk = Annotated[int, mapped_column(primary_key=True)] str30 = Annotated[ str, mapped_column(String(30), insert_default=func.foo()) @@ -683,6 +693,27 @@ class A(dc_decl_base): eq_(fas.args, ["self", "id"]) eq_(fas.kwonlyargs, ["data"]) + @testing.combinations(True, False, argnames="unsafe_hash") + def test_hash_attribute( + self, dc_decl_base: Type[MappedAsDataclass], unsafe_hash + ): + class A(dc_decl_base, unsafe_hash=unsafe_hash): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, hash=False) + data: Mapped[str] = mapped_column(hash=True) + + a = A(id=1, data="x") + if not unsafe_hash or not dc_decl_base._mad_before: + with expect_raises(TypeError): + a_hash1 = hash(a) + else: + a_hash1 = hash(a) + a.id = 41 + eq_(hash(a), a_hash1) + a.data = "y" + ne_(hash(a), a_hash1) + @testing.requires.python310 def test_kw_only_dataclass_constant( self, dc_decl_base: Type[MappedAsDataclass] @@ -742,6 +773,21 @@ class Mixin(MappedAsDataclass): class Foo(Mixin): bar_value: Mapped[float] = mapped_column(default=78) + def test_MappedAsDataclass_table_provided(self, registry): + """test #11973""" + + with expect_raises_message( + exc.InvalidRequestError, + "Class .*Foo.* already defines a '__table__'. " + "ORM Annotated Dataclasses do not support a pre-existing " + "'__table__' element", + ): + + @registry.mapped_as_dataclass + class Foo: + __table__ = Table("foo", registry.metadata) + foo: Mapped[float] + def test_dataclass_exception_wrapped(self, dc_decl_base): with expect_raises_message( exc.InvalidRequestError, @@ -1133,6 +1179,8 @@ class Child(Mixin): c1 = Child() eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 def test_abstract_is_dc(self): collected_annotations = {} @@ -1154,6 +1202,8 @@ class Child(Mixin): eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}}) eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation("check_annotations", [True, False]) def test_abstract_is_dc_w_mapped(self, check_annotations): if check_annotations: @@ -1217,6 +1267,8 @@ class Child(Mixin, Parent): eq_regex(repr(Child(a=5, b=6, c=7)), r".*\.Child\(c=7\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation( "dataclass_scope", ["on_base", "on_mixin", "on_base_class", "on_sub_class"], @@ -1798,9 +1850,10 @@ def test_attribute_options(self, use_arguments, construct): "default_factory": list, "compare": True, "kw_only": False, + "hash": False, } exp = interfaces._AttributeOptions( - False, False, False, list, True, False + False, False, False, list, True, False, False ) else: kw = {} @@ -1822,7 +1875,13 @@ def test_ro_attribute_options(self, use_arguments, construct): "compare": True, } exp = interfaces._AttributeOptions( - False, False, _NoArg.NO_ARG, _NoArg.NO_ARG, True, _NoArg.NO_ARG + False, + False, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + True, + _NoArg.NO_ARG, + _NoArg.NO_ARG, ) else: kw = {} diff --git a/test/orm/declarative/test_dc_transforms_future_anno_sync.py b/test/orm/declarative/test_dc_transforms_future_anno_sync.py new file mode 100644 index 00000000000..8701990526f --- /dev/null +++ b/test/orm/declarative/test_dc_transforms_future_anno_sync.py @@ -0,0 +1,2212 @@ +"""This file is automatically generated from the file +'test/orm/declarative/test_dc_transforms.py' +by the 'tools/sync_test_files.py' script. + +Do not edit manually, any change will be lost. +""" # noqa: E501 + +from __future__ import annotations + +import contextlib +import dataclasses +from dataclasses import InitVar +import functools +import inspect as pyinspect +from itertools import product +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TypeVar +from unittest import mock + +from typing_extensions import Annotated + +from sqlalchemy import BigInteger +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import func +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import JSON +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.orm import column_property +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import deferred +from sqlalchemy.orm import interfaces +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import MappedColumn +from sqlalchemy.orm import query_expression +from sqlalchemy.orm import registry +from sqlalchemy.orm import registry as _RegistryType +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.orm import synonym +from sqlalchemy.sql.base import _NoArg +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_regex +from sqlalchemy.testing import expect_deprecated +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ +from sqlalchemy.testing import Variation +from sqlalchemy.util import compat + + +def _dataclass_mixin_warning(clsname, attrnames): + return testing.expect_deprecated( + rf"When transforming .* to a dataclass, attribute\(s\) " + rf"{attrnames} originates from superclass .*{clsname}" + ) + + +class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): + @testing.fixture(params=["(MAD, DB)", "(DB, MAD)"]) + def dc_decl_base(self, request, metadata): + _md = metadata + + if request.param == "(MAD, DB)": + + class Base(MappedAsDataclass, DeclarativeBase): + _mad_before = True + metadata = _md + type_annotation_map = { + str: String().with_variant(String(50), "mysql", "mariadb") + } + + else: + # test #8665 by reversing the order of the classes + class Base(DeclarativeBase, MappedAsDataclass): + _mad_before = False + metadata = _md + type_annotation_map = { + str: String().with_variant(String(50), "mysql", "mariadb") + } + + yield Base + Base.registry.dispose() + + def test_basic_constructor_repr_base_cls( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + a_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("a.id"), init=False + ) + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + def test_generic_class(self): + """further test for #8665""" + + T_Value = TypeVar("T_Value") + + class SomeBaseClass(DeclarativeBase): + pass + + class GenericSetting( + MappedAsDataclass, SomeBaseClass, Generic[T_Value] + ): + __tablename__ = "xx" + + id: Mapped[int] = mapped_column( + Integer, primary_key=True, init=False + ) + + key: Mapped[str] = mapped_column(String, init=True) + + value: Mapped[T_Value] = mapped_column( + JSON, init=True, default_factory=lambda: {} + ) + + new_instance: GenericSetting[Dict[str, Any]] = ( # noqa: F841 + GenericSetting(key="x", value={"foo": "bar"}) + ) + + def test_no_anno_doesnt_go_into_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class User(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "user" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + username: Mapped[str] + password: Mapped[str] + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + default_factory=list + ) + + class Address(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "address" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + # should not be in the dataclass constructor + user_id = mapped_column(ForeignKey(User.id)) + + email_address: Mapped[str] + + a1 = Address("email@address") + eq_(a1.email_address, "email@address") + + def test_warn_on_non_dc_mixin(self): + class _BaseMixin: + create_user: Mapped[int] = mapped_column() + update_user: Mapped[Optional[int]] = mapped_column( + default=None, init=False + ) + + class Base(DeclarativeBase, MappedAsDataclass, _BaseMixin): + pass + + class SubMixin: + foo: Mapped[str] + bar: Mapped[str] = mapped_column() + + with _dataclass_mixin_warning( + "_BaseMixin", "'create_user', 'update_user'" + ), _dataclass_mixin_warning("SubMixin", "'foo', 'bar'"): + + class User(SubMixin, Base): + __tablename__ = "sys_user" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + username: Mapped[str] = mapped_column(String) + password: Mapped[str] = mapped_column(String) + + def test_basic_constructor_repr_cls_decorator( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass() + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + + # note a_id isn't included because it wasn't annotated + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', x=None), " + "some_module.B(id=None, data='data2', x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation("dc_type", ["decorator", "superclass"]) + def test_dataclass_fn(self, dc_type: Variation): + annotations = {} + + def dc_callable(kls, **kw) -> Type[Any]: + annotations[kls] = kls.__annotations__ + return dataclasses.dataclass(kls, **kw) # type: ignore + + if dc_type.decorator: + reg = registry() + + @reg.mapped_as_dataclass(dataclass_callable=dc_callable) + class MappedClass: + __tablename__ = "mapped_class" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + eq_(annotations, {MappedClass: {"id": int, "name": str}}) + + elif dc_type.superclass: + + class Base(DeclarativeBase): + pass + + class Mixin(MappedAsDataclass, dataclass_callable=dc_callable): + id: Mapped[int] = mapped_column(primary_key=True) + + class MappedClass(Mixin, Base): + __tablename__ = "mapped_class" + name: Mapped[str] + + eq_( + annotations, + {Mixin: {"id": int}, MappedClass: {"id": int, "name": str}}, + ) + else: + dc_type.fail() + + def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="d1") + data2: Mapped[str] = mapped_column(default_factory=lambda: "d2") + + a1 = A() + eq_(a1.data, "d1") + eq_(a1.data2, "d2") + + def test_default_factory_vs_collection_class( + self, dc_decl_base: Type[MappedAsDataclass] + ): + # this is currently the error raised by dataclasses. We can instead + # do this validation ourselves, but overall I don't know that we + # can hit every validation and rule that's in dataclasses + with expect_raises_message( + ValueError, "cannot specify both default and default_factory" + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column( + default="d1", default_factory=lambda: "d2" + ) + + def test_combine_args_from_pep593(self, decl_base: Type[DeclarativeBase]): + """test that we can set up column-level defaults separate from + dataclass defaults + + """ + + global intpk, str30, s_str30, user_fk + + intpk = Annotated[int, mapped_column(primary_key=True)] + str30 = Annotated[ + str, mapped_column(String(30), insert_default=func.foo()) + ] + s_str30 = Annotated[ + str, + mapped_column(String(30), server_default="some server default"), + ] + user_fk = Annotated[int, mapped_column(ForeignKey("user_account.id"))] + + class User(MappedAsDataclass, decl_base): + __tablename__ = "user_account" + + # we need this case for dataclasses that can't derive things + # from Annotated yet at the typing level + id: Mapped[intpk] = mapped_column(init=False) + name_none: Mapped[Optional[str30]] = mapped_column(default=None) + name: Mapped[str30] = mapped_column(default="hi") + name2: Mapped[s_str30] = mapped_column(default="there") + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + back_populates="user", default_factory=list + ) + + class Address(MappedAsDataclass, decl_base): + __tablename__ = "address" + + id: Mapped[intpk] = mapped_column(init=False) + email_address: Mapped[str] + user_id: Mapped[user_fk] = mapped_column(init=False) + user: Mapped[Optional["User"]] = relationship( + back_populates="addresses", default=None + ) + + is_true(User.__table__.c.id.primary_key) + is_true(User.__table__.c.name_none.default.arg.compare(func.foo())) + is_true(User.__table__.c.name.default.arg.compare(func.foo())) + eq_(User.__table__.c.name2.server_default.arg, "some server default") + + is_true(Address.__table__.c.user_id.references(User.__table__.c.id)) + u1 = User() + eq_(u1.name_none, None) + eq_(u1.name, "hi") + eq_(u1.name2, "there") + + def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]): + class Person(dc_decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + name: Mapped[str] + type: Mapped[str] = mapped_column(init=False) + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "engineer"} + + e1 = Engineer("nm", "st", "en", "pl") + eq_(e1.name, "nm") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + def test_non_mapped_fields_wo_mapped_or_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + a1 = A("data", "ctrl_one", 5) + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + }, + ) + + def test_non_mapped_fields_wo_mapped_or_dc_w_inherits( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + class B(A): + b_data: Mapped[str] = mapped_column(default="bd") + + # ensure we didnt break dataclasses contract of removing Field + # issue #8880 + eq_(A.__dict__["some_field"], 5) + assert "ctrl_one" not in A.__dict__ + + b1 = B(data="data", ctrl_one="ctrl_one", some_field=5, b_data="x") + eq_( + dataclasses.asdict(b1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + "b_data": "x", + }, + ) + + def test_init_var(self, dc_decl_base: Type[MappedAsDataclass]): + class User(dc_decl_base): + __tablename__ = "user_account" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + name: Mapped[str] + + password: InitVar[str] + repeat_password: InitVar[str] + + password_hash: Mapped[str] = mapped_column( + init=False, nullable=False + ) + + def __post_init__(self, password: str, repeat_password: str): + if password != repeat_password: + raise ValueError("passwords do not match") + + self.password_hash = f"some hash... {password}" + + u1 = User(name="u1", password="p1", repeat_password="p1") + eq_(u1.password_hash, "some hash... p1") + self.assert_compile( + select(User), + "SELECT user_account.id, user_account.name, " + "user_account.password_hash FROM user_account", + ) + + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): + """We will be telling users "this is a dataclass that is also + mapped". Therefore, they will want *any* kind of attribute to do what + it would normally do in a dataclass, including normal types without any + field and explicit use of dataclasses.field(). additionally, we'd like + ``Mapped`` to mean "persist this attribute". So the absence of + ``Mapped`` should also mean something too. + + """ + + class A(dc_decl_base): + __tablename__ = "a" + + ctrl_one: str = dataclasses.field() + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + some_field: int = dataclasses.field(default=5) + + some_none_field: Optional[str] = dataclasses.field(default=None) + + some_other_int_field: int = 10 + + # some field is part of the constructor + a1 = A("ctrlone", "datafield") + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 5, + "some_none_field": None, + "some_other_int_field": 10, + }, + ) + + a2 = A( + "ctrlone", + "datafield", + some_field=7, + some_other_int_field=12, + some_none_field="x", + ) + eq_( + dataclasses.asdict(a2), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 7, + "some_none_field": "x", + "some_other_int_field": 12, + }, + ) + + # only Mapped[] is mapped + self.assert_compile(select(A), "SELECT a.id, a.data FROM a") + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=[ + "self", + "ctrl_one", + "data", + "some_field", + "some_none_field", + "some_other_int_field", + ], + varargs=None, + varkw=None, + defaults=(5, None, 10), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]): + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(MappedAsDataclass, Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "engineer"} + + e1 = Engineer("st", "en", "pl") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + eq_( + pyinspect.getfullargspec(Person.__init__), + # the boring **kw __init__ + pyinspect.FullArgSpec( + args=["self"], + varargs=None, + varkw="kwargs", + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + eq_( + pyinspect.getfullargspec(Engineer.__init__), + # the exciting dataclasses __init__ + pyinspect.FullArgSpec( + args=["self", "status", "engineer_name", "primary_language"], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_compare(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, compare=False) + data: Mapped[str] + + a1 = A(id=0, data="foo") + a2 = A(id=1, data="foo") + eq_(a1, a2) + + @testing.requires.python310 + def test_kw_only_attribute(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(kw_only=True) + + fas = pyinspect.getfullargspec(A.__init__) + eq_(fas.args, ["self", "id"]) + eq_(fas.kwonlyargs, ["data"]) + + @testing.combinations(True, False, argnames="unsafe_hash") + def test_hash_attribute( + self, dc_decl_base: Type[MappedAsDataclass], unsafe_hash + ): + class A(dc_decl_base, unsafe_hash=unsafe_hash): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, hash=False) + data: Mapped[str] = mapped_column(hash=True) + + a = A(id=1, data="x") + if not unsafe_hash or not dc_decl_base._mad_before: + with expect_raises(TypeError): + a_hash1 = hash(a) + else: + a_hash1 = hash(a) + a.id = 41 + eq_(hash(a), a_hash1) + a.data = "y" + ne_(hash(a), a_hash1) + + @testing.requires.python310 + def test_kw_only_dataclass_constant( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class Mixin(MappedAsDataclass): + a: Mapped[int] = mapped_column(primary_key=True) + b: Mapped[int] = mapped_column(default=1) + + class Child(Mixin, dc_decl_base): + __tablename__ = "child" + + _: dataclasses.KW_ONLY + c: Mapped[int] + + c1 = Child(1, c=5) + eq_(c1, Child(a=1, b=1, c=5)) + + def test_mapped_column_overrides(self, dc_decl_base): + """test #8688""" + + class TriggeringMixin(MappedAsDataclass): + mixin_value: Mapped[int] = mapped_column(BigInteger) + + class NonTriggeringMixin(MappedAsDataclass): + mixin_value: Mapped[int] + + class Foo(dc_decl_base, TriggeringMixin): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + foo_value: Mapped[float] = mapped_column(default=78) + + class Bar(dc_decl_base, NonTriggeringMixin): + __tablename__ = "bar" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + bar_value: Mapped[float] = mapped_column(default=78) + + f1 = Foo(mixin_value=5) + eq_(f1.foo_value, 78) + + b1 = Bar(mixin_value=5) + eq_(b1.bar_value, 78) + + def test_mixing_MappedAsDataclass_with_decorator_raises(self, registry): + """test #9211""" + + class Mixin(MappedAsDataclass): + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + with expect_raises_message( + exc.InvalidRequestError, + "Class .*Foo.* is already a dataclass; ensure that " + "base classes / decorator styles of establishing dataclasses " + "are not being mixed. ", + ): + + @registry.mapped_as_dataclass + class Foo(Mixin): + bar_value: Mapped[float] = mapped_column(default=78) + + def test_MappedAsDataclass_table_provided(self, registry): + """test #11973""" + + with expect_raises_message( + exc.InvalidRequestError, + "Class .*Foo.* already defines a '__table__'. " + "ORM Annotated Dataclasses do not support a pre-existing " + "'__table__' element", + ): + + @registry.mapped_as_dataclass + class Foo: + __table__ = Table("foo", registry.metadata) + foo: Mapped[float] + + def test_dataclass_exception_wrapped(self, dc_decl_base): + with expect_raises_message( + exc.InvalidRequestError, + r"Python dataclasses error encountered when creating dataclass " + r"for \'Foo\': .*Please refer to Python dataclasses.*", + ) as ec: + + class Foo(dc_decl_base): + id: Mapped[int] = mapped_column(primary_key=True, init=False) + foo_value: Mapped[float] = mapped_column(default=78) + foo_no_value: Mapped[float] = mapped_column() + __tablename__ = "foo" + + is_true(isinstance(ec.error.__cause__, TypeError)) + + def test_dataclass_default(self, dc_decl_base): + """test for #9879""" + + def c10(): + return 10 + + def c20(): + return 20 + + class A(dc_decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + def_init: Mapped[int] = mapped_column(default=42) + call_init: Mapped[int] = mapped_column(default_factory=c10) + def_no_init: Mapped[int] = mapped_column(default=13, init=False) + call_no_init: Mapped[int] = mapped_column( + default_factory=c20, init=False + ) + + a = A(id=100) + eq_(a.def_init, 42) + eq_(a.call_init, 10) + eq_(a.def_no_init, 13) + eq_(a.call_no_init, 20) + + fields = {f.name: f for f in dataclasses.fields(A)} + eq_(fields["def_init"].default, 42) + eq_(fields["call_init"].default_factory, c10) + eq_(fields["def_no_init"].default, dataclasses.MISSING) + ne_(fields["def_no_init"].default_factory, dataclasses.MISSING) + eq_(fields["call_no_init"].default_factory, c20) + + def test_dataclass_default_callable(self, dc_decl_base): + """test for #9936""" + + def cd(): + return 42 + + with expect_deprecated( + "Callable object passed to the ``default`` parameter for " + "attribute 'value' in a ORM-mapped Dataclasses context is " + "ambiguous, and this use will raise an error in a future " + "release. If this callable is intended to produce Core level ", + "Callable object passed to the ``default`` parameter for " + "attribute 'no_init' in a ORM-mapped Dataclasses context is " + "ambiguous, and this use will raise an error in a future " + "release. If this callable is intended to produce Core level ", + ): + + class A(dc_decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + value: Mapped[int] = mapped_column(default=cd) + no_init: Mapped[int] = mapped_column(default=cd, init=False) + + a = A(id=100) + is_false("no_init" in a.__dict__) + eq_(a.value, cd) + eq_(a.no_init, None) + + fields = {f.name: f for f in dataclasses.fields(A)} + eq_(fields["value"].default, cd) + eq_(fields["no_init"].default, cd) + + +class RelationshipDefaultFactoryTest(fixtures.TestBase): + def test_list(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs[0].data, "hi") + + def test_set(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: {B(data="hi")} + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs.pop().data, "hi") + + def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + # old school collection mismatch error FTW + with expect_raises_message( + TypeError, "Incompatible collection type: list is not set-like" + ): + A() + + def test_one_to_one_example(self, dc_decl_base: Type[MappedAsDataclass]): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Parent(dc_decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + child: Mapped["Child"] = relationship( # noqa: F821 + back_populates="parent", default=None + ) + + class Child(dc_decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + parent_id: Mapped[int] = mapped_column( + ForeignKey("parent.id"), init=False + ) + parent: Mapped["Parent"] = relationship( + back_populates="child", default=None + ) + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + + p2 = Parent() + is_(p2.child, None) + + def test_replace_operation_works_w_history_etc( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + registry.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")]) + sess.add(a1) + sess.commit() + + a2 = dataclasses.replace(a1, x=12, bs=[B("b4")]) + + assert a1 in sess + assert not sess.is_modified(a1, include_collections=True) + assert a2 not in sess + eq_(inspect(a2).attrs.x.history, ([12], (), ())) + sess.add(a2) + sess.commit() + + eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12]) + eq_( + sess.scalars(select(B.data).order_by(B.id)).all(), + ["b1", "b2", "b3", "b4"], + ) + + def test_post_init(self, registry: _RegistryType): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(init=False) + + def __post_init__(self): + self.data = "some data" + + a1 = A() + eq_(a1.data, "some data") + + def test_no_field_args_w_new_style(self, registry: _RegistryType): + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + def test_no_field_args_w_new_style_two(self, registry: _RegistryType): + @dataclasses.dataclass + class Base: + pass + + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A(Base): + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + +class DataclassesForNonMappedClassesTest(fixtures.TestBase): + """test for cases added in #9179""" + + def test_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6)), r".*\.Child\(a=5, b=6\)") + + def test_base_is_dc_plus_options(self): + class Parent(MappedAsDataclass, DeclarativeBase, unsafe_hash=True): + a: int + + class Child(Parent, repr=False): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(5, 6) + eq_(hash(c1), hash(Child(5, 6))) + + # still reprs, because base has a repr, but b not included + eq_regex(repr(c1), r".*\.Child\(a=5\)") + + def test_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(a=5, b=6) + eq_regex(repr(c1), r".*\.Child\(b=6\)") + + def test_base_is_dc_field(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int = dataclasses.field(default=10) + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True, default=7) + + c1 = Child(a=5, b=6) + eq_regex(repr(c1), r".*\.Child\(a=5, b=6\)") + + c1 = Child(b=6) + eq_regex(repr(c1), r".*\.Child\(a=10, b=6\)") + + c1 = Child() + eq_regex(repr(c1), r".*\.Child\(a=10, b=7\)") + + def test_abstract_and_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Mixin(Parent): + __abstract__ = True + b: int + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)") + + def test_abstract_and_base_is_dc_plus_options(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Mixin(Parent, unsafe_hash=True): + __abstract__ = True + b: int + + class Child(Mixin, repr=False): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_(hash(Child(5, 6, 7)), hash(Child(5, 6, 7))) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6\)") + + def test_abstract_and_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + class Mixin(Parent): + __abstract__ = True + b: InitVar[int] + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(a=5, b=6, c=7) + eq_regex(repr(c1), r".*\.Child\(c=7\)") + + def test_abstract_and_base_is_dc_field(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int = dataclasses.field(default=10) + + class Mixin(Parent): + __abstract__ = True + b: int = dataclasses.field(default=7) + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True, default=9) + + c1 = Child(b=6, c=7) + eq_regex(repr(c1), r".*\.Child\(a=10, b=6, c=7\)") + + c1 = Child() + eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + def test_abstract_is_dc(self): + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + + class Parent(DeclarativeBase): + a: int + + class Mixin(MappedAsDataclass, Parent, dataclass_callable=check_args): + __abstract__ = True + b: int + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}}) + eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation("check_annotations", [True, False]) + def test_abstract_is_dc_w_mapped(self, check_annotations): + if check_annotations: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + + class_kw = {"dataclass_callable": check_args} + else: + class_kw = {} + + class Parent(DeclarativeBase): + a: int + + class Mixin(MappedAsDataclass, Parent, **class_kw): + __abstract__ = True + b: Mapped[int] = mapped_column() + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + if check_annotations: + # note: current dataclasses process adds Field() object to Child + # based on attributes which include those from Mixin. This means + # the annotations of Child are also augmented while we do + # dataclasses collection. + eq_( + collected_annotations, + {Mixin: {"b": int}, Child: {"b": int, "c": int}}, + ) + eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + + def test_mixin_and_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + @dataclasses.dataclass + class Mixin: + b: int + + class Child(Mixin, Parent): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)") + + def test_mixin_and_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + @dataclasses.dataclass + class Mixin: + b: InitVar[int] + + class Child(Mixin, Parent): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(a=5, b=6, c=7)), r".*\.Child\(c=7\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation( + "dataclass_scope", + ["on_base", "on_mixin", "on_base_class", "on_sub_class"], + ) + @testing.variation( + "test_alternative_callable", + [True, False], + ) + def test_mixin_w_inheritance( + self, dataclass_scope, test_alternative_callable + ): + """test #9226""" + + expected_annotations = {} + + if test_alternative_callable: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = getattr( + cls, "__annotations__", {} + ) + return dataclasses.dataclass(cls, **kw) + + klass_kw = {"dataclass_callable": check_args} + else: + klass_kw = {} + + if dataclass_scope.on_base: + + class Base(MappedAsDataclass, DeclarativeBase, **klass_kw): + pass + + expected_annotations[Base] = {} + else: + + class Base(DeclarativeBase): + pass + + if dataclass_scope.on_mixin: + + class Mixin(MappedAsDataclass, **klass_kw): + @declared_attr.directive + @classmethod + def __tablename__(cls) -> str: + return cls.__name__.lower() + + @declared_attr.directive + @classmethod + def __mapper_args__(cls) -> Dict[str, Any]: + return { + "polymorphic_identity": cls.__name__, + "polymorphic_on": "polymorphic_type", + } + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + init=False, + ) + + expected_annotations[Mixin] = {} + + non_dc_mixin = contextlib.nullcontext + + else: + + class Mixin: + @declared_attr.directive + @classmethod + def __tablename__(cls) -> str: + return cls.__name__.lower() + + @declared_attr.directive + @classmethod + def __mapper_args__(cls) -> Dict[str, Any]: + return { + "polymorphic_identity": cls.__name__, + "polymorphic_on": "polymorphic_type", + } + + if dataclass_scope.on_base or dataclass_scope.on_base_class: + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + init=False, + ) + + else: + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + ) + + non_dc_mixin = functools.partial( + _dataclass_mixin_warning, "Mixin", "'polymorphic_type'" + ) + + if dataclass_scope.on_base_class: + with non_dc_mixin(): + + class Book(Mixin, MappedAsDataclass, Base, **klass_kw): + id: Mapped[int] = mapped_column( + Integer, + primary_key=True, + init=False, + ) + + else: + if dataclass_scope.on_base: + local_non_dc_mixin = non_dc_mixin + else: + local_non_dc_mixin = contextlib.nullcontext + + with local_non_dc_mixin(): + + class Book(Mixin, Base): + if not dataclass_scope.on_sub_class: + id: Mapped[int] = mapped_column( # noqa: A001 + Integer, primary_key=True, init=False + ) + else: + id: Mapped[int] = mapped_column( # noqa: A001 + Integer, + primary_key=True, + ) + + if MappedAsDataclass in Book.__mro__: + expected_annotations[Book] = {"id": int, "polymorphic_type": str} + + if dataclass_scope.on_sub_class: + with non_dc_mixin(): + + class Novel(MappedAsDataclass, Book, **klass_kw): + id: Mapped[int] = mapped_column( # noqa: A001 + ForeignKey("book.id"), + primary_key=True, + init=False, + ) + description: Mapped[Optional[str]] + + else: + with non_dc_mixin(): + + class Novel(Book): + id: Mapped[int] = mapped_column( + ForeignKey("book.id"), + primary_key=True, + init=False, + ) + description: Mapped[Optional[str]] + + expected_annotations[Novel] = {"id": int, "description": Optional[str]} + + if test_alternative_callable: + eq_(collected_annotations, expected_annotations) + + n1 = Novel("the description") + eq_(n1.description, "the description") + + +class DataclassArgsTest(fixtures.TestBase): + dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") + if compat.py310: + dc_arg_names += ("match_args", "kw_only") + + @testing.fixture(params=product(dc_arg_names, (True, False))) + def dc_argument_fixture(self, request: Any, registry: _RegistryType): + name, use_defaults = request.param + + args = {n: n == name for n in self.dc_arg_names} + if args["order"]: + args["eq"] = True + if use_defaults: + default = { + "init": True, + "repr": True, + "eq": True, + "order": False, + "unsafe_hash": False, + } + if compat.py310: + default |= {"match_args": True, "kw_only": False} + to_apply = {k: v for k, v in args.items() if v} + effective = {**default, **to_apply} + return to_apply, effective + else: + return args, args + + @testing.fixture(params=["mapped_column", "synonym", "deferred"]) + def mapped_expr_constructor(self, request): + name = request.param + + if name == "mapped_column": + yield mapped_column(default=7, init=True) + elif name == "synonym": + yield synonym("some_int", default=7, init=True) + elif name == "deferred": + yield deferred(Column(Integer), default=7, init=True) + + def test_attrs_rejected_if_not_a_dc( + self, mapped_expr_constructor, decl_base: Type[DeclarativeBase] + ): + if isinstance(mapped_expr_constructor, MappedColumn): + unwanted_args = "'init'" + else: + unwanted_args = "'default', 'init'" + with expect_raises_message( + exc.ArgumentError, + r"Attribute 'x' on class .*A.* includes dataclasses " + r"argument\(s\): " + rf"{unwanted_args} but class does not specify SQLAlchemy native " + "dataclass configuration", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] = mapped_expr_constructor + + def _assert_cls(self, cls, dc_arguments): + if dc_arguments["init"]: + + def create(data, x): + if dc_arguments.get("kw_only"): + return cls(data=data, x=x) + else: + return cls(data, x) + + else: + + def create(data, x): + a1 = cls() + a1.data = data + a1.x = x + return a1 + + for n in self.dc_arg_names: + if dc_arguments[n]: + getattr(self, f"_assert_{n}")(cls, create, dc_arguments) + else: + getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments) + + if dc_arguments["init"]: + a1 = cls(data="some data") + eq_(a1.x, 7) + + a1 = create("some data", 15) + some_int = a1.some_int + eq_( + dataclasses.asdict(a1), + {"data": "some data", "id": None, "some_int": some_int, "x": 15}, + ) + eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15)) + + def _assert_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + hash(a1) + + def _assert_not_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + + if dc_arguments["eq"]: + with expect_raises(TypeError): + hash(a1) + else: + hash(a1) + + def _assert_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a3) + ne_(a1, a2) + + def _assert_not_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a1) + ne_(a1, a3) + ne_(a1, a2) + + def _assert_order(self, cls, create, dc_arguments): + is_false(create("g", 10) < create("b", 7)) + + is_true(create("g", 10) > create("b", 7)) + + is_false(create("g", 10) <= create("b", 7)) + + is_true(create("g", 10) >= create("b", 7)) + + eq_( + list(sorted([create("g", 10), create("g", 5), create("b", 7)])), + [ + create("b", 7), + create("g", 5), + create("g", 10), + ], + ) + + def _assert_not_order(self, cls, create, dc_arguments): + with expect_raises(TypeError): + create("g", 10) < create("b", 7) + + with expect_raises(TypeError): + create("g", 10) > create("b", 7) + + with expect_raises(TypeError): + create("g", 10) <= create("b", 7) + + with expect_raises(TypeError): + create("g", 10) >= create("b", 7) + + def _assert_repr(self, cls, create, dc_arguments): + assert "__repr__" in cls.__dict__ + a1 = create("some data", 12) + eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)") + + def _assert_not_repr(self, cls, create, dc_arguments): + assert "__repr__" not in cls.__dict__ + + # if a superclass has __repr__, then we still get repr. + # so can't test this + # a1 = create("some data", 12) + # eq_regex(repr(a1), r"<.*A object at 0x.*>") + + def _assert_init(self, cls, create, dc_arguments): + if not dc_arguments.get("kw_only", False): + a1 = cls("some data", 5) + + eq_(a1.data, "some data") + eq_(a1.x, 5) + + a2 = cls(data="some data", x=5) + eq_(a2.data, "some data") + eq_(a2.x, 5) + + a3 = cls(data="some data") + eq_(a3.data, "some data") + eq_(a3.x, 7) + + def _assert_not_init(self, cls, create, dc_arguments): + with expect_raises(TypeError): + cls("Some data", 5) + + # we run real "dataclasses" on the class. so with init=False, it + # doesn't touch what was there, and the SQLA default constructor + # gets put on. + a1 = cls(data="some data") + eq_(a1.data, "some data") + eq_(a1.x, None) + + a1 = cls() + eq_(a1.data, None) + + # no constructor, it sets None for x...ok + eq_(a1.x, None) + + def _assert_match_args(self, cls, create, dc_arguments): + if not dc_arguments["kw_only"]: + is_true(len(cls.__match_args__) > 0) + + def _assert_not_match_args(self, cls, create, dc_arguments): + is_false(hasattr(cls, "__match_args__")) + + def _assert_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_(fas.args, ["self"]) + eq_( + len(fas.kwonlyargs), + len(pyinspect.signature(cls.__init__).parameters) - 1, + ) + + def _assert_not_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_( + len(fas.args), + len(pyinspect.signature(cls.__init__).parameters), + ) + eq_(fas.kwonlyargs, []) + + def test_dc_arguments_decorator( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + @registry.mapped_as_dataclass(**dc_argument_fixture[0]) + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_base( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + reg = registry + + class Base( + MappedAsDataclass, DeclarativeBase, **dc_argument_fixture[0] + ): + registry = reg + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_perclass( + self, + dc_argument_fixture, + mapped_expr_constructor, + decl_base: Type[DeclarativeBase], + ): + class A(MappedAsDataclass, decl_base, **dc_argument_fixture[0]): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_override_base(self, registry: _RegistryType): + reg = registry + + class Base(MappedAsDataclass, DeclarativeBase, init=False, order=True): + registry = reg + + class A(Base, init=True, repr=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_column(default=7) + + effective = { + "init": True, + "repr": False, + "eq": True, + "order": True, + "unsafe_hash": False, + } + if compat.py310: + effective |= {"match_args": True, "kw_only": False} + self._assert_cls(A, effective) + + def test_dc_base_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + class Base(MappedAsDataclass, DeclarativeBase, slots=True): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + class A(Base2, slots=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_decorator_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class Base(DeclarativeBase): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class A(Base2): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_raise_for_slots( + self, + registry: _RegistryType, + decl_base: Type[DeclarativeBase], + ): + reg = registry + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + class A(MappedAsDataclass, decl_base): + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots' are not accepted", + ): + + class Base(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + _sa_apply_dc_transforms = {"slots": True} + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + @reg.mapped + class C: + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + @testing.variation("use_arguments", [True, False]) + @testing.combinations( + mapped_column, + lambda **kw: synonym("some_int", **kw), + lambda **kw: deferred(Column(Integer), **kw), + lambda **kw: composite("foo", **kw), + lambda **kw: relationship("Foo", **kw), + lambda **kw: association_proxy("foo", "bar", **kw), + argnames="construct", + ) + def test_attribute_options(self, use_arguments, construct): + if use_arguments: + kw = { + "init": False, + "repr": False, + "default": False, + "default_factory": list, + "compare": True, + "kw_only": False, + "hash": False, + } + exp = interfaces._AttributeOptions( + False, False, False, list, True, False, False + ) + else: + kw = {} + exp = interfaces._DEFAULT_ATTRIBUTE_OPTIONS + + prop = construct(**kw) + eq_(prop._attribute_options, exp) + + @testing.variation("use_arguments", [True, False]) + @testing.combinations( + lambda **kw: column_property(Column(Integer), **kw), + lambda **kw: query_expression(**kw), + argnames="construct", + ) + def test_ro_attribute_options(self, use_arguments, construct): + if use_arguments: + kw = { + "repr": False, + "compare": True, + } + exp = interfaces._AttributeOptions( + False, + False, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + True, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ) + else: + kw = {} + exp = interfaces._DEFAULT_READONLY_ATTRIBUTE_OPTIONS + + prop = construct(**kw) + eq_(prop._attribute_options, exp) + + +class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests for #8718""" + + __dialect__ = "default" + + @testing.fixture + def model(self): + def go(use_mixin, use_inherits, mad_setup, dataclass_kw): + if use_mixin: + if mad_setup == "dc, mad": + + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): + pass + + elif mad_setup == "mad, dc": + + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): + pass + + elif mad_setup == "subclass": + + class BaseEntity(DeclarativeBase): + pass + + class IdMixin(MappedAsDataclass): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A( + IdMixin, MappedAsDataclass, BaseEntity, **dataclass_kw + ): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(IdMixin, BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + if mad_setup == "dc, mad": + + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "mad, dc": + + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "subclass": + + class BaseEntity(MappedAsDataclass, DeclarativeBase): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A(BaseEntity, **dataclass_kw): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + if use_inherits: + + class B(A): + __mapper_args__ = { + "polymorphic_identity": "b", + } + b_data: Mapped[str] = mapped_column(String, init=False) + + return B + else: + return A + + yield go + + @testing.combinations("inherits", "plain", argnames="use_inherits") + @testing.combinations("mixin", "base", argnames="use_mixin") + @testing.combinations( + "mad, dc", "dc, mad", "subclass", argnames="mad_setup" + ) + def test_mapping(self, model, use_inherits, use_mixin, mad_setup): + target_cls = model( + use_inherits=use_inherits == "inherits", + use_mixin=use_mixin == "mixin", + mad_setup=mad_setup, + dataclass_kw={}, + ) + + obj = target_cls() + assert "id" not in obj.__dict__ + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(dc_decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + graph_id: Mapped[int] = mapped_column( + ForeignKey("graph.id"), init=False + ) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1"), default=None + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2"), default=None + ) + + class Graph(dc_decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + edges: Mapped[List[Edge]] = relationship() + + Point.__qualname__ = "mymodel.Point" + Edge.__qualname__ = "mymodel.Edge" + Graph.__qualname__ = "mymodel.Graph" + g = Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + eq_( + repr(g), + "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, " + "graph_id=None, start=mymodel.Point(x=1, y=2), " + "end=mymodel.Point(x=3, y=4)), " + "mymodel.Edge(id=None, graph_id=None, " + "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])", + ) + + def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(dc_decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, init=False, repr=False + ) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, + mapped_column(), + mapped_column(), + mapped_column("zip"), + default=None, + ) + + Address.__qualname__ = "mymodule.Address" + User.__qualname__ = "mymodule.User" + u = User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + u2 = User("u2") + eq_( + repr(u), + "mymodule.User(name='user 1', " + "address=mymodule.Address(street='123 anywhere street', " + "state='NY', zip_='12345'))", + ) + eq_(repr(u2), "mymodule.User(name='u2', address=None)") + + +class ReadOnlyAttrTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests related to #9628""" + + __dialect__ = "default" + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_default_behavior( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf") + + m1 = MyClass(data="foo") + eq_(m1, MyClass(data="foo")) + ne_(m1, MyClass(data="bar")) + + eq_regex( + repr(m1), + r".*MyClass\(id=None, data='foo', const=None\)", + ) + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_no_repr_behavior( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf", repr=False) + + m1 = MyClass(data="foo") + + eq_regex( + repr(m1), + r".*MyClass\(id=None, data='foo'\)", + ) + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_enable_compare( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf", compare=True) + + m1 = MyClass(data="foo") + eq_(m1, MyClass(data="foo")) + ne_(m1, MyClass(data="bar")) + + m2 = MyClass(data="foo") + m2.const = "some const" + ne_(m2, MyClass(data="foo")) + m3 = MyClass(data="foo") + m3.const = "some const" + eq_(m2, m3) diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index c5b908cd822..1b633d1bcf0 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -1067,7 +1067,6 @@ class Person(decl_base): target_id = Column(Integer, primary_key=True) class Engineer(Person): - """single table inheritance""" if decl_type.legacy: @@ -1084,7 +1083,6 @@ def target_id(cls): ) class Manager(Person): - """single table inheritance""" if decl_type.legacy: @@ -1468,7 +1466,6 @@ class A(a_1): class OverlapColPrecedenceTest(DeclarativeTestBase): - """test #1892 cases when declarative does column precedence.""" def _run_test(self, Engineer, e_id, p_id): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 900133df593..d670e96dcbf 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -7,6 +7,7 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing @@ -98,6 +99,159 @@ class Foo(Base): self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo") + @testing.variation("base_type", ["generate_base", "subclass"]) + @testing.variation("attrname", ["table", "tablename"]) + @testing.variation("position", ["base", "abstract"]) + @testing.variation("assert_no_extra_cols", [True, False]) + def test_declared_attr_on_base( + self, registry, base_type, attrname, position, assert_no_extra_cols + ): + """test #11509""" + + if position.abstract: + if base_type.generate_base: + SuperBase = registry.generate_base() + + class Base(SuperBase): + __abstract__ = True + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + elif base_type.subclass: + + class SuperBase(DeclarativeBase): + pass + + class Base(SuperBase): + __abstract__ = True + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + else: + base_type.fail() + else: + if base_type.generate_base: + + class Base: + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + Base = registry.generate_base(cls=Base) + elif base_type.subclass: + + class Base(DeclarativeBase): + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + else: + base_type.fail() + + if attrname.table and assert_no_extra_cols: + with expect_raises_message( + sa.exc.ArgumentError, + "Can't add additional column 'data' when specifying __table__", + ): + + class MyNopeClass(Base): + data = Column(String) + + return + + class MyClass(Base): + if attrname.tablename: + id = Column(Integer, primary_key=True) # noqa: A001 + + class MyOtherClass(Base): + if attrname.tablename: + id = Column(Integer, primary_key=True) # noqa: A001 + + t = Table( + "my_override", + Base.metadata, + Column("id", Integer, primary_key=True), + ) + + class MyOverrideClass(Base): + __table__ = t + + Base.registry.configure() + + # __table__ was assigned + assert isinstance(MyClass.__dict__["__table__"], schema.Table) + assert isinstance(MyOtherClass.__dict__["__table__"], schema.Table) + + eq_(MyClass.__table__.name, "MyClass") + eq_(MyClass.__table__.c.keys(), ["id"]) + + eq_(MyOtherClass.__table__.name, "MyOtherClass") + eq_(MyOtherClass.__table__.c.keys(), ["id"]) + + is_(MyOverrideClass.__table__, t) + def test_simple_wbase(self): class MyMixin: id = Column( @@ -672,11 +826,9 @@ def target(cls): return relationship("Other") class Engineer(Mixin, Person): - """single table inheritance""" class Manager(Mixin, Person): - """single table inheritance""" class Other(Base): @@ -1324,7 +1476,7 @@ class Model(Base, ColumnMixin): assert_raises_message( sa.exc.ArgumentError, - "Can't add additional column 'tada' when " "specifying __table__", + "Can't add additional column 'tada' when specifying __table__", go, ) diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 833518a4275..9b0d4f334bc 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,13 +1,14 @@ """This file includes annotation-sensitive tests while having ``from __future__ import annotations`` in effect. -Only tests that don't have an equivalent in ``test_typed_mappings`` are -specified here. All test from ``test_typed_mappings`` are copied over to +Only tests that don't have an equivalent in ``test_typed_mapping`` are +specified here. All test from ``test_typed_mapping`` are copied over to the ``test_tm_future_annotations_sync`` by the ``sync_test_file`` script. """ from __future__ import annotations +import enum from typing import ClassVar from typing import Dict from typing import List @@ -29,8 +30,13 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm.util import _cleanup_mapped_str_annotation +from sqlalchemy.sql import sqltypes +from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from .test_typed_mapping import expect_annotation_syntax_error from .test_typed_mapping import MappedColumnTest as _MappedColumnTest from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest @@ -45,6 +51,89 @@ class M3: pass +class AnnoUtilTest(fixtures.TestBase): + @testing.combinations( + ("Mapped[Address]", 'Mapped["Address"]'), + ('Mapped["Address"]', 'Mapped["Address"]'), + ("Mapped['Address']", "Mapped['Address']"), + ("Mapped[Address | None]", 'Mapped["Address | None"]'), + ("Mapped[None | Address]", 'Mapped["None | Address"]'), + ('Mapped["Address | None"]', 'Mapped["Address | None"]'), + ("Mapped['None | Address']", "Mapped['None | Address']"), + ('Mapped["Address" | "None"]', 'Mapped["Address" | "None"]'), + ('Mapped["None" | "Address"]', 'Mapped["None" | "Address"]'), + ("Mapped[A_]", 'Mapped["A_"]'), + ("Mapped[_TypingLiteral]", 'Mapped["_TypingLiteral"]'), + ("Mapped[datetime.datetime]", 'Mapped["datetime.datetime"]'), + ("Mapped[List[Edge]]", 'Mapped[List["Edge"]]'), + ( + "Mapped[collections.abc.MutableSequence[B]]", + 'Mapped[collections.abc.MutableSequence["B"]]', + ), + ("Mapped[typing.Sequence[B]]", 'Mapped[typing.Sequence["B"]]'), + ("Mapped[dict[str, str]]", 'Mapped[dict["str", "str"]]'), + ("Mapped[Dict[str, str]]", 'Mapped[Dict["str", "str"]]'), + ("Mapped[list[str]]", 'Mapped[list["str"]]'), + ("Mapped[dict[str, str] | None]", "Mapped[dict[str, str] | None]"), + ("Mapped[Optional[anno_str_mc]]", 'Mapped[Optional["anno_str_mc"]]'), + ( + "Mapped[Optional[Dict[str, str]]]", + 'Mapped[Optional[Dict["str", "str"]]]', + ), + ( + "Mapped[Optional[Union[Decimal, float]]]", + 'Mapped[Optional[Union["Decimal", "float"]]]', + ), + ( + "Mapped[Optional[Union[list[int], list[str]]]]", + "Mapped[Optional[Union[list[int], list[str]]]]", + ), + ("Mapped[TestType[str]]", 'Mapped[TestType["str"]]'), + ("Mapped[TestType[str, str]]", 'Mapped[TestType["str", "str"]]'), + ("Mapped[Union[A, None]]", 'Mapped[Union["A", "None"]]'), + ("Mapped[Union[Decimal, float]]", 'Mapped[Union["Decimal", "float"]]'), + ( + "Mapped[Union[Decimal, float, None]]", + 'Mapped[Union["Decimal", "float", "None"]]', + ), + ( + "Mapped[Union[Dict[str, str], None]]", + "Mapped[Union[Dict[str, str], None]]", + ), + ("Mapped[Union[float, Decimal]]", 'Mapped[Union["float", "Decimal"]]'), + ( + "Mapped[Union[list[int], list[str]]]", + "Mapped[Union[list[int], list[str]]]", + ), + ( + "Mapped[Union[list[int], list[str], None]]", + "Mapped[Union[list[int], list[str], None]]", + ), + ( + "Mapped[Union[None, Dict[str, str]]]", + "Mapped[Union[None, Dict[str, str]]]", + ), + ( + "Mapped[Union[None, list[int], list[str]]]", + "Mapped[Union[None, list[int], list[str]]]", + ), + ("Mapped[A | None]", 'Mapped["A | None"]'), + ("Mapped[Decimal | float]", 'Mapped["Decimal | float"]'), + ("Mapped[Decimal | float | None]", 'Mapped["Decimal | float | None"]'), + ( + "Mapped[list[int] | list[str] | None]", + "Mapped[list[int] | list[str] | None]", + ), + ("Mapped[None | dict[str, str]]", "Mapped[None | dict[str, str]]"), + ( + "Mapped[None | list[int] | list[str]]", + "Mapped[None | list[int] | list[str]]", + ), + ) + def test_cleanup_mapped_str_annotation(self, given, expected): + eq_(_cleanup_mapped_str_annotation(given, __name__), expected) + + class MappedColumnTest(_MappedColumnTest): def test_fully_qualified_mapped_name(self, decl_base): """test #8853, regression caused by #8759 ;) @@ -92,11 +181,11 @@ def make_class() -> None: ll = list + def make_class() -> None: x: ll[int] = [1, 2, 3] - """ # noqa: E501 class Foo(decl_base): @@ -112,6 +201,85 @@ class Foo(decl_base): select(Foo), "SELECT foo.id, foo.data, foo.data2 FROM foo" ) + def test_type_favors_outer(self, decl_base): + """test #10899, that we maintain favoring outer names vs. inner. + this is for backwards compatibility as well as what people + usually expect regarding the names of attributes in the class. + + """ + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + uuid: Mapped[uuid.UUID] = mapped_column() + + is_true(isinstance(User.__table__.c.uuid.type, sqltypes.Uuid)) + + def test_type_inline_cls_qualified(self, decl_base): + """test #10899, where we test that we can refer to the class name + directly to refer to class-bound elements. + + """ + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[User.Role] + + is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + + def test_type_inline_disambiguate(self, decl_base): + """test #10899, where we test that we can refer to an inner name + that's not in conflict directly without qualification. + + """ + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[Role] + + is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + eq_(User.__table__.c.role.type.name, "role") # and not 'enum' + + def test_type_inner_can_be_qualified(self, decl_base): + """test #10899, same test as that of Role, using it to qualify against + a global variable with the same name. + + """ + + global SomeGlobalName + SomeGlobalName = None + + class User(decl_base): + __tablename__ = "user" + + class SomeGlobalName(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[User.SomeGlobalName] + + is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.SomeGlobalName) + def test_indirect_mapped_name_local_level(self, decl_base): """test #8759. diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index ec5f5e82097..5b17e3e6e54 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -13,6 +13,7 @@ from decimal import Decimal import enum import inspect as _py_inspect +import re import typing from typing import Any from typing import cast @@ -29,8 +30,12 @@ from typing import Union import uuid +import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal +from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType +from typing_extensions import TypedDict from sqlalchemy import BIGINT from sqlalchemy import BigInteger @@ -38,6 +43,7 @@ from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -62,16 +68,23 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import DynamicMapped +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm import foreign from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship +from sqlalchemy.orm import remote from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped +from sqlalchemy.orm.attributes import CollectionAttributeImpl from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict +from sqlalchemy.orm.dynamic import DynamicAttributeImpl from sqlalchemy.orm.properties import MappedColumn +from sqlalchemy.orm.writeonly import WriteOnlyAttributeImpl from sqlalchemy.schema import CreateTable from sqlalchemy.sql.base import _NoArg from sqlalchemy.sql.sqltypes import Enum @@ -85,11 +98,76 @@ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import requires from sqlalchemy.testing import Variation +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated +TV = typing.TypeVar("TV") + + +class _SomeDict1(TypedDict): + type: Literal["1"] + + +class _SomeDict2(TypedDict): + type: Literal["2"] + + +_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2] + +_StrTypeAlias: TypeAlias = str + + +if compat.py38: + _TypingLiteral = typing.Literal["a", "b"] +_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] + +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604) + +TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType) + +_StrPep695 = TypeAliasType("_StrPep695", str) +_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str) +_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,)) +_TypingGenericPep695 = TypingTypeAliasType( + "_TypingGenericPep695", List[TV], type_params=(TV,) +) +_GenericPep695Typed = _GenericPep695[int] +_TypingGenericPep695Typed = _TypingGenericPep695[int] +_UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2]) +strtypalias_keyword = TypeAliasType( + "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})] +) +if compat.py310: + strtypalias_keyword_nested = TypeAliasType( + "strtypalias_keyword_nested", + int | Annotated[str, mapped_column(info={"hi": "there"})], + ) +strtypalias_ta: TypeAlias = Annotated[str, mapped_column(info={"hi": "there"})] +strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] +_Literal695 = TypeAliasType( + "_Literal695", Literal["to-do", "in-progress", "done"] +) +_TypingLiteral695 = TypingTypeAliasType( + "_TypingLiteral695", Literal["to-do", "in-progress", "done"] +) +_RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695) + def expect_annotation_syntax_error(name): return expect_raises_message( @@ -163,6 +241,46 @@ class Foo(decl_base): else: eq_(Foo.__table__.c.data.default.arg, 5) + def test_type_inline_declaration(self, decl_base): + """test #10899""" + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[Role] + + is_true(isinstance(User.__table__.c.role.type, Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + eq_(User.__table__.c.role.type.name, "role") # and not 'enum' + + def test_type_uses_inner_when_present(self, decl_base): + """test #10899, that we use inner name when appropriate""" + + class Role(enum.Enum): + foo = "foo" + bar = "bar" + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[Role] + + is_true(isinstance(User.__table__.c.role.type, Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + eq_(User.__table__.c.role.type.name, "role") # and not 'enum' + def test_legacy_declarative_base(self): typ = VARCHAR(50) Base = declarative_base(type_annotation_map={str: typ}) @@ -177,6 +295,43 @@ class MyClass(Base): is_(MyClass.__table__.c.data.type, typ) is_true(MyClass.__table__.c.id.primary_key) + @testing.variation("style", ["none", "lambda_", "string", "direct"]) + def test_foreign_annotation_propagates_correctly(self, decl_base, style): + """test #10597""" + + class Parent(decl_base): + __tablename__ = "parent" + id: Mapped[int] = mapped_column(primary_key=True) + + class Child(decl_base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + + if style.none: + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + else: + parent_id: Mapped[int] = mapped_column() + + if style.lambda_: + parent: Mapped[Parent] = relationship( + primaryjoin=lambda: remote(Parent.id) + == foreign(Child.parent_id), + ) + elif style.string: + parent: Mapped[Parent] = relationship( + primaryjoin="remote(Parent.id) == " + "foreign(Child.parent_id)", + ) + elif style.direct: + parent: Mapped[Parent] = relationship( + primaryjoin=remote(Parent.id) == foreign(parent_id), + ) + elif style.none: + parent: Mapped[Parent] = relationship() + + assert Child.__mapper__.attrs.parent.strategy.use_get + @testing.combinations( (BIGINT(),), (BIGINT,), @@ -475,19 +630,179 @@ class User(decl_base): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[MyClass] = mapped_column() - def test_construct_lhs_sqlalchemy_type(self, decl_base): - with expect_raises_message( - sa_exc.ArgumentError, - "The type provided inside the 'data' attribute Mapped " - "annotation is the SQLAlchemy type .*BigInteger.*. Expected " - "a Python type instead", - ): + @testing.variation( + "argtype", + [ + "type", + ("column", testing.requires.python310), + ("mapped_column", testing.requires.python310), + "column_class", + "ref_to_type", + ("ref_to_column", testing.requires.python310), + ], + ) + def test_construct_lhs_sqlalchemy_type(self, decl_base, argtype): + """test for #12329. - class User(decl_base): - __tablename__ = "users" + of note here are all the different messages we have for when the + wrong thing is put into Mapped[], and in fact in #12329 we added + another one. - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[BigInteger] = mapped_column() + This is a lot of different messages, but at the same time they + occur at different places in the interpretation of types. If + we were to centralize all these messages, we'd still likely end up + doing distinct messages for each scenario, so instead we added + a new ArgumentError subclass MappedAnnotationError that provides + some commonality to all of these cases. + + + """ + expect_future_annotations = "annotations" in globals() + + if argtype.type: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, type is + # a SQL type + "The type provided inside the 'data' attribute Mapped " + "annotation is the SQLAlchemy type .*BigInteger.*. Expected " + "a Python type instead", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[BigInteger] = mapped_column() + + elif argtype.column: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # util.py -> _extract_mapped_subtype + ( + re.escape( + "Could not interpret annotation " + "Mapped[Column('q', BigInteger)]." + ) + if expect_future_annotations + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + else re.escape( + "The object provided inside the 'data' attribute " + "Mapped annotation is not a Python type, it's the " + "object Column('q', BigInteger(), table=None). " + "Expected a Python type." + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[Column("q", BigInteger)] = ( # noqa: F821 + mapped_column() + ) + + elif argtype.mapped_column: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + # interestingly, this raises at the same point for both + # future annotations mode and legacy annotations mode + r"The object provided inside the 'data' attribute " + "Mapped annotation is not a Python type, it's the object " + r"\. " + "Expected a Python type.", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + big_integer: Mapped[int] = mapped_column() + data: Mapped[big_integer] = mapped_column() + + elif argtype.column_class: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, type is not + # a SQL type + re.escape( + "Could not locate SQLAlchemy Core type for Python type " + " inside the " + "'data' attribute Mapped annotation" + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[Column] = mapped_column() + + elif argtype.ref_to_type: + mytype = BigInteger + with expect_raises_message( + orm_exc.MappedAnnotationError, + ( + # decl_base.py -> _exract_mappable_attributes + re.escape( + "Could not resolve all types within mapped " + 'annotation: "Mapped[mytype]"' + ) + if expect_future_annotations + # properties.py -> _init_column_for_annotation, type is + # a SQL type + else re.escape( + "The type provided inside the 'data' attribute Mapped " + "annotation is the SQLAlchemy type " + ". " + "Expected a Python type instead" + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[mytype] = mapped_column() + + elif argtype.ref_to_column: + mycol = Column("q", BigInteger) + + with expect_raises_message( + orm_exc.MappedAnnotationError, + # decl_base.py -> _exract_mappable_attributes + ( + re.escape( + "Could not resolve all types within mapped " + 'annotation: "Mapped[mycol]"' + ) + if expect_future_annotations + else + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + re.escape( + "The object provided inside the 'data' attribute " + "Mapped " + "annotation is not a Python type, it's the object " + "Column('q', BigInteger(), table=None). " + "Expected a Python type." + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[mycol] = mapped_column() + + else: + argtype.fail() def test_construct_rhs_type_override_lhs(self, decl_base): class Element(decl_base): @@ -692,6 +1007,352 @@ class MyClass(decl_base): is_true(MyClass.__table__.c.data_two.nullable) eq_(MyClass.__table__.c.data_three.type.length, 50) + def test_plain_typealias_as_typemap_keys( + self, decl_base: Type[DeclarativeBase] + ): + decl_base.registry.update_type_annotation_map( + {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)} + ) + + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[_StrTypeAlias] + structure: Mapped[_UnionTypeAlias] + + eq_(Test.__table__.c.data.type.length, 30) + is_(Test.__table__.c.structure.type._type_affinity, JSON) + + @testing.variation( + "option", + [ + "plain", + "union", + "union_604", + "union_null", + "union_null_604", + "optional", + "optional_union", + "optional_union_604", + "union_newtype", + "union_null_newtype", + "union_695", + "union_null_695", + ], + ) + @testing.variation("in_map", ["yes", "no", "value"]) + @testing.requires.python312 + def test_pep695_behavior(self, decl_base, in_map, option): + """Issue #11955""" + global tat + + if option.plain: + tat = TypeAliasType("tat", str) + elif option.union: + tat = TypeAliasType("tat", Union[str, int]) + elif option.union_604: + tat = TypeAliasType("tat", str | int) + elif option.union_null: + tat = TypeAliasType("tat", Union[str, int, None]) + elif option.union_null_604: + tat = TypeAliasType("tat", str | int | None) + elif option.optional: + tat = TypeAliasType("tat", Optional[str]) + elif option.optional_union: + tat = TypeAliasType("tat", Optional[Union[str, int]]) + elif option.optional_union_604: + tat = TypeAliasType("tat", Optional[str | int]) + elif option.union_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int]) + elif option.union_null_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int, None]) + elif option.union_695: + tat = TypeAliasType("tat", str | int) + elif option.union_null_695: + tat = TypeAliasType("tat", str | int | None) + else: + option.fail() + + if in_map.yes: + decl_base.registry.update_type_annotation_map({tat: String(99)}) + elif in_map.value and "newtype" not in option.name: + decl_base.registry.update_type_annotation_map( + {tat.__value__: String(99)} + ) + + def declare(): + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[tat] + + return Test.__table__.c.data + + if in_map.yes: + col = declare() + length = 99 + elif ( + in_map.value + and "newtype" not in option.name + or option.optional + or option.plain + ): + with expect_deprecated( + "Matching the provided TypeAliasType 'tat' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + col = declare() + length = 99 if in_map.value else None + else: + with expect_raises_message( + orm_exc.MappedAnnotationError, + r"Could not locate SQLAlchemy Core type for Python type .*tat " + "inside the 'data' attribute Mapped annotation", + ): + declare() + return + + is_true(isinstance(col.type, String)) + eq_(col.type.length, length) + nullable = "null" in option.name or "optional" in option.name + eq_(col.nullable, nullable) + + @testing.variation( + "type_", + [ + "str_extension", + "str_typing", + "generic_extension", + "generic_typing", + "generic_typed_extension", + "generic_typed_typing", + ], + ) + @testing.requires.python312 + def test_pep695_typealias_as_typemap_keys( + self, decl_base: Type[DeclarativeBase], type_ + ): + """test #10807""" + + decl_base.registry.update_type_annotation_map( + { + _UnionPep695: JSON, + _StrPep695: String(30), + _TypingStrPep695: String(30), + _GenericPep695: String(30), + _TypingGenericPep695: String(30), + _GenericPep695Typed: String(30), + _TypingGenericPep695Typed: String(30), + } + ) + + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + if type_.str_extension: + data: Mapped[_StrPep695] + elif type_.str_typing: + data: Mapped[_TypingStrPep695] + elif type_.generic_extension: + data: Mapped[_GenericPep695] + elif type_.generic_typing: + data: Mapped[_TypingGenericPep695] + elif type_.generic_typed_extension: + data: Mapped[_GenericPep695Typed] + elif type_.generic_typed_typing: + data: Mapped[_TypingGenericPep695Typed] + else: + type_.fail() + structure: Mapped[_UnionPep695] + + eq_(Test.__table__.c.data.type._type_affinity, String) + eq_(Test.__table__.c.data.type.length, 30) + is_(Test.__table__.c.structure.type._type_affinity, JSON) + + @testing.variation( + "alias_type", + ["none", "typekeyword", "typealias", "typekeyword_nested"], + ) + @testing.requires.python312 + def test_extract_pep593_from_pep695( + self, decl_base: Type[DeclarativeBase], alias_type + ): + """test #11130""" + if alias_type.typekeyword: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword: VARCHAR(33)} # noqa: F821 + ) + if alias_type.typekeyword_nested: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword_nested: VARCHAR(42)} # noqa: F821 + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if alias_type.typekeyword: + data_one: Mapped[strtypalias_keyword] # noqa: F821 + elif alias_type.typealias: + data_one: Mapped[strtypalias_ta] # noqa: F821 + elif alias_type.none: + data_one: Mapped[strtypalias_plain] # noqa: F821 + elif alias_type.typekeyword_nested: + data_one: Mapped[strtypalias_keyword_nested] # noqa: F821 + else: + alias_type.fail() + + table = MyClass.__table__ + assert table is not None + + if alias_type.typekeyword_nested: + # a nested annotation is not supported + eq_(MyClass.data_one.expression.info, {}) + else: + eq_(MyClass.data_one.expression.info, {"hi": "there"}) + + if alias_type.typekeyword: + eq_(MyClass.data_one.type.length, 33) + elif alias_type.typekeyword_nested: + eq_(MyClass.data_one.type.length, 42) + else: + eq_(MyClass.data_one.type.length, None) + + @testing.variation( + "type_", + [ + "literal", + "literal_typing", + "recursive", + "not_literal", + "not_literal_typing", + "generic", + "generic_typing", + "generic_typed", + "generic_typed_typing", + ], + ) + @testing.combinations(True, False, argnames="in_map") + @testing.requires.python312 + def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): + """test #11305.""" + + def declare(): + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + if type_.recursive: + status: Mapped[_RecursiveLiteral695] # noqa: F821 + elif type_.literal: + status: Mapped[_Literal695] # noqa: F821 + elif type_.literal_typing: + status: Mapped[_TypingLiteral695] # noqa: F821 + elif type_.not_literal: + status: Mapped[_StrPep695] # noqa: F821 + elif type_.not_literal_typing: + status: Mapped[_TypingStrPep695] # noqa: F821 + elif type_.generic: + status: Mapped[_GenericPep695] # noqa: F821 + elif type_.generic_typing: + status: Mapped[_TypingGenericPep695] # noqa: F821 + elif type_.generic_typed: + status: Mapped[_GenericPep695Typed] # noqa: F821 + elif type_.generic_typed_typing: + status: Mapped[_TypingGenericPep695Typed] # noqa: F821 + else: + type_.fail() + + return Foo + + if in_map: + decl_base.registry.update_type_annotation_map( + { + _Literal695: Enum(enum.Enum), # noqa: F821 + _TypingLiteral695: Enum(enum.Enum), # noqa: F821 + _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 + _StrPep695: Enum(enum.Enum), # noqa: F821 + _TypingStrPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695Typed: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695Typed: Enum(enum.Enum), # noqa: F821 + } + ) + if type_.recursive: + with expect_deprecated( + "Mapping recursive TypeAliasType '.+' that resolve to " + "literal to generate an Enum is deprecated. SQLAlchemy " + "2.1 will not support this use case. Please avoid using " + "recursing TypeAliasType", + ): + Foo = declare() + elif type_.literal or type_.literal_typing: + Foo = declare() + else: + with expect_raises_message( + exc.ArgumentError, + "Can't associate TypeAliasType '.+' to an Enum " + "since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal.'a', " + "'b'.` are supported when generating Enums.", + ): + declare() + return + elif ( + type_.generic + or type_.generic_typing + or type_.generic_typed + or type_.generic_typed_typing + ): + # This behaves like 2.1 -> rationale is that no-one asked to + # support such types and in 2.1 will already be like this + # so it makes little sense to add support this late in the 2.0 + # series + with expect_raises_message( + exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type " + ".+ inside the 'status' attribute Mapped annotation", + ): + declare() + return + else: + with expect_deprecated( + "Matching the provided TypeAliasType '.*' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + Foo = declare() + col = Foo.__table__.c.status + if in_map and not type_.not_literal: + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["to-do", "in-progress", "done"]) + is_(col.type.native_enum, False) + else: + is_true(isinstance(col.type, String)) + + @testing.requires.python38 + def test_typing_literal_identity(self, decl_base): + """See issue #11820""" + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + t: Mapped[_TypingLiteral] + te: Mapped[_TypingExtensionsLiteral] + + for col in (Foo.__table__.c.t, Foo.__table__.c.te): + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["a", "b"]) + is_(col.type.native_enum, False) + @testing.requires.python310 def test_we_got_all_attrs_test_annotated(self): argnames = _py_inspect.getfullargspec(mapped_column) @@ -763,7 +1424,9 @@ def test_we_got_all_attrs_test_annotated(self): ), ("index", True, lambda column: column.index is True), ("index", _NoArg.NO_ARG, lambda column: column.index is None), + ("index", False, lambda column: column.index is False), ("unique", True, lambda column: column.unique is True), + ("unique", False, lambda column: column.unique is False), ("autoincrement", True, lambda column: column.autoincrement is True), ("system", True, lambda column: column.system is True), ("primary_key", True, lambda column: column.primary_key is True), @@ -832,6 +1495,13 @@ def test_we_got_all_attrs_test_annotated(self): "Argument 'init' is a dataclass argument" ), ), + ( + "hash", + True, + exc.SADeprecationWarning( + "Argument 'hash' is a dataclass argument" + ), + ), argnames="argname, argument, assertion", ) @testing.variation("use_annotated", [True, False, "control"]) @@ -855,6 +1525,7 @@ def test_names_encountered_for_annotated( "repr", "compare", "default_factory", + "hash", ) if is_dataclass: @@ -921,6 +1592,32 @@ class User(Base): argument, ) + @testing.combinations(("index",), ("unique",), argnames="paramname") + @testing.combinations((True,), (False,), (None,), argnames="orig") + @testing.combinations((True,), (False,), (None,), argnames="merging") + def test_index_unique_combinations( + self, paramname, orig, merging, decl_base + ): + """test #11091""" + + global myint + + amc = mapped_column(**{paramname: merging}) + myint = Annotated[int, amc] + + mc = mapped_column(**{paramname: orig}) + + class User(decl_base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + myname: Mapped[myint] = mc + + result = getattr(User.__table__.c.myname, paramname) + if orig is None: + is_(result, merging) + else: + is_(result, orig) + def test_pep484_newtypes_as_typemap_keys( self, decl_base: Type[DeclarativeBase] ): @@ -955,6 +1652,33 @@ class MyClass(decl_base): eq_(MyClass.__table__.c.data_four.type.length, 150) is_false(MyClass.__table__.c.data_four.nullable) + def test_newtype_missing_from_map(self, decl_base): + global str50 + + str50 = NewType("str50", str) + + if compat.py310: + text = ".*str50" + else: + # NewTypes before 3.10 had a very bad repr + # .new_type at 0x...> + text = ".*NewType.*" + + with expect_deprecated( + f"Matching the provided NewType '{text}' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data_one: Mapped[str50] + + is_true(isinstance(MyClass.data_one.type, String)) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -983,8 +1707,7 @@ class SomeRelated(decl_base): with expect_raises_message( NotImplementedError, - r"Use of the \ construct inside of an Annotated " + r"Use of the 'Relationship' construct inside of an Annotated " r"object is not yet supported.", ): @@ -1187,22 +1910,68 @@ class RefElementTwo(decl_base): Dict, (str, str), ), + (list, None, testing.requires.python310), + ( + List, + None, + ), + (dict, None, testing.requires.python310), + ( + Dict, + None, + ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + test #11831, regression from #11814 + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + if args is None: + TestType = Annotated[container_typ, 0] + else: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + if args is None: + TestType = container_typ + else: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if args == (str, str): + type_annotation_map = {TestType[str, str]: JSON()} + elif args is None: + type_annotation_map = {TestType: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if args == (str, str): + data: Mapped[TestType[str, str]] = mapped_column() + elif args is None: + data: Mapped[TestType] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON) @@ -1401,34 +2170,64 @@ class Element(decl_base): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + "union_null", + ("pep604_null", requires.python310), + ], + ) + def test_unions(self, union): + global UnionType our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.union_null: + UnionType = Union[float, Decimal, None] + elif union.pep604: + UnionType = float | Decimal + elif union.pep604_null: + UnionType = float | Decimal | None + else: + union.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionType: our_type} class User(Base): __tablename__ = "users" - __table__: Table id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[Union[float, Decimal]] = mapped_column() - reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + data: Mapped[Union[float, Decimal]] + reverse_data: Mapped[Union[Decimal, float]] - optional_data: Mapped[ - Optional[Union[float, Decimal]] - ] = mapped_column() + optional_data: Mapped[Optional[Union[float, Decimal]]] = ( + mapped_column() + ) # use Optional directly - reverse_optional_data: Mapped[ - Optional[Union[Decimal, float]] - ] = mapped_column() + reverse_optional_data: Mapped[Optional[Union[Decimal, float]]] = ( + mapped_column() + ) # use Union with None, same as Optional but presents differently # (Optional object with __origin__ Union vs. Union) - reverse_u_optional_data: Mapped[ - Union[Decimal, float, None] + reverse_u_optional_data: Mapped[Union[Decimal, float, None]] = ( + mapped_column() + ) + + refer_union: Mapped[UnionType] + refer_union_optional: Mapped[Optional[UnionType]] + + # py38, 37 does not automatically flatten unions, add extra tests + # for this. maintain these in order to catch future regressions + # in the behavior of ``Union`` + unflat_union_optional_data: Mapped[ + Union[Union[Decimal, float, None], None] ] = mapped_column() float_data: Mapped[float] = mapped_column() @@ -1437,71 +2236,131 @@ class User(Base): if compat.py310: pep604_data: Mapped[float | Decimal] = mapped_column() pep604_reverse: Mapped[Decimal | float] = mapped_column() - pep604_optional: Mapped[ - Decimal | float | None - ] = mapped_column() + pep604_optional: Mapped[Decimal | float | None] = ( + mapped_column() + ) pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() - pep604_optional_fwd: Mapped[ - "Decimal | float | None" - ] = mapped_column() + pep604_optional_fwd: Mapped["Decimal | float | None"] = ( + mapped_column() + ) - is_(User.__table__.c.data.type, our_type) - is_false(User.__table__.c.data.nullable) - is_(User.__table__.c.reverse_data.type, our_type) - is_(User.__table__.c.optional_data.type, our_type) - is_true(User.__table__.c.optional_data.nullable) + info = [ + ("data", False), + ("reverse_data", False), + ("optional_data", True), + ("reverse_optional_data", True), + ("reverse_u_optional_data", True), + ("refer_union", "null" in union.name), + ("refer_union_optional", True), + ("unflat_union_optional_data", True), + ] + if compat.py310: + info += [ + ("pep604_data", False), + ("pep604_reverse", False), + ("pep604_optional", True), + ("pep604_data_fwd", False), + ("pep604_reverse_fwd", False), + ("pep604_optional_fwd", True), + ] - is_(User.__table__.c.reverse_optional_data.type, our_type) - is_(User.__table__.c.reverse_u_optional_data.type, our_type) - is_true(User.__table__.c.reverse_optional_data.nullable) - is_true(User.__table__.c.reverse_u_optional_data.nullable) + for name, nullable in info: + col = User.__table__.c[name] + is_(col.type, our_type, name) + is_(col.nullable, nullable, name) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + ne_(User.__table__.c.float_data.type, our_type) - if compat.py310: - for suffix in ("", "_fwd"): - data_col = User.__table__.c[f"pep604_data{suffix}"] - reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] - optional_col = User.__table__.c[f"pep604_optional{suffix}"] - is_(data_col.type, our_type) - is_false(data_col.nullable) - is_(reverse_col.type, our_type) - is_false(reverse_col.nullable) - is_(optional_col.type, our_type) - is_true(optional_col.nullable) + is_true(isinstance(User.__table__.c.decimal_data.type, Numeric)) + ne_(User.__table__.c.decimal_data.type, our_type) - @testing.combinations( - ("not_optional",), - ("optional",), - ("optional_fwd_ref",), - ("union_none",), - ("pep604", testing.requires.python310), - ("pep604_fwd_ref", testing.requires.python310), - argnames="optional_on_json", + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + ("pep695", requires.python312), + ], ) + def test_optional_in_annotation_map(self, union): + """See issue #11370""" + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = {_Json: JSON} + elif union.pep604: + type_annotation_map = {_JsonPep604: JSON} + elif union.pep695: + type_annotation_map = {_JsonPep695: JSON} # noqa: F821 + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + elif union.pep695: + json1: Mapped[_JsonPep695] # noqa: F821 + json2: Mapped[_JsonPep695] = mapped_column( # noqa: F821 + nullable=False + ) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + + @testing.variation( + "option", + [ + "not_optional", + "optional", + "optional_fwd_ref", + "union_none", + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + ], + ) + @testing.variation("brackets", ["oneset", "twosets"]) @testing.combinations( "include_mc_type", "derive_from_anno", argnames="include_mc_type" ) def test_optional_styles_nested_brackets( - self, optional_on_json, include_mc_type + self, option, brackets, include_mc_type ): + """composed types test, includes tests that were added later for + #12207""" + class Base(DeclarativeBase): if testing.requires.python310.enabled: type_annotation_map = { - Dict[str, str]: JSON, - dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, + list[int] | list[str]: JSON, } else: type_annotation_map = { - Dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, } if include_mc_type == "include_mc_type": mc = mapped_column(JSON) + mc2 = mapped_column(JSON) else: mc = mapped_column() + mc2 = mapped_column() class A(Base): __tablename__ = "a" @@ -1509,21 +2368,67 @@ class A(Base): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[str] = mapped_column() - if optional_on_json == "not_optional": - json: Mapped[Dict[str, str]] = mapped_column() # type: ignore - elif optional_on_json == "optional": - json: Mapped[Optional[Dict[str, str]]] = mc - elif optional_on_json == "optional_fwd_ref": - json: Mapped["Optional[Dict[str, str]]"] = mc - elif optional_on_json == "union_none": - json: Mapped[Union[Dict[str, str], None]] = mc - elif optional_on_json == "pep604": - json: Mapped[dict[str, str] | None] = mc - elif optional_on_json == "pep604_fwd_ref": - json: Mapped["dict[str, str] | None"] = mc + if brackets.oneset: + if option.not_optional: + json: Mapped[Dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + if testing.requires.python310.enabled: + json2: Mapped[dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Dict[str, Decimal]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Optional[dict[str, Decimal]]] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Dict[str, Decimal]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped["Optional[dict[str, Decimal]]"] = mc2 + elif option.union_none: + json: Mapped[Union[Dict[str, Decimal], None]] = mc + json2: Mapped[Union[None, Dict[str, Decimal]]] = mc2 + elif option.pep604: + json: Mapped[dict[str, Decimal] | None] = mc + if testing.requires.python310.enabled: + json2: Mapped[None | dict[str, Decimal]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["dict[str, Decimal] | None"] = mc + if testing.requires.python310.enabled: + json2: Mapped["None | dict[str, Decimal]"] = mc2 + elif brackets.twosets: + if option.not_optional: + json: Mapped[Union[List[int], List[str]]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Union[List[int], List[str]]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + Optional[Union[list[int], list[str]]] + ] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Union[List[int], List[str]]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + "Optional[Union[list[int], list[str]]]" + ] = mc2 + elif option.union_none: + json: Mapped[Union[List[int], List[str], None]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Union[None, list[int], list[str]]] = mc2 + elif option.pep604: + json: Mapped[list[int] | list[str] | None] = mc + json2: Mapped[None | list[int] | list[str]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["list[int] | list[str] | None"] = mc + json2: Mapped["None | list[int] | list[str]"] = mc2 + else: + brackets.fail() is_(A.__table__.c.json.type._type_affinity, JSON) - if optional_on_json == "not_optional": + if hasattr(A, "json2"): + is_(A.__table__.c.json2.type._type_affinity, JSON) + if option.not_optional: + is_false(A.__table__.c.json2.nullable) + else: + is_true(A.__table__.c.json2.nullable) + + if option.not_optional: is_false(A.__table__.c.json.nullable) else: is_true(A.__table__.c.json.nullable) @@ -1734,7 +2639,8 @@ class int_sub(int): ) with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + orm_exc.MappedAnnotationError, + "Could not locate SQLAlchemy Core type", ): class MyClass(Base): @@ -2289,6 +3195,42 @@ class Base(DeclarativeBase): yield Base Base.registry.dispose() + @testing.combinations( + (Relationship, CollectionAttributeImpl), + (Mapped, CollectionAttributeImpl), + (WriteOnlyMapped, WriteOnlyAttributeImpl), + (DynamicMapped, DynamicAttributeImpl), + argnames="mapped_cls,implcls", + ) + def test_use_relationship(self, decl_base, mapped_cls, implcls): + """test #10611""" + + global B + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + # for future annotations support, need to write these + # directly in source code + if mapped_cls is Relationship: + bs: Relationship[List[B]] = relationship() + elif mapped_cls is Mapped: + bs: Mapped[List[B]] = relationship() + elif mapped_cls is WriteOnlyMapped: + bs: WriteOnlyMapped[List[B]] = relationship() + elif mapped_cls is DynamicMapped: + bs: DynamicMapped[List[B]] = relationship() + + decl_base.registry.configure() + assert isinstance(A.bs.impl, implcls) + def test_no_typing_in_rhs(self, decl_base): class A(decl_base): __tablename__ = "a" @@ -2407,9 +3349,9 @@ class A(decl_base): collection_class=list ) elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship(collection_class=list) + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship(collection_class=list) + ) else: datatype.fail() @@ -2436,15 +3378,15 @@ class A(decl_base): if datatype.typing_sequence: bs: Mapped[typing.Sequence[B]] = relationship() elif datatype.collections_sequence: - bs: Mapped[ - collections.abc.Sequence[B] - ] = relationship() + bs: Mapped[collections.abc.Sequence[B]] = ( + relationship() + ) elif datatype.typing_mutable_sequence: bs: Mapped[typing.MutableSequence[B]] = relationship() elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship() + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship() + ) else: datatype.fail() @@ -2544,7 +3486,7 @@ class B(decl_base): back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "union_none": - a: Mapped["Union[A, None]"] = relationship( + a: Mapped[Union[A, None]] = relationship( back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "pep604": @@ -2649,7 +3591,7 @@ class B(decl_base): is_false(B.__mapper__.attrs["a"].uselist) is_false(B.__mapper__.attrs["a_warg"].uselist) - def test_one_to_one_example(self, decl_base: Type[DeclarativeBase]): + def test_one_to_one_example_quoted(self, decl_base: Type[DeclarativeBase]): """test example in the relationship docs will derive uselist=False correctly""" @@ -2673,6 +3615,32 @@ class Child(decl_base): is_(p1.child, c1) is_(c1.parent, p1) + def test_one_to_one_example_non_quoted( + self, decl_base: Type[DeclarativeBase] + ): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Child(decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(primary_key=True) + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + parent: Mapped["Parent"] = relationship(back_populates="child") + + class Parent(decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(primary_key=True) + child: Mapped[Child] = relationship( # noqa: F821 + back_populates="parent" + ) + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + def test_collection_class_dict_no_collection(self, decl_base): class A(decl_base): __tablename__ = "a" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 6b8becf9c02..acc07ba7d4c 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -4,6 +4,7 @@ from decimal import Decimal import enum import inspect as _py_inspect +import re import typing from typing import Any from typing import cast @@ -20,8 +21,12 @@ from typing import Union import uuid +import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal +from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType +from typing_extensions import TypedDict from sqlalchemy import BIGINT from sqlalchemy import BigInteger @@ -29,6 +34,7 @@ from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -53,16 +59,23 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import DynamicMapped +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm import foreign from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship +from sqlalchemy.orm import remote from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped +from sqlalchemy.orm.attributes import CollectionAttributeImpl from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict +from sqlalchemy.orm.dynamic import DynamicAttributeImpl from sqlalchemy.orm.properties import MappedColumn +from sqlalchemy.orm.writeonly import WriteOnlyAttributeImpl from sqlalchemy.schema import CreateTable from sqlalchemy.sql.base import _NoArg from sqlalchemy.sql.sqltypes import Enum @@ -76,11 +89,76 @@ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import requires from sqlalchemy.testing import Variation +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated +TV = typing.TypeVar("TV") + + +class _SomeDict1(TypedDict): + type: Literal["1"] + + +class _SomeDict2(TypedDict): + type: Literal["2"] + + +_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2] + +_StrTypeAlias: TypeAlias = str + + +if compat.py38: + _TypingLiteral = typing.Literal["a", "b"] +_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] + +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604) + +TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType) + +_StrPep695 = TypeAliasType("_StrPep695", str) +_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str) +_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,)) +_TypingGenericPep695 = TypingTypeAliasType( + "_TypingGenericPep695", List[TV], type_params=(TV,) +) +_GenericPep695Typed = _GenericPep695[int] +_TypingGenericPep695Typed = _TypingGenericPep695[int] +_UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2]) +strtypalias_keyword = TypeAliasType( + "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})] +) +if compat.py310: + strtypalias_keyword_nested = TypeAliasType( + "strtypalias_keyword_nested", + int | Annotated[str, mapped_column(info={"hi": "there"})], + ) +strtypalias_ta: TypeAlias = Annotated[str, mapped_column(info={"hi": "there"})] +strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] +_Literal695 = TypeAliasType( + "_Literal695", Literal["to-do", "in-progress", "done"] +) +_TypingLiteral695 = TypingTypeAliasType( + "_TypingLiteral695", Literal["to-do", "in-progress", "done"] +) +_RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695) + def expect_annotation_syntax_error(name): return expect_raises_message( @@ -154,6 +232,46 @@ class Foo(decl_base): else: eq_(Foo.__table__.c.data.default.arg, 5) + def test_type_inline_declaration(self, decl_base): + """test #10899""" + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[Role] + + is_true(isinstance(User.__table__.c.role.type, Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + eq_(User.__table__.c.role.type.name, "role") # and not 'enum' + + def test_type_uses_inner_when_present(self, decl_base): + """test #10899, that we use inner name when appropriate""" + + class Role(enum.Enum): + foo = "foo" + bar = "bar" + + class User(decl_base): + __tablename__ = "user" + + class Role(enum.Enum): + admin = "admin" + user = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[Role] + + is_true(isinstance(User.__table__.c.role.type, Enum)) + eq_(User.__table__.c.role.type.length, 5) + is_(User.__table__.c.role.type.enum_class, User.Role) + eq_(User.__table__.c.role.type.name, "role") # and not 'enum' + def test_legacy_declarative_base(self): typ = VARCHAR(50) Base = declarative_base(type_annotation_map={str: typ}) @@ -168,6 +286,43 @@ class MyClass(Base): is_(MyClass.__table__.c.data.type, typ) is_true(MyClass.__table__.c.id.primary_key) + @testing.variation("style", ["none", "lambda_", "string", "direct"]) + def test_foreign_annotation_propagates_correctly(self, decl_base, style): + """test #10597""" + + class Parent(decl_base): + __tablename__ = "parent" + id: Mapped[int] = mapped_column(primary_key=True) + + class Child(decl_base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + + if style.none: + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + else: + parent_id: Mapped[int] = mapped_column() + + if style.lambda_: + parent: Mapped[Parent] = relationship( + primaryjoin=lambda: remote(Parent.id) + == foreign(Child.parent_id), + ) + elif style.string: + parent: Mapped[Parent] = relationship( + primaryjoin="remote(Parent.id) == " + "foreign(Child.parent_id)", + ) + elif style.direct: + parent: Mapped[Parent] = relationship( + primaryjoin=remote(Parent.id) == foreign(parent_id), + ) + elif style.none: + parent: Mapped[Parent] = relationship() + + assert Child.__mapper__.attrs.parent.strategy.use_get + @testing.combinations( (BIGINT(),), (BIGINT,), @@ -466,19 +621,179 @@ class User(decl_base): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[MyClass] = mapped_column() - def test_construct_lhs_sqlalchemy_type(self, decl_base): - with expect_raises_message( - sa_exc.ArgumentError, - "The type provided inside the 'data' attribute Mapped " - "annotation is the SQLAlchemy type .*BigInteger.*. Expected " - "a Python type instead", - ): + @testing.variation( + "argtype", + [ + "type", + ("column", testing.requires.python310), + ("mapped_column", testing.requires.python310), + "column_class", + "ref_to_type", + ("ref_to_column", testing.requires.python310), + ], + ) + def test_construct_lhs_sqlalchemy_type(self, decl_base, argtype): + """test for #12329. - class User(decl_base): - __tablename__ = "users" + of note here are all the different messages we have for when the + wrong thing is put into Mapped[], and in fact in #12329 we added + another one. - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[BigInteger] = mapped_column() + This is a lot of different messages, but at the same time they + occur at different places in the interpretation of types. If + we were to centralize all these messages, we'd still likely end up + doing distinct messages for each scenario, so instead we added + a new ArgumentError subclass MappedAnnotationError that provides + some commonality to all of these cases. + + + """ + expect_future_annotations = "annotations" in globals() + + if argtype.type: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, type is + # a SQL type + "The type provided inside the 'data' attribute Mapped " + "annotation is the SQLAlchemy type .*BigInteger.*. Expected " + "a Python type instead", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[BigInteger] = mapped_column() + + elif argtype.column: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # util.py -> _extract_mapped_subtype + ( + re.escape( + "Could not interpret annotation " + "Mapped[Column('q', BigInteger)]." + ) + if expect_future_annotations + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + else re.escape( + "The object provided inside the 'data' attribute " + "Mapped annotation is not a Python type, it's the " + "object Column('q', BigInteger(), table=None). " + "Expected a Python type." + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[Column("q", BigInteger)] = ( # noqa: F821 + mapped_column() + ) + + elif argtype.mapped_column: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + # interestingly, this raises at the same point for both + # future annotations mode and legacy annotations mode + r"The object provided inside the 'data' attribute " + "Mapped annotation is not a Python type, it's the object " + r"\. " + "Expected a Python type.", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + big_integer: Mapped[int] = mapped_column() + data: Mapped[big_integer] = mapped_column() + + elif argtype.column_class: + with expect_raises_message( + orm_exc.MappedAnnotationError, + # properties.py -> _init_column_for_annotation, type is not + # a SQL type + re.escape( + "Could not locate SQLAlchemy Core type for Python type " + " inside the " + "'data' attribute Mapped annotation" + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[Column] = mapped_column() + + elif argtype.ref_to_type: + mytype = BigInteger + with expect_raises_message( + orm_exc.MappedAnnotationError, + ( + # decl_base.py -> _exract_mappable_attributes + re.escape( + "Could not resolve all types within mapped " + 'annotation: "Mapped[mytype]"' + ) + if expect_future_annotations + # properties.py -> _init_column_for_annotation, type is + # a SQL type + else re.escape( + "The type provided inside the 'data' attribute Mapped " + "annotation is the SQLAlchemy type " + ". " + "Expected a Python type instead" + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[mytype] = mapped_column() + + elif argtype.ref_to_column: + mycol = Column("q", BigInteger) + + with expect_raises_message( + orm_exc.MappedAnnotationError, + # decl_base.py -> _exract_mappable_attributes + ( + re.escape( + "Could not resolve all types within mapped " + 'annotation: "Mapped[mycol]"' + ) + if expect_future_annotations + else + # properties.py -> _init_column_for_annotation, object is + # not a SQL type or a python type, it's just some object + re.escape( + "The object provided inside the 'data' attribute " + "Mapped " + "annotation is not a Python type, it's the object " + "Column('q', BigInteger(), table=None). " + "Expected a Python type." + ) + ), + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[mycol] = mapped_column() + + else: + argtype.fail() def test_construct_rhs_type_override_lhs(self, decl_base): class Element(decl_base): @@ -683,6 +998,352 @@ class MyClass(decl_base): is_true(MyClass.__table__.c.data_two.nullable) eq_(MyClass.__table__.c.data_three.type.length, 50) + def test_plain_typealias_as_typemap_keys( + self, decl_base: Type[DeclarativeBase] + ): + decl_base.registry.update_type_annotation_map( + {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)} + ) + + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[_StrTypeAlias] + structure: Mapped[_UnionTypeAlias] + + eq_(Test.__table__.c.data.type.length, 30) + is_(Test.__table__.c.structure.type._type_affinity, JSON) + + @testing.variation( + "option", + [ + "plain", + "union", + "union_604", + "union_null", + "union_null_604", + "optional", + "optional_union", + "optional_union_604", + "union_newtype", + "union_null_newtype", + "union_695", + "union_null_695", + ], + ) + @testing.variation("in_map", ["yes", "no", "value"]) + @testing.requires.python312 + def test_pep695_behavior(self, decl_base, in_map, option): + """Issue #11955""" + # anno only: global tat + + if option.plain: + tat = TypeAliasType("tat", str) + elif option.union: + tat = TypeAliasType("tat", Union[str, int]) + elif option.union_604: + tat = TypeAliasType("tat", str | int) + elif option.union_null: + tat = TypeAliasType("tat", Union[str, int, None]) + elif option.union_null_604: + tat = TypeAliasType("tat", str | int | None) + elif option.optional: + tat = TypeAliasType("tat", Optional[str]) + elif option.optional_union: + tat = TypeAliasType("tat", Optional[Union[str, int]]) + elif option.optional_union_604: + tat = TypeAliasType("tat", Optional[str | int]) + elif option.union_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int]) + elif option.union_null_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int, None]) + elif option.union_695: + tat = TypeAliasType("tat", str | int) + elif option.union_null_695: + tat = TypeAliasType("tat", str | int | None) + else: + option.fail() + + if in_map.yes: + decl_base.registry.update_type_annotation_map({tat: String(99)}) + elif in_map.value and "newtype" not in option.name: + decl_base.registry.update_type_annotation_map( + {tat.__value__: String(99)} + ) + + def declare(): + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[tat] + + return Test.__table__.c.data + + if in_map.yes: + col = declare() + length = 99 + elif ( + in_map.value + and "newtype" not in option.name + or option.optional + or option.plain + ): + with expect_deprecated( + "Matching the provided TypeAliasType 'tat' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + col = declare() + length = 99 if in_map.value else None + else: + with expect_raises_message( + orm_exc.MappedAnnotationError, + r"Could not locate SQLAlchemy Core type for Python type .*tat " + "inside the 'data' attribute Mapped annotation", + ): + declare() + return + + is_true(isinstance(col.type, String)) + eq_(col.type.length, length) + nullable = "null" in option.name or "optional" in option.name + eq_(col.nullable, nullable) + + @testing.variation( + "type_", + [ + "str_extension", + "str_typing", + "generic_extension", + "generic_typing", + "generic_typed_extension", + "generic_typed_typing", + ], + ) + @testing.requires.python312 + def test_pep695_typealias_as_typemap_keys( + self, decl_base: Type[DeclarativeBase], type_ + ): + """test #10807""" + + decl_base.registry.update_type_annotation_map( + { + _UnionPep695: JSON, + _StrPep695: String(30), + _TypingStrPep695: String(30), + _GenericPep695: String(30), + _TypingGenericPep695: String(30), + _GenericPep695Typed: String(30), + _TypingGenericPep695Typed: String(30), + } + ) + + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + if type_.str_extension: + data: Mapped[_StrPep695] + elif type_.str_typing: + data: Mapped[_TypingStrPep695] + elif type_.generic_extension: + data: Mapped[_GenericPep695] + elif type_.generic_typing: + data: Mapped[_TypingGenericPep695] + elif type_.generic_typed_extension: + data: Mapped[_GenericPep695Typed] + elif type_.generic_typed_typing: + data: Mapped[_TypingGenericPep695Typed] + else: + type_.fail() + structure: Mapped[_UnionPep695] + + eq_(Test.__table__.c.data.type._type_affinity, String) + eq_(Test.__table__.c.data.type.length, 30) + is_(Test.__table__.c.structure.type._type_affinity, JSON) + + @testing.variation( + "alias_type", + ["none", "typekeyword", "typealias", "typekeyword_nested"], + ) + @testing.requires.python312 + def test_extract_pep593_from_pep695( + self, decl_base: Type[DeclarativeBase], alias_type + ): + """test #11130""" + if alias_type.typekeyword: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword: VARCHAR(33)} # noqa: F821 + ) + if alias_type.typekeyword_nested: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword_nested: VARCHAR(42)} # noqa: F821 + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if alias_type.typekeyword: + data_one: Mapped[strtypalias_keyword] # noqa: F821 + elif alias_type.typealias: + data_one: Mapped[strtypalias_ta] # noqa: F821 + elif alias_type.none: + data_one: Mapped[strtypalias_plain] # noqa: F821 + elif alias_type.typekeyword_nested: + data_one: Mapped[strtypalias_keyword_nested] # noqa: F821 + else: + alias_type.fail() + + table = MyClass.__table__ + assert table is not None + + if alias_type.typekeyword_nested: + # a nested annotation is not supported + eq_(MyClass.data_one.expression.info, {}) + else: + eq_(MyClass.data_one.expression.info, {"hi": "there"}) + + if alias_type.typekeyword: + eq_(MyClass.data_one.type.length, 33) + elif alias_type.typekeyword_nested: + eq_(MyClass.data_one.type.length, 42) + else: + eq_(MyClass.data_one.type.length, None) + + @testing.variation( + "type_", + [ + "literal", + "literal_typing", + "recursive", + "not_literal", + "not_literal_typing", + "generic", + "generic_typing", + "generic_typed", + "generic_typed_typing", + ], + ) + @testing.combinations(True, False, argnames="in_map") + @testing.requires.python312 + def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): + """test #11305.""" + + def declare(): + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + if type_.recursive: + status: Mapped[_RecursiveLiteral695] # noqa: F821 + elif type_.literal: + status: Mapped[_Literal695] # noqa: F821 + elif type_.literal_typing: + status: Mapped[_TypingLiteral695] # noqa: F821 + elif type_.not_literal: + status: Mapped[_StrPep695] # noqa: F821 + elif type_.not_literal_typing: + status: Mapped[_TypingStrPep695] # noqa: F821 + elif type_.generic: + status: Mapped[_GenericPep695] # noqa: F821 + elif type_.generic_typing: + status: Mapped[_TypingGenericPep695] # noqa: F821 + elif type_.generic_typed: + status: Mapped[_GenericPep695Typed] # noqa: F821 + elif type_.generic_typed_typing: + status: Mapped[_TypingGenericPep695Typed] # noqa: F821 + else: + type_.fail() + + return Foo + + if in_map: + decl_base.registry.update_type_annotation_map( + { + _Literal695: Enum(enum.Enum), # noqa: F821 + _TypingLiteral695: Enum(enum.Enum), # noqa: F821 + _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 + _StrPep695: Enum(enum.Enum), # noqa: F821 + _TypingStrPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695Typed: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695Typed: Enum(enum.Enum), # noqa: F821 + } + ) + if type_.recursive: + with expect_deprecated( + "Mapping recursive TypeAliasType '.+' that resolve to " + "literal to generate an Enum is deprecated. SQLAlchemy " + "2.1 will not support this use case. Please avoid using " + "recursing TypeAliasType", + ): + Foo = declare() + elif type_.literal or type_.literal_typing: + Foo = declare() + else: + with expect_raises_message( + exc.ArgumentError, + "Can't associate TypeAliasType '.+' to an Enum " + "since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal.'a', " + "'b'.` are supported when generating Enums.", + ): + declare() + return + elif ( + type_.generic + or type_.generic_typing + or type_.generic_typed + or type_.generic_typed_typing + ): + # This behaves like 2.1 -> rationale is that no-one asked to + # support such types and in 2.1 will already be like this + # so it makes little sense to add support this late in the 2.0 + # series + with expect_raises_message( + exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type " + ".+ inside the 'status' attribute Mapped annotation", + ): + declare() + return + else: + with expect_deprecated( + "Matching the provided TypeAliasType '.*' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + Foo = declare() + col = Foo.__table__.c.status + if in_map and not type_.not_literal: + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["to-do", "in-progress", "done"]) + is_(col.type.native_enum, False) + else: + is_true(isinstance(col.type, String)) + + @testing.requires.python38 + def test_typing_literal_identity(self, decl_base): + """See issue #11820""" + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + t: Mapped[_TypingLiteral] + te: Mapped[_TypingExtensionsLiteral] + + for col in (Foo.__table__.c.t, Foo.__table__.c.te): + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["a", "b"]) + is_(col.type.native_enum, False) + @testing.requires.python310 def test_we_got_all_attrs_test_annotated(self): argnames = _py_inspect.getfullargspec(mapped_column) @@ -754,7 +1415,9 @@ def test_we_got_all_attrs_test_annotated(self): ), ("index", True, lambda column: column.index is True), ("index", _NoArg.NO_ARG, lambda column: column.index is None), + ("index", False, lambda column: column.index is False), ("unique", True, lambda column: column.unique is True), + ("unique", False, lambda column: column.unique is False), ("autoincrement", True, lambda column: column.autoincrement is True), ("system", True, lambda column: column.system is True), ("primary_key", True, lambda column: column.primary_key is True), @@ -823,6 +1486,13 @@ def test_we_got_all_attrs_test_annotated(self): "Argument 'init' is a dataclass argument" ), ), + ( + "hash", + True, + exc.SADeprecationWarning( + "Argument 'hash' is a dataclass argument" + ), + ), argnames="argname, argument, assertion", ) @testing.variation("use_annotated", [True, False, "control"]) @@ -846,6 +1516,7 @@ def test_names_encountered_for_annotated( "repr", "compare", "default_factory", + "hash", ) if is_dataclass: @@ -912,6 +1583,32 @@ class User(Base): argument, ) + @testing.combinations(("index",), ("unique",), argnames="paramname") + @testing.combinations((True,), (False,), (None,), argnames="orig") + @testing.combinations((True,), (False,), (None,), argnames="merging") + def test_index_unique_combinations( + self, paramname, orig, merging, decl_base + ): + """test #11091""" + + # anno only: global myint + + amc = mapped_column(**{paramname: merging}) + myint = Annotated[int, amc] + + mc = mapped_column(**{paramname: orig}) + + class User(decl_base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + myname: Mapped[myint] = mc + + result = getattr(User.__table__.c.myname, paramname) + if orig is None: + is_(result, merging) + else: + is_(result, orig) + def test_pep484_newtypes_as_typemap_keys( self, decl_base: Type[DeclarativeBase] ): @@ -946,6 +1643,33 @@ class MyClass(decl_base): eq_(MyClass.__table__.c.data_four.type.length, 150) is_false(MyClass.__table__.c.data_four.nullable) + def test_newtype_missing_from_map(self, decl_base): + # anno only: global str50 + + str50 = NewType("str50", str) + + if compat.py310: + text = ".*str50" + else: + # NewTypes before 3.10 had a very bad repr + # .new_type at 0x...> + text = ".*NewType.*" + + with expect_deprecated( + f"Matching the provided NewType '{text}' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data_one: Mapped[str50] + + is_true(isinstance(MyClass.data_one.type, String)) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -974,8 +1698,7 @@ class SomeRelated(decl_base): with expect_raises_message( NotImplementedError, - r"Use of the \ construct inside of an Annotated " + r"Use of the 'Relationship' construct inside of an Annotated " r"object is not yet supported.", ): @@ -1178,22 +1901,68 @@ class RefElementTwo(decl_base): Dict, (str, str), ), + (list, None, testing.requires.python310), + ( + List, + None, + ), + (dict, None, testing.requires.python310), + ( + Dict, + None, + ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + test #11831, regression from #11814 + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + if args is None: + TestType = Annotated[container_typ, 0] + else: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + if args is None: + TestType = container_typ + else: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if args == (str, str): + type_annotation_map = {TestType[str, str]: JSON()} + elif args is None: + type_annotation_map = {TestType: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if args == (str, str): + data: Mapped[TestType[str, str]] = mapped_column() + elif args is None: + data: Mapped[TestType] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON) @@ -1392,34 +2161,64 @@ class Element(decl_base): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + "union_null", + ("pep604_null", requires.python310), + ], + ) + def test_unions(self, union): + # anno only: global UnionType our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.union_null: + UnionType = Union[float, Decimal, None] + elif union.pep604: + UnionType = float | Decimal + elif union.pep604_null: + UnionType = float | Decimal | None + else: + union.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionType: our_type} class User(Base): __tablename__ = "users" - __table__: Table id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[Union[float, Decimal]] = mapped_column() - reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + data: Mapped[Union[float, Decimal]] + reverse_data: Mapped[Union[Decimal, float]] - optional_data: Mapped[ - Optional[Union[float, Decimal]] - ] = mapped_column() + optional_data: Mapped[Optional[Union[float, Decimal]]] = ( + mapped_column() + ) # use Optional directly - reverse_optional_data: Mapped[ - Optional[Union[Decimal, float]] - ] = mapped_column() + reverse_optional_data: Mapped[Optional[Union[Decimal, float]]] = ( + mapped_column() + ) # use Union with None, same as Optional but presents differently # (Optional object with __origin__ Union vs. Union) - reverse_u_optional_data: Mapped[ - Union[Decimal, float, None] + reverse_u_optional_data: Mapped[Union[Decimal, float, None]] = ( + mapped_column() + ) + + refer_union: Mapped[UnionType] + refer_union_optional: Mapped[Optional[UnionType]] + + # py38, 37 does not automatically flatten unions, add extra tests + # for this. maintain these in order to catch future regressions + # in the behavior of ``Union`` + unflat_union_optional_data: Mapped[ + Union[Union[Decimal, float, None], None] ] = mapped_column() float_data: Mapped[float] = mapped_column() @@ -1428,71 +2227,131 @@ class User(Base): if compat.py310: pep604_data: Mapped[float | Decimal] = mapped_column() pep604_reverse: Mapped[Decimal | float] = mapped_column() - pep604_optional: Mapped[ - Decimal | float | None - ] = mapped_column() + pep604_optional: Mapped[Decimal | float | None] = ( + mapped_column() + ) pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() - pep604_optional_fwd: Mapped[ - "Decimal | float | None" - ] = mapped_column() + pep604_optional_fwd: Mapped["Decimal | float | None"] = ( + mapped_column() + ) - is_(User.__table__.c.data.type, our_type) - is_false(User.__table__.c.data.nullable) - is_(User.__table__.c.reverse_data.type, our_type) - is_(User.__table__.c.optional_data.type, our_type) - is_true(User.__table__.c.optional_data.nullable) + info = [ + ("data", False), + ("reverse_data", False), + ("optional_data", True), + ("reverse_optional_data", True), + ("reverse_u_optional_data", True), + ("refer_union", "null" in union.name), + ("refer_union_optional", True), + ("unflat_union_optional_data", True), + ] + if compat.py310: + info += [ + ("pep604_data", False), + ("pep604_reverse", False), + ("pep604_optional", True), + ("pep604_data_fwd", False), + ("pep604_reverse_fwd", False), + ("pep604_optional_fwd", True), + ] - is_(User.__table__.c.reverse_optional_data.type, our_type) - is_(User.__table__.c.reverse_u_optional_data.type, our_type) - is_true(User.__table__.c.reverse_optional_data.nullable) - is_true(User.__table__.c.reverse_u_optional_data.nullable) + for name, nullable in info: + col = User.__table__.c[name] + is_(col.type, our_type, name) + is_(col.nullable, nullable, name) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + ne_(User.__table__.c.float_data.type, our_type) - if compat.py310: - for suffix in ("", "_fwd"): - data_col = User.__table__.c[f"pep604_data{suffix}"] - reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] - optional_col = User.__table__.c[f"pep604_optional{suffix}"] - is_(data_col.type, our_type) - is_false(data_col.nullable) - is_(reverse_col.type, our_type) - is_false(reverse_col.nullable) - is_(optional_col.type, our_type) - is_true(optional_col.nullable) + is_true(isinstance(User.__table__.c.decimal_data.type, Numeric)) + ne_(User.__table__.c.decimal_data.type, our_type) - @testing.combinations( - ("not_optional",), - ("optional",), - ("optional_fwd_ref",), - ("union_none",), - ("pep604", testing.requires.python310), - ("pep604_fwd_ref", testing.requires.python310), - argnames="optional_on_json", + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + ("pep695", requires.python312), + ], ) + def test_optional_in_annotation_map(self, union): + """See issue #11370""" + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = {_Json: JSON} + elif union.pep604: + type_annotation_map = {_JsonPep604: JSON} + elif union.pep695: + type_annotation_map = {_JsonPep695: JSON} # noqa: F821 + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + elif union.pep695: + json1: Mapped[_JsonPep695] # noqa: F821 + json2: Mapped[_JsonPep695] = mapped_column( # noqa: F821 + nullable=False + ) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + + @testing.variation( + "option", + [ + "not_optional", + "optional", + "optional_fwd_ref", + "union_none", + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + ], + ) + @testing.variation("brackets", ["oneset", "twosets"]) @testing.combinations( "include_mc_type", "derive_from_anno", argnames="include_mc_type" ) def test_optional_styles_nested_brackets( - self, optional_on_json, include_mc_type + self, option, brackets, include_mc_type ): + """composed types test, includes tests that were added later for + #12207""" + class Base(DeclarativeBase): if testing.requires.python310.enabled: type_annotation_map = { - Dict[str, str]: JSON, - dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, + list[int] | list[str]: JSON, } else: type_annotation_map = { - Dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, } if include_mc_type == "include_mc_type": mc = mapped_column(JSON) + mc2 = mapped_column(JSON) else: mc = mapped_column() + mc2 = mapped_column() class A(Base): __tablename__ = "a" @@ -1500,21 +2359,67 @@ class A(Base): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[str] = mapped_column() - if optional_on_json == "not_optional": - json: Mapped[Dict[str, str]] = mapped_column() # type: ignore - elif optional_on_json == "optional": - json: Mapped[Optional[Dict[str, str]]] = mc - elif optional_on_json == "optional_fwd_ref": - json: Mapped["Optional[Dict[str, str]]"] = mc - elif optional_on_json == "union_none": - json: Mapped[Union[Dict[str, str], None]] = mc - elif optional_on_json == "pep604": - json: Mapped[dict[str, str] | None] = mc - elif optional_on_json == "pep604_fwd_ref": - json: Mapped["dict[str, str] | None"] = mc + if brackets.oneset: + if option.not_optional: + json: Mapped[Dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + if testing.requires.python310.enabled: + json2: Mapped[dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Dict[str, Decimal]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Optional[dict[str, Decimal]]] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Dict[str, Decimal]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped["Optional[dict[str, Decimal]]"] = mc2 + elif option.union_none: + json: Mapped[Union[Dict[str, Decimal], None]] = mc + json2: Mapped[Union[None, Dict[str, Decimal]]] = mc2 + elif option.pep604: + json: Mapped[dict[str, Decimal] | None] = mc + if testing.requires.python310.enabled: + json2: Mapped[None | dict[str, Decimal]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["dict[str, Decimal] | None"] = mc + if testing.requires.python310.enabled: + json2: Mapped["None | dict[str, Decimal]"] = mc2 + elif brackets.twosets: + if option.not_optional: + json: Mapped[Union[List[int], List[str]]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Union[List[int], List[str]]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + Optional[Union[list[int], list[str]]] + ] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Union[List[int], List[str]]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + "Optional[Union[list[int], list[str]]]" + ] = mc2 + elif option.union_none: + json: Mapped[Union[List[int], List[str], None]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Union[None, list[int], list[str]]] = mc2 + elif option.pep604: + json: Mapped[list[int] | list[str] | None] = mc + json2: Mapped[None | list[int] | list[str]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["list[int] | list[str] | None"] = mc + json2: Mapped["None | list[int] | list[str]"] = mc2 + else: + brackets.fail() is_(A.__table__.c.json.type._type_affinity, JSON) - if optional_on_json == "not_optional": + if hasattr(A, "json2"): + is_(A.__table__.c.json2.type._type_affinity, JSON) + if option.not_optional: + is_false(A.__table__.c.json2.nullable) + else: + is_true(A.__table__.c.json2.nullable) + + if option.not_optional: is_false(A.__table__.c.json.nullable) else: is_true(A.__table__.c.json.nullable) @@ -1725,7 +2630,8 @@ class int_sub(int): ) with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + orm_exc.MappedAnnotationError, + "Could not locate SQLAlchemy Core type", ): class MyClass(Base): @@ -2280,6 +3186,42 @@ class Base(DeclarativeBase): yield Base Base.registry.dispose() + @testing.combinations( + (Relationship, CollectionAttributeImpl), + (Mapped, CollectionAttributeImpl), + (WriteOnlyMapped, WriteOnlyAttributeImpl), + (DynamicMapped, DynamicAttributeImpl), + argnames="mapped_cls,implcls", + ) + def test_use_relationship(self, decl_base, mapped_cls, implcls): + """test #10611""" + + # anno only: global B + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + # for future annotations support, need to write these + # directly in source code + if mapped_cls is Relationship: + bs: Relationship[List[B]] = relationship() + elif mapped_cls is Mapped: + bs: Mapped[List[B]] = relationship() + elif mapped_cls is WriteOnlyMapped: + bs: WriteOnlyMapped[List[B]] = relationship() + elif mapped_cls is DynamicMapped: + bs: DynamicMapped[List[B]] = relationship() + + decl_base.registry.configure() + assert isinstance(A.bs.impl, implcls) + def test_no_typing_in_rhs(self, decl_base): class A(decl_base): __tablename__ = "a" @@ -2398,9 +3340,9 @@ class A(decl_base): collection_class=list ) elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship(collection_class=list) + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship(collection_class=list) + ) else: datatype.fail() @@ -2427,15 +3369,15 @@ class A(decl_base): if datatype.typing_sequence: bs: Mapped[typing.Sequence[B]] = relationship() elif datatype.collections_sequence: - bs: Mapped[ - collections.abc.Sequence[B] - ] = relationship() + bs: Mapped[collections.abc.Sequence[B]] = ( + relationship() + ) elif datatype.typing_mutable_sequence: bs: Mapped[typing.MutableSequence[B]] = relationship() elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship() + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship() + ) else: datatype.fail() @@ -2535,7 +3477,7 @@ class B(decl_base): back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "union_none": - a: Mapped["Union[A, None]"] = relationship( + a: Mapped[Union[A, None]] = relationship( back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "pep604": @@ -2640,7 +3582,7 @@ class B(decl_base): is_false(B.__mapper__.attrs["a"].uselist) is_false(B.__mapper__.attrs["a_warg"].uselist) - def test_one_to_one_example(self, decl_base: Type[DeclarativeBase]): + def test_one_to_one_example_quoted(self, decl_base: Type[DeclarativeBase]): """test example in the relationship docs will derive uselist=False correctly""" @@ -2664,6 +3606,32 @@ class Child(decl_base): is_(p1.child, c1) is_(c1.parent, p1) + def test_one_to_one_example_non_quoted( + self, decl_base: Type[DeclarativeBase] + ): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Child(decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(primary_key=True) + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + parent: Mapped["Parent"] = relationship(back_populates="child") + + class Parent(decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(primary_key=True) + child: Mapped[Child] = relationship( # noqa: F821 + back_populates="parent" + ) + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + def test_collection_class_dict_no_collection(self, decl_base): class A(decl_base): __tablename__ = "a" diff --git a/test/orm/dml/test_bulk.py b/test/orm/dml/test_bulk.py index baa6c20f83f..3159c139da2 100644 --- a/test/orm/dml/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -2,6 +2,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import Identity from sqlalchemy import insert +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing @@ -89,8 +90,14 @@ def setup_mappers(cls): cls.mapper_registry.map_imperatively(Address, a) cls.mapper_registry.map_imperatively(Order, o) - @testing.combinations("save_objects", "insert_mappings", "insert_stmt") - def test_bulk_save_return_defaults(self, statement_type): + @testing.combinations( + "save_objects", + "insert_mappings", + "insert_stmt", + argnames="statement_type", + ) + @testing.variation("return_defaults", [True, False]) + def test_bulk_save_return_defaults(self, statement_type, return_defaults): (User,) = self.classes("User") s = fixture_session() @@ -101,12 +108,14 @@ def test_bulk_save_return_defaults(self, statement_type): returning_users_id = " RETURNING users.id" with self.sql_execution_asserter() as asserter: - s.bulk_save_objects(objects, return_defaults=True) + s.bulk_save_objects(objects, return_defaults=return_defaults) elif statement_type == "insert_mappings": data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] returning_users_id = " RETURNING users.id" with self.sql_execution_asserter() as asserter: - s.bulk_insert_mappings(User, data, return_defaults=True) + s.bulk_insert_mappings( + User, data, return_defaults=return_defaults + ) elif statement_type == "insert_stmt": data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] @@ -119,7 +128,10 @@ def test_bulk_save_return_defaults(self, statement_type): asserter.assert_( Conditional( - testing.db.dialect.insert_executemany_returning + ( + return_defaults + and testing.db.dialect.insert_executemany_returning + ) or statement_type == "insert_stmt", [ CompiledSQL( @@ -129,23 +141,61 @@ def test_bulk_save_return_defaults(self, statement_type): ), ], [ - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u1"}], - ), - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u2"}], - ), - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u3"}], - ), + Conditional( + return_defaults, + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u1"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u2"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u3"}], + ), + ], + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [ + {"name": "u1"}, + {"name": "u2"}, + {"name": "u3"}, + ], + ), + ], + ) ], ) ) + if statement_type == "save_objects": - eq_(objects[0].__dict__["id"], 1) + if return_defaults: + eq_(objects[0].__dict__["id"], 1) + eq_(inspect(objects[0]).key, (User, (1,), None)) + else: + assert "id" not in objects[0].__dict__ + eq_(inspect(objects[0]).key, None) + elif statement_type == "insert_mappings": + # test for #11661 + if return_defaults: + eq_(data[0]["id"], 1) + else: + assert "id" not in data[0] + + def test_bulk_save_objects_defaults_key(self): + User = self.classes.User + + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + key = inspect(pes[0]).key + + s.commit() + eq_(inspect(s.get(User, 1)).key, key) def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") @@ -238,7 +288,7 @@ def test_bulk_save_updated_include_unchanged(self): asserter.assert_( CompiledSQL( - "UPDATE users SET name=:name WHERE " "users.id = :users_id", + "UPDATE users SET name=:name WHERE users.id = :users_id", [ {"users_id": 1, "name": "u1new"}, {"users_id": 2, "name": "u2"}, diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 7af47de8186..6d69b2250c3 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -8,8 +8,10 @@ import uuid from sqlalchemy import bindparam +from sqlalchemy import Computed from sqlalchemy import event from sqlalchemy import exc +from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -23,14 +25,23 @@ from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import Bundle from sqlalchemy.orm import column_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import immediateload +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import lazyload from sqlalchemy.orm import load_only from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import orm_insert_sentinel +from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session +from sqlalchemy.orm import subqueryload from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures @@ -266,6 +277,86 @@ class User(decl_base): ), ) + @testing.requires.insert_returning + @testing.variation( + "insert_type", + [("values", testing.requires.multivalues_inserts), "bulk"], + ) + def test_returning_col_property( + self, decl_base, insert_type: testing.Variation + ): + """test #12326""" + + class User(ComparableEntity, decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + name: Mapped[str] + age: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + a_alias = aliased(User) + User.colprop = column_property( + select(func.max(a_alias.age)) + .where(a_alias.id != User.id) + .scalar_subquery() + ) + + sess = fixture_session() + + if insert_type.values: + stmt = insert(User).values( + [ + dict(id=1, name="john", age=25), + dict(id=2, name="jack", age=47), + dict(id=3, name="jill", age=29), + dict(id=4, name="jane", age=37), + ], + ) + params = None + elif insert_type.bulk: + stmt = insert(User) + params = [ + dict(id=1, name="john", age=25), + dict(id=2, name="jack", age=47), + dict(id=3, name="jill", age=29), + dict(id=4, name="jane", age=37), + ] + else: + insert_type.fail() + + stmt = stmt.returning(User) + + result = sess.execute(stmt, params=params) + + # the RETURNING doesn't have the column property in it. + # so to load these, they are all lazy loaded + with self.sql_execution_asserter() as asserter: + eq_( + result.scalars().all(), + [ + User(id=1, name="john", age=25, colprop=47), + User(id=2, name="jack", age=47, colprop=37), + User(id=3, name="jill", age=29, colprop=47), + User(id=4, name="jane", age=37, colprop=47), + ], + ) + + # assert they're all lazy loaded + asserter.assert_( + *[ + CompiledSQL( + 'SELECT (SELECT max(user_1.age) AS max_1 FROM "user" ' + 'AS user_1 WHERE user_1.id != "user".id) AS anon_1 ' + 'FROM "user" WHERE "user".id = :pk_1' + ) + for i in range(4) + ] + ) + @testing.requires.insert_returning @testing.requires.returning_star @testing.variation( @@ -381,6 +472,68 @@ class User(ComparableEntity, decl_base): eq_(result.all(), [User(id=1, name="John", age=30)]) + @testing.requires.insert_returning + @testing.variation( + "insert_type", + ["bulk", ("values", testing.requires.multivalues_inserts), "single"], + ) + def test_insert_returning_bundle(self, decl_base, insert_type): + """test #10776""" + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + + name: Mapped[str] = mapped_column() + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + insert_stmt = insert(User).returning( + User.name, Bundle("mybundle", User.id, User.x, User.y) + ) + + s = fixture_session() + + if insert_type.bulk: + result = s.execute( + insert_stmt, + [ + {"name": "some name 1", "x": 1, "y": 2}, + {"name": "some name 2", "x": 2, "y": 3}, + {"name": "some name 3", "x": 3, "y": 4}, + ], + ) + elif insert_type.values: + result = s.execute( + insert_stmt.values( + [ + {"name": "some name 1", "x": 1, "y": 2}, + {"name": "some name 2", "x": 2, "y": 3}, + {"name": "some name 3", "x": 3, "y": 4}, + ], + ) + ) + elif insert_type.single: + result = s.execute( + insert_stmt, {"name": "some name 1", "x": 1, "y": 2} + ) + else: + insert_type.fail() + + if insert_type.single: + eq_(result.all(), [("some name 1", (1, 1, 2))]) + else: + eq_( + result.all(), + [ + ("some name 1", (1, 1, 2)), + ("some name 2", (2, 2, 3)), + ("some name 3", (3, 3, 4)), + ], + ) + @testing.variation( "use_returning", [(True, testing.requires.insert_returning), False] ) @@ -531,6 +684,103 @@ class Employee(ComparableEntity, decl_base): class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): __backend__ = True + @testing.variation( + "use_onupdate", + [ + "none", + "server", + "callable", + "clientsql", + ("computed", testing.requires.computed_columns), + ], + ) + def test_bulk_update_onupdates( + self, + decl_base, + use_onupdate, + ): + """assert that for now, bulk ORM update by primary key does not + expire or refresh onupdates.""" + + class Employee(ComparableEntity, decl_base): + __tablename__ = "employee" + + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + user_name: Mapped[str] = mapped_column(String(200), nullable=False) + + if use_onupdate.server: + some_server_value: Mapped[str] = mapped_column( + server_onupdate=FetchedValue() + ) + elif use_onupdate.callable: + some_server_value: Mapped[str] = mapped_column( + onupdate=lambda: "value 2" + ) + elif use_onupdate.clientsql: + some_server_value: Mapped[str] = mapped_column( + onupdate=literal("value 2") + ) + elif use_onupdate.computed: + some_server_value: Mapped[str] = mapped_column( + String(255), + Computed(user_name + " computed value"), + nullable=True, + ) + else: + some_server_value: Mapped[str] + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + uuid1 = uuid.uuid4() + + if use_onupdate.computed: + server_old_value, server_new_value = ( + "e1 old name computed value", + "e1 new name computed value", + ) + e1 = Employee(uuid=uuid1, user_name="e1 old name") + else: + server_old_value, server_new_value = ("value 1", "value 2") + e1 = Employee( + uuid=uuid1, + user_name="e1 old name", + some_server_value="value 1", + ) + s.add(e1) + s.flush() + + # for computed col, make sure e1.some_server_value is loaded. + # this will already be the case for all RETURNING backends, so this + # suits just MySQL. + if use_onupdate.computed: + e1.some_server_value + + stmt = update(Employee) + + # perform out of band UPDATE on server value to simulate + # a computed col + if use_onupdate.none or use_onupdate.server: + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") + ) + + execution_options = {} + + s.execute( + stmt, + execution_options=execution_options, + params=[{"uuid": uuid1, "user_name": "e1 new name"}], + ) + + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + # do a full expire, now the new value is definitely there + s.commit() + s.expire_all() + eq_(e1.some_server_value, server_new_value) + @testing.variation( "returning_executemany", [ @@ -794,6 +1044,34 @@ class A(decl_base): result = s.execute(stmt, data) eq_(result.all(), [(1, 5, 9), (2, 5, 9), (3, 5, 9)]) + @testing.requires.update_returning + def test_bulk_update_returning_bundle(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + s.add_all( + [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)], + ) + s.commit() + + stmt = update(A).returning(Bundle("mybundle", A.id, A.x), A.y) + + data = {"x": 5, "y": 9} + + result = s.execute(stmt, data) + eq_(result.all(), [((1, 5), 9), ((2, 5), 9), ((3, 5), 9)]) + def test_bulk_update_w_where_one(self, decl_base): """test use case in #9595""" @@ -882,6 +1160,47 @@ class User(decl_base): ], ) + @testing.requires.update_returning + def test_returning_col_property(self, decl_base): + """test #12326""" + + class User(ComparableEntity, decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + name: Mapped[str] + age: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + a_alias = aliased(User) + User.colprop = column_property( + select(func.max(a_alias.age)) + .where(a_alias.id != User.id) + .scalar_subquery() + ) + + sess = fixture_session() + + sess.execute( + insert(User), + [ + dict(id=1, name="john", age=25), + dict(id=2, name="jack", age=47), + dict(id=3, name="jill", age=29), + dict(id=4, name="jane", age=37), + ], + ) + + stmt = ( + update(User).values(age=30).where(User.age == 29).returning(User) + ) + + row = sess.execute(stmt).one() + eq_(row[0], User(id=3, name="jill", age=30, colprop=47)) + class BulkDMLReturningInhTest: use_sentinel = False @@ -2207,3 +2526,213 @@ def test_select_from_insert_cte( asserter.assert_( CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}]) ) + + +class EagerLoadTest( + fixtures.DeclarativeMappedTest, testing.AssertsExecutionResults +): + run_inserts = "each" + __requires__ = ("insert_returning",) + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) + cs = relationship("C") + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + a = relationship("A") + + class C(Base): + __tablename__ = "c" + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + @classmethod + def insert_data(cls, connection): + A = cls.classes.A + C = cls.classes.C + with Session(connection) as sess: + sess.add_all( + [ + A(id=1, cs=[C(id=1), C(id=2)]), + A(id=2), + A(id=3, cs=[C(id=3), C(id=4)]), + ] + ) + sess.commit() + + @testing.fixture + def fixture_with_loader_opt(self): + def go(lazy): + class Base(DeclarativeBase): + pass + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + a = relationship("A", lazy=lazy) + + return A, B + + return go + + @testing.combinations( + (selectinload,), + (immediateload,), + ) + def test_insert_supported(self, loader): + A, B = self.classes("A", "B") + + sess = fixture_session() + + result = sess.execute( + insert(B).returning(B).options(loader(B.a)), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 3, "a_id": 2}, + {"id": 4, "a_id": 3}, + {"id": 5, "a_id": 3}, + ], + ).scalars() + + for b in result: + assert "a" in b.__dict__ + + @testing.combinations( + (joinedload,), + (subqueryload,), + ) + def test_insert_not_supported(self, loader): + """test #11853""" + + A, B = self.classes("A", "B") + + sess = fixture_session() + + stmt = insert(B).returning(B).options(loader(B.a)) + + with expect_deprecated( + f"The {loader.__name__} loader option is not compatible " + "with DML statements", + ): + sess.execute(stmt, [{"id": 1, "a_id": 1}]) + + @testing.combinations( + (joinedload,), + (subqueryload,), + (selectinload,), + (immediateload,), + ) + def test_secondary_opt_ok(self, loader): + A, B = self.classes("A", "B") + + sess = fixture_session() + + opt = selectinload(B.a) + opt = getattr(opt, loader.__name__)(A.cs) + + result = sess.execute( + insert(B).returning(B).options(opt), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 3, "a_id": 2}, + {"id": 4, "a_id": 3}, + {"id": 5, "a_id": 3}, + ], + ).scalars() + + for b in result: + assert "a" in b.__dict__ + assert "cs" in b.a.__dict__ + + @testing.combinations( + ("joined",), + ("select",), + ("subquery",), + ("selectin",), + ("immediate",), + argnames="lazy_opt", + ) + def test_insert_handles_implicit(self, fixture_with_loader_opt, lazy_opt): + """test #11853""" + + A, B = fixture_with_loader_opt(lazy_opt) + + sess = fixture_session() + + for b_obj in sess.execute( + insert(B).returning(B), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 3, "a_id": 2}, + {"id": 4, "a_id": 3}, + {"id": 5, "a_id": 3}, + ], + ).scalars(): + + if lazy_opt in ("select", "joined", "subquery"): + # these aren't supported by DML + assert "a" not in b_obj.__dict__ + else: + # the other three are + assert "a" in b_obj.__dict__ + + @testing.combinations( + (lazyload,), (selectinload,), (immediateload,), argnames="loader_opt" + ) + @testing.combinations( + (joinedload,), + (subqueryload,), + (selectinload,), + (immediateload,), + (lazyload,), + argnames="secondary_opt", + ) + def test_secondary_w_criteria_caching(self, loader_opt, secondary_opt): + """test #11855""" + A, B, C = self.classes("A", "B", "C") + + for i in range(3): + with fixture_session() as sess: + + opt = loader_opt(B.a) + opt = getattr(opt, secondary_opt.__name__)( + A.cs.and_(C.a_id == 1) + ) + stmt = insert(B).returning(B).options(opt) + + b1 = sess.scalar(stmt, [{"a_id": 1}]) + + eq_({c.id for c in b1.a.cs}, {1, 2}) + + opt = loader_opt(B.a) + opt = getattr(opt, secondary_opt.__name__)( + A.cs.and_(C.a_id == 3) + ) + + stmt = insert(B).returning(B).options(opt) + + b3 = sess.scalar(stmt, [{"a_id": 3}]) + + eq_({c.id for c in b3.a.cs}, {3, 4}) diff --git a/test/orm/dml/test_evaluator.py b/test/orm/dml/test_evaluator.py index 81da16914b7..3fc82db6944 100644 --- a/test/orm/dml/test_evaluator.py +++ b/test/orm/dml/test_evaluator.py @@ -370,6 +370,14 @@ def test_custom_op(self): r"Cannot evaluate math operator \"add\" for " r"datatypes JSON, INTEGER", ), + ( + lambda User: User.json + {"bar": "bat"}, + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate concatenate operator \"concat_op\" for " + r"datatypes JSON, JSON", + ), ( lambda User: User.json - 12, "json", diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 03468972d56..387ce161b86 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1,15 +1,22 @@ +from __future__ import annotations + +import uuid + from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc +from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ @@ -21,9 +28,12 @@ from sqlalchemy import values from sqlalchemy.orm import aliased from sqlalchemy.orm import backref +from sqlalchemy.orm import Bundle from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import immediateload from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -35,6 +45,7 @@ from sqlalchemy.sql.selectable import Select from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures @@ -42,6 +53,7 @@ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -66,6 +78,7 @@ def define_tables(cls, metadata): metadata, Column("id", Integer, primary_key=True), Column("user_id", ForeignKey("users.id")), + Column("email_address", String(50)), ) m = MetaData() @@ -106,6 +119,24 @@ def insert_data(cls, connection): ], ) + @testing.fixture + def addresses_data( + self, + ): + addresses = self.tables.addresses + + with testing.db.begin() as connection: + connection.execute( + addresses.insert(), + [ + dict(id=1, user_id=1, email_address="jo1"), + dict(id=2, user_id=1, email_address="jo2"), + dict(id=3, user_id=2, email_address="ja1"), + dict(id=4, user_id=3, email_address="ji1"), + dict(id=5, user_id=4, email_address="jan1"), + ], + ) + @classmethod def setup_mappers(cls): User = cls.classes.User @@ -1312,6 +1343,52 @@ def test_update_evaluate_w_explicit_returning(self): ), ) + @testing.requires.update_from_returning + # can't use evaluate because it can't match the col->col in the WHERE + @testing.combinations("fetch", "auto", argnames="synchronize_session") + def test_update_from_multi_returning( + self, synchronize_session, addresses_data + ): + """test #12327""" + User = self.classes.User + Address = self.classes.Address + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + with self.sql_execution_asserter() as asserter: + stmt = ( + update(User) + .where(User.id == Address.user_id) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .returning( + User.id, Address.email_address, func.char_length(User.name) + ) + .execution_options(synchronize_session=synchronize_session) + ) + + rows = sess.execute(stmt).all() + eq_(set(rows), {(2, "ja1", 4), (4, "jan1", 4)}) + + # these are simple values, these are now evaluated even with + # the "fetch" strategy, new in 1.4, so there is no expiry + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " + "FROM addresses " + "WHERE users.id = addresses.user_id AND " + "users.age_int > %(age_int_2)s " + "RETURNING users.id, addresses.email_address, " + "char_length(users.name) AS char_length_1", + [{"age_int_1": 10, "age_int_2": 29}], + dialect="postgresql", + ), + ) + @testing.requires.update_returning @testing.combinations("update", "delete", argnames="crud_type") def test_fetch_w_explicit_returning(self, crud_type): @@ -1351,6 +1428,45 @@ def test_fetch_w_explicit_returning(self, crud_type): # to point to the class, so you can test eq with sets eq_(set(result.all()), expected) + @testing.requires.update_returning + @testing.variation("crud_type", ["update", "delete"]) + @testing.combinations( + "auto", + "evaluate", + "fetch", + False, + argnames="synchronize_session", + ) + def test_crud_returning_bundle(self, crud_type, synchronize_session): + """test #10776""" + User = self.classes.User + + sess = fixture_session() + + if crud_type.update: + stmt = ( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session=synchronize_session) + .returning(Bundle("mybundle", User.id, User.age), User.name) + ) + expected = {((4, 27), "jane"), ((2, 37), "jack")} + elif crud_type.delete: + stmt = ( + delete(User) + .filter(User.age > 29) + .execution_options(synchronize_session=synchronize_session) + .returning(Bundle("mybundle", User.id, User.age), User.name) + ) + expected = {((2, 47), "jack"), ((4, 37), "jane")} + else: + crud_type.fail() + + result = sess.execute(stmt) + + eq_(set(result.all()), expected) + @testing.requires.delete_returning @testing.requires.returning_star def test_delete_returning_star(self): @@ -2535,7 +2651,7 @@ def test_update_from_multitable_same_names(self): ) -class ExpressionUpdateTest(fixtures.MappedTest): +class ExpressionUpdateDeleteTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( @@ -2601,6 +2717,27 @@ def do_orm_execute(bulk_ud): eq_(update_stmt.dialect_kwargs, update_args) + def test_delete_args(self): + Data = self.classes.Data + session = fixture_session() + delete_args = {"mysql_limit": 1} + + m1 = testing.mock.Mock() + + @event.listens_for(session, "after_bulk_delete") + def do_orm_execute(bulk_ud): + delete_stmt = ( + bulk_ud.result.context.compiled.compile_state.statement + ) + m1(delete_stmt) + + q = session.query(Data) + q.delete(delete_args=delete_args) + + delete_stmt = m1.mock_calls[0][1][0] + + eq_(delete_stmt.dialect_kwargs, delete_args) + class InheritTest(fixtures.DeclarativeMappedTest): run_inserts = "each" @@ -2924,6 +3061,54 @@ def test_update_from_multitable(self, synchronize_session): ) +class InheritWPolyTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def inherit_fixture(self, decl_base): + def go(poly_type): + + class Person(decl_base): + __tablename__ = "person" + id = Column(Integer, primary_key=True) + type = Column(String(50)) + name = Column(String(50)) + + if poly_type.wpoly: + __mapper_args__ = {"with_polymorphic": "*"} + + class Engineer(Person): + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) + engineer_name = Column(String(50)) + + if poly_type.inline: + __mapper_args__ = {"polymorphic_load": "inline"} + + return Person, Engineer + + return go + + @testing.variation("poly_type", ["wpoly", "inline", "none"]) + def test_update_base_only(self, poly_type, inherit_fixture): + Person, Engineer = inherit_fixture(poly_type) + + self.assert_compile( + update(Person).values(name="n1"), "UPDATE person SET name=:name" + ) + + @testing.variation("poly_type", ["wpoly", "inline", "none"]) + def test_delete_base_only(self, poly_type, inherit_fixture): + Person, Engineer = inherit_fixture(poly_type) + + self.assert_compile(delete(Person), "DELETE FROM person") + + self.assert_compile( + delete(Person).where(Person.id == 7), + "DELETE FROM person WHERE person.id = :id_1", + ) + + class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): __backend__ = True @@ -3205,3 +3390,263 @@ def test_load_from_delete(self, connection, use_from_statement): ) # TODO: state of above objects should be "deleted" + + +class OnUpdatePopulationTest(fixtures.TestBase): + __backend__ = True + + @testing.variation("populate_existing", [True, False]) + @testing.variation( + "use_onupdate", + [ + "none", + "server", + "callable", + "clientsql", + ("computed", testing.requires.computed_columns), + ], + ) + @testing.variation( + "use_returning", + [ + ("returning", testing.requires.update_returning), + ("defaults", testing.requires.update_returning), + "none", + ], + ) + @testing.variation("synchronize", ["auto", "fetch", "evaluate"]) + @testing.variation("pk_order", ["first", "middle"]) + def test_update_populate_existing( + self, + decl_base, + populate_existing, + use_onupdate, + use_returning, + synchronize, + pk_order, + ): + """test #11912 and #11917""" + + class Employee(ComparableEntity, decl_base): + __tablename__ = "employee" + + if pk_order.first: + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + user_name: Mapped[str] = mapped_column(String(200), nullable=False) + + if pk_order.middle: + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + + if use_onupdate.server: + some_server_value: Mapped[str] = mapped_column( + server_onupdate=FetchedValue() + ) + elif use_onupdate.callable: + some_server_value: Mapped[str] = mapped_column( + onupdate=lambda: "value 2" + ) + elif use_onupdate.clientsql: + some_server_value: Mapped[str] = mapped_column( + onupdate=literal("value 2") + ) + elif use_onupdate.computed: + some_server_value: Mapped[str] = mapped_column( + String(255), + Computed(user_name + " computed value"), + nullable=True, + ) + else: + some_server_value: Mapped[str] + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + uuid1 = uuid.uuid4() + + if use_onupdate.computed: + server_old_value, server_new_value = ( + "e1 old name computed value", + "e1 new name computed value", + ) + e1 = Employee(uuid=uuid1, user_name="e1 old name") + else: + server_old_value, server_new_value = ("value 1", "value 2") + e1 = Employee( + uuid=uuid1, + user_name="e1 old name", + some_server_value="value 1", + ) + s.add(e1) + s.flush() + + stmt = ( + update(Employee) + .values(user_name="e1 new name") + .where(Employee.uuid == uuid1) + ) + + if use_returning.returning: + stmt = stmt.returning(Employee) + elif use_returning.defaults: + # NOTE: the return_defaults case here has not been analyzed for + # #11912 or #11917. future enhancements may change its behavior + stmt = stmt.return_defaults() + + # perform out of band UPDATE on server value to simulate + # a computed col + if use_onupdate.none or use_onupdate.server: + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") + ) + + execution_options = {} + + if populate_existing: + execution_options["populate_existing"] = True + + if synchronize.evaluate: + execution_options["synchronize_session"] = "evaluate" + if synchronize.fetch: + execution_options["synchronize_session"] = "fetch" + + if use_returning.returning: + rows = s.scalars(stmt, execution_options=execution_options) + else: + s.execute(stmt, execution_options=execution_options) + + if ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + ): + if not use_returning.defaults: + # if server-side onupdate was generated, the col should have + # been expired + assert "some_server_value" not in e1.__dict__ + + # and refreshes when called. this is even if we have RETURNING + # rows we didn't fetch yet. + eq_(e1.some_server_value, server_new_value) + else: + # using return defaults here is not expiring. have not + # researched why, it may be because the explicit + # return_defaults interferes with the ORMs call + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + elif use_onupdate.callable: + if not use_returning.defaults or not synchronize.fetch: + # for python-side onupdate, col is populated with local value + assert "some_server_value" in e1.__dict__ + + # and is refreshed + eq_(e1.some_server_value, server_new_value) + else: + assert "some_server_value" in e1.__dict__ + + # and is not refreshed + eq_(e1.some_server_value, server_old_value) + + else: + # no onupdate, then the value was not touched yet, + # even if we used RETURNING with populate_existing, because + # we did not fetch the rows yet + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + # now see if we can fetch rows + if use_returning.returning: + + if populate_existing or not use_onupdate.none: + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_new_value, + ), + }, + ) + + else: + # if no populate existing and no server default, that column + # is not touched at all + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_old_value, + ), + }, + ) + + if use_returning.defaults: + # as mentioned above, the return_defaults() case here remains + # unanalyzed. + if synchronize.fetch or ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + or use_onupdate.none + ): + eq_(e1.some_server_value, server_old_value) + else: + eq_(e1.some_server_value, server_new_value) + + elif ( + populate_existing and use_returning.returning + ) or not use_onupdate.none: + eq_(e1.some_server_value, server_new_value) + else: + # no onupdate specified, and no populate existing with returning, + # the attribute is not refreshed + eq_(e1.some_server_value, server_old_value) + + # do a full expire, now the new value is definitely there + s.commit() + s.expire_all() + eq_(e1.some_server_value, server_new_value) + + +class PGIssue11849Test(fixtures.DeclarativeMappedTest): + __backend__ = True + __only_on__ = ("postgresql",) + + @classmethod + def setup_classes(cls): + + from sqlalchemy.dialects.postgresql import JSONB + + Base = cls.DeclarativeBasic + + class TestTbl(Base): + __tablename__ = "testtbl" + + test_id = Column(Integer, primary_key=True) + test_field = Column(JSONB) + + def test_issue_11849(self): + TestTbl = self.classes.TestTbl + + session = fixture_session() + + obj = TestTbl( + test_id=1, test_field={"test1": 1, "test2": "2", "test3": [3, "3"]} + ) + session.add(obj) + + query = ( + update(TestTbl) + .where(TestTbl.test_id == 1) + .values(test_field=TestTbl.test_field + {"test3": {"test4": 4}}) + ) + session.execute(query) + + # not loaded + assert "test_field" not in obj.__dict__ + + # synchronizes on load + eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}}) diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index 5b5989c9205..d0f8e680d0d 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -469,19 +469,20 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): e.g.:: self._fixture_from_geometry( - "a": { - "subclasses": { - "b": {"polymorphic_load": "selectin"}, - "c": { - "subclasses": { - "d": { - "polymorphic_load": "inlne", "single": True - }, - "e": { - "polymorphic_load": "inline", "single": True + { + "a": { + "subclasses": { + "b": {"polymorphic_load": "selectin"}, + "c": { + "subclasses": { + "d": {"polymorphic_load": "inlne", "single": True}, + "e": { + "polymorphic_load": "inline", + "single": True, + }, }, + "polymorphic_load": "selectin", }, - "polymorphic_load": "selectin", } } } @@ -490,42 +491,41 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): would provide the equivalent of:: class a(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) a_data = Column(String(50)) type = Column(String(50)) - __mapper_args__ = { - "polymorphic_on": type, - "polymorphic_identity": "a" - } + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "a"} + class b(a): - __tablename__ = 'b' + __tablename__ = "b" - id = Column(ForeignKey('a.id'), primary_key=True) + id = Column(ForeignKey("a.id"), primary_key=True) b_data = Column(String(50)) __mapper_args__ = { "polymorphic_identity": "b", - "polymorphic_load": "selectin" + "polymorphic_load": "selectin", } # ... + class c(a): - __tablename__ = 'c' + __tablename__ = "c" - class d(c): - # ... - class e(c): - # ... + class d(c): ... + + + class e(c): ... Declarative is used so that we get extra behaviors of declarative, such as single-inheritance column masking. - """ + """ # noqa: E501 run_create_tables = "each" run_define_tables = "each" diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 0f9a623bdac..9100970d440 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -32,6 +32,7 @@ from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_polymorphic from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.testing import AssertsCompiledSQL @@ -818,7 +819,7 @@ class RelationshipTest6(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - global people, managers, data + global people, managers people = Table( "people", metadata, @@ -2476,9 +2477,9 @@ class Retailer(Customer): __mapper_args__ = { "polymorphic_identity": "retailer", - "polymorphic_load": "inline" - if use_poly_on_retailer - else None, + "polymorphic_load": ( + "inline" if use_poly_on_retailer else None + ), } return Customer, Store, Retailer @@ -3148,3 +3149,177 @@ def test_big_query(self, query_type, use_criteria): head, UnitHead(managers=expected_managers), ) + + +@testing.combinations( + (2,), + (3,), + id_="s", + argnames="num_levels", +) +@testing.combinations( + ("with_poly_star",), + ("inline",), + ("selectin",), + ("none",), + id_="s", + argnames="wpoly_type", +) +class SubclassWithPolyEagerLoadTest(fixtures.DeclarativeMappedTest): + """test #11446""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey("a.id")) + + class A(Base): + __tablename__ = "a" + + id = Column(Integer, primary_key=True) + type = Column(String(10)) + bs = relationship("B") + + if cls.wpoly_type == "selectin": + __mapper_args__ = {"polymorphic_on": "type"} + elif cls.wpoly_type == "inline": + __mapper_args__ = {"polymorphic_on": "type"} + elif cls.wpoly_type == "with_poly_star": + __mapper_args__ = { + "with_polymorphic": "*", + "polymorphic_on": "type", + } + else: + __mapper_args__ = {"polymorphic_on": "type"} + + class ASub(A): + __tablename__ = "asub" + id = Column(ForeignKey("a.id"), primary_key=True) + sub_data = Column(String(10)) + + if cls.wpoly_type == "selectin": + __mapper_args__ = { + "polymorphic_load": "selectin", + "polymorphic_identity": "asub", + } + elif cls.wpoly_type == "inline": + __mapper_args__ = { + "polymorphic_load": "inline", + "polymorphic_identity": "asub", + } + elif cls.wpoly_type == "with_poly_star": + __mapper_args__ = { + "with_polymorphic": "*", + "polymorphic_identity": "asub", + } + else: + __mapper_args__ = {"polymorphic_identity": "asub"} + + if cls.num_levels == 3: + + class ASubSub(ASub): + __tablename__ = "asubsub" + id = Column(ForeignKey("asub.id"), primary_key=True) + sub_sub_data = Column(String(10)) + + if cls.wpoly_type == "selectin": + __mapper_args__ = { + "polymorphic_load": "selectin", + "polymorphic_identity": "asubsub", + } + elif cls.wpoly_type == "inline": + __mapper_args__ = { + "polymorphic_load": "inline", + "polymorphic_identity": "asubsub", + } + elif cls.wpoly_type == "with_poly_star": + __mapper_args__ = { + "with_polymorphic": "*", + "polymorphic_identity": "asubsub", + } + else: + __mapper_args__ = {"polymorphic_identity": "asubsub"} + + @classmethod + def insert_data(cls, connection): + if cls.num_levels == 3: + ASubSub, B = cls.classes("ASubSub", "B") + + with Session(connection) as sess: + sess.add_all( + [ + ASubSub( + sub_data="sub", + sub_sub_data="subsub", + bs=[B(), B(), B()], + ) + for i in range(3) + ] + ) + + sess.commit() + else: + ASub, B = cls.classes("ASub", "B") + + with Session(connection) as sess: + sess.add_all( + [ + ASub(sub_data="sub", bs=[B(), B(), B()]) + for i in range(3) + ] + ) + sess.commit() + + @testing.variation("query_from", ["aliased_class", "class_", "parent"]) + @testing.combinations(selectinload, subqueryload, argnames="loader_fn") + def test_thing(self, query_from, loader_fn): + + A = self.classes.A + + if self.num_levels == 2: + target = self.classes.ASub + elif self.num_levels == 3: + target = self.classes.ASubSub + + if query_from.aliased_class: + asub_alias = aliased(target) + query = select(asub_alias).options(loader_fn(asub_alias.bs)) + elif query_from.class_: + query = select(target).options(loader_fn(A.bs)) + elif query_from.parent: + query = select(A).options(loader_fn(A.bs)) + + s = fixture_session() + + # NOTE: this is likely a different bug - setting + # polymorphic_load to "inline" and loading from the parent does not + # descend to the ASubSub subclass; however "selectin" setting + # **does**. this is inconsistent + if ( + query_from.parent + and self.wpoly_type == "inline" + and self.num_levels == 3 + ): + # this should ideally be "2" + expected_q = 5 + + elif query_from.parent and self.wpoly_type == "none": + expected_q = 5 + elif query_from.parent and self.wpoly_type == "selectin": + expected_q = 3 + else: + expected_q = 2 + + with self.assert_statement_count(testing.db, expected_q): + for obj in s.scalars(query): + # test both that with_polymorphic loaded + eq_(obj.sub_data, "sub") + if self.num_levels == 3: + eq_(obj.sub_sub_data, "subsub") + + # as well as the collection eagerly loaded + assert obj.bs diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index abd6c86b570..9028fd25a43 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1684,7 +1684,7 @@ def test_none(self): s.flush() asserter.assert_( RegexSQL( - "SELECT .* " "FROM c WHERE :param_1 = c.bid", [{"param_1": 3}] + "SELECT .* FROM c WHERE :param_1 = c.bid", [{"param_1": 3}] ), CompiledSQL("DELETE FROM c WHERE c.cid = :cid", [{"cid": 1}]), CompiledSQL("DELETE FROM b WHERE b.id = :id", [{"id": 3}]), @@ -1933,7 +1933,7 @@ def test_refresh_column(self): # a.id is not included in the SELECT list "SELECT b.data FROM a JOIN b ON a.id = b.id " "WHERE a.id = :pk_1", - [{"pk_1": pk}] + [{"pk_1": pk}], # if we used load_scalar_attributes(), it would look like # this # "SELECT b.data AS b_data FROM b WHERE :param_1 = b.id", @@ -3012,7 +3012,7 @@ class D(C): ) def test_optimized_passes(self): - """ "test that the 'optimized load' routine doesn't crash when + """test that the 'optimized load' routine doesn't crash when a column in the join condition is not available.""" base, sub = self.tables.base, self.tables.sub @@ -3744,7 +3744,7 @@ class B(A): __mapper_args__ = {"polymorphic_identity": "b"} with expect_warnings( - r"Mapper\[C\(a\)\] does not indicate a " "'polymorphic_identity'," + r"Mapper\[C\(a\)\] does not indicate a 'polymorphic_identity'," ): class C(A): diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index df286f0d35c..58cf7b54271 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -735,6 +735,66 @@ def test_threelevel_selectin_to_inline_options(self): with self.assert_statement_count(testing.db, 0): eq_(result, [d(d_data="d1"), e(e_data="e1")]) + @testing.variation("include_intermediary_row", [True, False]) + def test_threelevel_load_only_3lev(self, include_intermediary_row): + """test issue #11327""" + + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "b": {"subclasses": {"c": {}}}, + } + } + } + ) + + a, b, c = self.classes("a", "b", "c") + sess = fixture_session() + sess.add(c(a_data="a1", b_data="b1", c_data="c1")) + if include_intermediary_row: + sess.add(b(a_data="a1", b_data="b1")) + sess.commit() + + sess = fixture_session() + + pks = [] + c_pks = [] + with self.sql_execution_asserter(testing.db) as asserter: + + for obj in sess.scalars( + select(a) + .options(selectin_polymorphic(a, classes=[b, c])) + .order_by(a.id) + ): + assert "b_data" in obj.__dict__ + if isinstance(obj, c): + assert "c_data" in obj.__dict__ + c_pks.append(obj.id) + pks.append(obj.id) + + asserter.assert_( + CompiledSQL( + "SELECT a.id, a.type, a.a_data FROM a ORDER BY a.id", {} + ), + AllOf( + CompiledSQL( + "SELECT c.id AS c_id, b.id AS b_id, a.id AS a_id, " + "a.type AS a_type, c.c_data AS c_c_data FROM a JOIN b " + "ON a.id = b.id JOIN c ON b.id = c.id WHERE a.id IN " + "(__[POSTCOMPILE_primary_keys]) ORDER BY a.id", + [{"primary_keys": c_pks}], + ), + CompiledSQL( + "SELECT b.id AS b_id, a.id AS a_id, a.type AS a_type, " + "b.b_data AS b_b_data FROM a JOIN b ON a.id = b.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": pks}], + ), + ), + ) + @testing.combinations((True,), (False,)) def test_threelevel_selectin_to_inline_awkward_alias_options( self, use_aliased_class @@ -752,7 +812,9 @@ def test_threelevel_selectin_to_inline_awkward_alias_options( a, b, c, d, e = self.classes("a", "b", "c", "d", "e") sess = fixture_session() - sess.add_all([d(d_data="d1"), e(e_data="e1")]) + sess.add_all( + [d(c_data="c1", d_data="d1"), e(c_data="c2", e_data="e1")] + ) sess.commit() from sqlalchemy import select @@ -840,6 +902,15 @@ def test_threelevel_selectin_to_inline_awkward_alias_options( {}, ), AllOf( + # note this query is added due to the fix made in + # #11327 + CompiledSQL( + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " + "c.c_data AS c_c_data FROM a JOIN c ON a.id = c.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": [1, 2]}], + ), CompiledSQL( "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, " "a.type AS a_type, d.d_data AS d_d_data FROM a " @@ -860,7 +931,10 @@ def test_threelevel_selectin_to_inline_awkward_alias_options( ) with self.assert_statement_count(testing.db, 0): - eq_(result, [d(d_data="d1"), e(e_data="e1")]) + eq_( + result, + [d(c_data="c1", d_data="d1"), e(c_data="c2", e_data="e1")], + ) def test_partial_load_no_invoke_eagers(self): # test issue #4199 @@ -1396,18 +1470,10 @@ def test_wp(self, mapping_fixture, connection): class CompositeAttributesTest(fixtures.TestBase): - @testing.fixture - def mapping_fixture(self, registry, connection): - Base = registry.generate_base() - class BaseCls(Base): - __tablename__ = "base" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) - type = Column(String(50)) - - __mapper_args__ = {"polymorphic_on": type} + @testing.fixture(params=("base", "sub")) + def mapping_fixture(self, request, registry, connection): + Base = registry.generate_base() class XYThing: def __init__(self, x, y): @@ -1427,13 +1493,28 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + class BaseCls(Base): + __tablename__ = "base" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + type = Column(String(50)) + + if request.param == "base": + comp1 = composite( + XYThing, Column("x1", Integer), Column("y1", Integer) + ) + + __mapper_args__ = {"polymorphic_on": type} + class A(ComparableEntity, BaseCls): __tablename__ = "a" id = Column(ForeignKey(BaseCls.id), primary_key=True) thing1 = Column(String(50)) - comp1 = composite( - XYThing, Column("x1", Integer), Column("y1", Integer) - ) + if request.param == "sub": + comp1 = composite( + XYThing, Column("x1", Integer), Column("y1", Integer) + ) __mapper_args__ = { "polymorphic_identity": "a", diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 0b358f8894b..1216aa0106f 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -2060,6 +2060,14 @@ def test_correlation_three(self): [(e3.name,)], ) + def test_with_polymorphic_named(self): + session = fixture_session() + poly = with_polymorphic(Person, "*", name="poly_name") + + res = session.execute(select(poly)).mappings() + eq_(res.keys(), ["poly_name"]) + eq_(len(res.all()), 5) + class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): def test_joined_aliasing_unrelated_subuqery(self): diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index daaf937b912..be42dc60904 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -2896,9 +2896,11 @@ def test_query_auto(self, autoalias): m1 = aliased(Manager, flat=True) q = sess.query(Engineer, m1).join(Engineer.manager.of_type(m1)) - with _aliased_join_warning( - r"Manager\(managers\)" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Manager\(managers\)") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT engineers.id AS " diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 52f3cf9c9f7..bfdf0b7bcfa 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -377,6 +377,58 @@ def test_select_from_aliased_w_subclass(self): "WHERE employees_1.type IN (__[POSTCOMPILE_type_1])", ) + @testing.combinations( + ( + lambda Engineer, Report: select(Report.report_id) + .select_from(Engineer) + .join(Engineer.reports), + ), + ( + lambda Engineer, Report: select(Report.report_id).select_from( + orm_join(Engineer, Report, Engineer.reports) + ), + ), + ( + lambda Engineer, Report: select(Report.report_id).join_from( + Engineer, Report, Engineer.reports + ), + ), + ( + lambda Engineer, Report: select(Report.report_id) + .select_from(Engineer) + .join(Report), + ), + argnames="stmt_fn", + ) + @testing.combinations(True, False, argnames="alias_engineer") + def test_select_col_only_from_w_join(self, stmt_fn, alias_engineer): + """test #11412 which seems to have been fixed by #10365""" + + Engineer = self.classes.Engineer + Report = self.classes.Report + + if alias_engineer: + Engineer = aliased(Engineer) + stmt = testing.resolve_lambda( + stmt_fn, Engineer=Engineer, Report=Report + ) + + if alias_engineer: + self.assert_compile( + stmt, + "SELECT reports.report_id FROM employees AS employees_1 " + "JOIN reports ON employees_1.employee_id = " + "reports.employee_id WHERE employees_1.type " + "IN (__[POSTCOMPILE_type_1])", + ) + else: + self.assert_compile( + stmt, + "SELECT reports.report_id FROM employees JOIN reports " + "ON employees.employee_id = reports.employee_id " + "WHERE employees.type IN (__[POSTCOMPILE_type_1])", + ) + @testing.combinations( ( lambda Engineer, Report: select(Report) @@ -1909,9 +1961,11 @@ def test_single_inh_subclass_join_joined_inh_subclass(self, autoalias): e1 = aliased(Engineer, flat=True) q = s.query(Boss).join(e1, e1.manager_id == Boss.id) - with _aliased_join_warning( - r"Mapper\[Engineer\(engineer\)\]" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Mapper\[Engineer\(engineer\)\]") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT manager.id AS manager_id, employee.id AS employee_id, " @@ -1974,9 +2028,11 @@ def test_joined_inh_subclass_join_single_inh_subclass(self, autoalias): b1 = aliased(Boss, flat=True) q = s.query(Engineer).join(b1, Engineer.manager_id == b1.id) - with _aliased_join_warning( - r"Mapper\[Boss\(manager\)\]" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Mapper\[Boss\(manager\)\]") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT engineer.id AS engineer_id, " diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index 677f8f20736..f14cdda5b66 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -6,6 +6,7 @@ be cleaned up and modernized. """ + import datetime import sqlalchemy as sa diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 976df514f3b..abd008cadf0 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -464,7 +464,7 @@ def get_bind(self, **kw): engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name] with mock.patch( - "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result" + "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result" ), mock.patch( "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement" ), mock.patch( @@ -529,7 +529,7 @@ def test_bound_connection(self): assert_raises_message( sa.exc.InvalidRequestError, - "Session already has a Connection " "associated", + "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect(), None, diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py index 6d613091def..a1bd399a4cb 100644 --- a/test/orm/test_bundle.py +++ b/test/orm/test_bundle.py @@ -3,6 +3,7 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import select +from sqlalchemy import SelectLabelStyle from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import tuple_ @@ -159,6 +160,151 @@ def test_c_attr(self): select(b1.c.d1, b1.c.d2), "SELECT data.d1, data.d2 FROM data" ) + @testing.variation( + "stmt_type", ["legacy", "newstyle", "newstyle_w_label_conv"] + ) + @testing.variation("col_type", ["orm", "core"]) + def test_dupe_col_name(self, stmt_type, col_type): + """test #11347""" + Data = self.classes.Data + sess = fixture_session() + + if col_type.orm: + b1 = Bundle("b1", Data.d1, Data.d3) + cols = Data.d1, Data.d2 + elif col_type.core: + data_table = self.tables.data + b1 = Bundle("b1", data_table.c.d1, data_table.c.d3) + cols = data_table.c.d1, data_table.c.d2 + else: + col_type.fail() + + if stmt_type.legacy: + row = ( + sess.query(cols[0], cols[1], b1) + .filter(Data.d1 == "d0d1") + .one() + ) + elif stmt_type.newstyle: + row = sess.execute( + select(cols[0], cols[1], b1).filter(Data.d1 == "d0d1") + ).one() + elif stmt_type.newstyle_w_label_conv: + row = sess.execute( + select(cols[0], cols[1], b1) + .filter(Data.d1 == "d0d1") + .set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).one() + else: + stmt_type.fail() + + if stmt_type.newstyle_w_label_conv: + # decision is made here that even if a SELECT with the + # "tablename_plus_colname" label style, within a Bundle we still + # use straight column name, even though the overall row + # uses tablename_colname + eq_( + row._mapping, + {"data_d1": "d0d1", "data_d2": "d0d2", "b1": ("d0d1", "d0d3")}, + ) + else: + eq_( + row._mapping, + {"d1": "d0d1", "d2": "d0d2", "b1": ("d0d1", "d0d3")}, + ) + + eq_(row[2]._mapping, {"d1": "d0d1", "d3": "d0d3"}) + + @testing.variation( + "stmt_type", ["legacy", "newstyle", "newstyle_w_label_conv"] + ) + @testing.variation("col_type", ["orm", "core"]) + def test_dupe_col_name_nested(self, stmt_type, col_type): + """test #11347""" + Data = self.classes.Data + sess = fixture_session() + + class DictBundle(Bundle): + def create_row_processor(self, query, procs, labels): + def proc(row): + return dict(zip(labels, (proc(row) for proc in procs))) + + return proc + + if col_type.core: + data_table = self.tables.data + + b1 = DictBundle("b1", data_table.c.d1, data_table.c.d3) + b2 = DictBundle("b2", data_table.c.d2, data_table.c.d3) + b3 = DictBundle("b3", data_table.c.d2, data_table.c.d3, b1, b2) + elif col_type.orm: + b1 = DictBundle("b1", Data.d1, Data.d3) + b2 = DictBundle("b2", Data.d2, Data.d3) + b3 = DictBundle("b3", Data.d2, Data.d3, b1, b2) + else: + col_type.fail() + + if stmt_type.legacy: + row = ( + sess.query(Data.d1, Data.d2, b3) + .filter(Data.d1 == "d0d1") + .one() + ) + elif stmt_type.newstyle: + row = sess.execute( + select(Data.d1, Data.d2, b3).filter(Data.d1 == "d0d1") + ).one() + elif stmt_type.newstyle_w_label_conv: + row = sess.execute( + select(Data.d1, Data.d2, b3) + .filter(Data.d1 == "d0d1") + .set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) + ).one() + else: + stmt_type.fail() + + if stmt_type.newstyle_w_label_conv: + eq_( + row._mapping, + { + "data_d1": "d0d1", + "data_d2": "d0d2", + "b3": { + "d2": "d0d2", + "d3": "d0d3", + "b1": {"d1": "d0d1", "d3": "d0d3"}, + "b2": {"d2": "d0d2", "d3": "d0d3"}, + }, + }, + ) + else: + eq_( + row._mapping, + { + "d1": "d0d1", + "d2": "d0d2", + "b3": { + "d2": "d0d2", + "d3": "d0d3", + "b1": {"d1": "d0d1", "d3": "d0d3"}, + "b2": {"d2": "d0d2", "d3": "d0d3"}, + }, + }, + ) + eq_( + row[2], + { + "d2": "d0d2", + "d3": "d0d3", + "b1": {"d1": "d0d1", "d3": "d0d3"}, + "b2": {"d2": "d0d2", "d3": "d0d3"}, + }, + ) + def test_result(self): Data = self.classes.Data sess = fixture_session() diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index ff70e4718b5..4bd353b84fd 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -643,15 +643,9 @@ def test_wpoly_cache_keys(self): self._run_cache_key_fixture( lambda: ( inspect(Person), - inspect( - aliased(Person, me_stmt), - ), - inspect( - aliased(Person, meb_stmt), - ), - inspect( - with_polymorphic(Person, [Manager, Engineer]), - ), + inspect(aliased(Person, me_stmt)), + inspect(aliased(Person, meb_stmt)), + inspect(with_polymorphic(Person, [Manager, Engineer])), # aliased=True is the same as flat=True for default selectable inspect( with_polymorphic( @@ -695,9 +689,7 @@ def test_wpoly_cache_keys(self): aliased=True, ), ), - inspect( - with_polymorphic(Person, [Manager, Engineer, Boss]), - ), + inspect(with_polymorphic(Person, [Manager, Engineer, Boss])), inspect( with_polymorphic( Person, @@ -712,6 +704,7 @@ def test_wpoly_cache_keys(self): polymorphic_on=literal_column("bar"), ), ), + inspect(with_polymorphic(Person, "*", name="foo")), ), compare_values=True, ) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index ded2c25db79..cd205be5b48 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -16,9 +16,13 @@ from sqlalchemy.orm import Composite from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers +from sqlalchemy.orm import defer +from sqlalchemy.orm import load_only from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm import undefer +from sqlalchemy.orm import undefer_group from sqlalchemy.orm.attributes import LoaderCallableStatus from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -411,11 +415,11 @@ def test_bulk_insert_heterogeneous(self, type_): assert_data = [ { "start": d["start"] if "start" in d else None, - "end": d["end"] - if "end" in d - else Point(d["x2"], d["y2"]) - if "x2" in d - else None, + "end": ( + d["end"] + if "end" in d + else Point(d["x2"], d["y2"]) if "x2" in d else None + ), "graph_id": d["graph_id"], } for d in data @@ -916,9 +920,11 @@ def test_event_listener_no_value_to_set( mock.call( e1, Point(5, 6), - LoaderCallableStatus.NO_VALUE - if not active_history - else None, + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else None + ), Edge.start.impl, ) ], @@ -965,9 +971,11 @@ def test_event_listener_set_to_new( mock.call( e1, Point(7, 8), - LoaderCallableStatus.NO_VALUE - if not active_history - else Point(5, 6), + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else Point(5, 6) + ), Edge.start.impl, ) ], @@ -1019,9 +1027,11 @@ def test_event_listener_set_to_deleted( [ mock.call( e1, - LoaderCallableStatus.NO_VALUE - if not active_history - else Point(5, 6), + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else Point(5, 6) + ), Edge.start.impl, ) ], @@ -1464,7 +1474,7 @@ def test_query_aliased(self): eq_(sess.query(ae).filter(ae.c == C("a2b1", b2)).one(), a2) -class ConfigurationTest(fixtures.MappedTest): +class ConfigAndDeferralTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( @@ -1502,7 +1512,7 @@ def __ne__(self, other): class Edge(cls.Comparable): pass - def _test_roundtrip(self): + def _test_roundtrip(self, *, assert_deferred=False, options=()): Edge, Point = self.classes.Edge, self.classes.Point e1 = Edge(start=Point(3, 4), end=Point(5, 6)) @@ -1510,7 +1520,19 @@ def _test_roundtrip(self): sess.add(e1) sess.commit() - eq_(sess.query(Edge).one(), Edge(start=Point(3, 4), end=Point(5, 6))) + stmt = select(Edge) + if options: + stmt = stmt.options(*options) + e1 = sess.execute(stmt).scalar_one() + + names = ["start", "end", "x1", "x2", "y1", "y2"] + for name in names: + if assert_deferred: + assert name not in e1.__dict__ + else: + assert name in e1.__dict__ + + eq_(e1, Edge(start=Point(3, 4), end=Point(5, 6))) def test_columns(self): edge, Edge, Point = ( @@ -1556,7 +1578,7 @@ def test_strings(self): self._test_roundtrip() - def test_deferred(self): + def test_deferred_config(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, @@ -1574,7 +1596,121 @@ def test_deferred(self): ), }, ) - self._test_roundtrip() + self._test_roundtrip(assert_deferred=True) + + def test_defer_option_on_cols(self): + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + self.mapper_registry.map_imperatively( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, + edge.c.x1, + edge.c.y1, + ), + "end": sa.orm.composite( + Point, + edge.c.x2, + edge.c.y2, + ), + }, + ) + self._test_roundtrip( + assert_deferred=True, + options=( + defer(Edge.x1), + defer(Edge.x2), + defer(Edge.y1), + defer(Edge.y2), + ), + ) + + def test_defer_option_on_composite(self): + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + self.mapper_registry.map_imperatively( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, + edge.c.x1, + edge.c.y1, + ), + "end": sa.orm.composite( + Point, + edge.c.x2, + edge.c.y2, + ), + }, + ) + self._test_roundtrip( + assert_deferred=True, options=(defer(Edge.start), defer(Edge.end)) + ) + + @testing.variation("composite_only", [True, False]) + def test_load_only_option_on_composite(self, composite_only): + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + self.mapper_registry.map_imperatively( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, edge.c.x1, edge.c.y1, deferred=True + ), + "end": sa.orm.composite( + Point, + edge.c.x2, + edge.c.y2, + ), + }, + ) + + if composite_only: + self._test_roundtrip( + assert_deferred=False, + options=(load_only(Edge.start, Edge.end),), + ) + else: + self._test_roundtrip( + assert_deferred=False, + options=(load_only(Edge.start, Edge.x2, Edge.y2),), + ) + + def test_defer_option_on_composite_via_group(self): + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + self.mapper_registry.map_imperatively( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, edge.c.x1, edge.c.y1, deferred=True, group="s" + ), + "end": sa.orm.composite( + Point, edge.c.x2, edge.c.y2, deferred=True + ), + }, + ) + self._test_roundtrip( + assert_deferred=False, + options=(undefer_group("s"), undefer(Edge.end)), + ) def test_check_prop_type(self): edge, Edge, Point = ( diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index dd0d597b225..a961962d916 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -368,6 +368,14 @@ class PropagateAttrsTest(QueryTest): def propagate_cases(): return testing.combinations( (lambda: select(1), False), + (lambda User: select(User.id), True), + (lambda User: select(User.id + User.id), True), + (lambda User: select(User.id + User.id + User.id), True), + (lambda User: select(sum([User.id] * 10, User.id)), True), # type: ignore # noqa: E501 + ( + lambda User: select(literal_column("3") + User.id + User.id), + True, + ), (lambda User: select(func.count(User.id)), True), ( lambda User: select(1).select_from(select(User).subquery()), @@ -555,7 +563,7 @@ def test_aliased_delete(self, stmt_type: testing.Variation): self.assert_compile( stmt, - "DELETE FROM users AS users_1 " "WHERE users_1.name = :name_1", + "DELETE FROM users AS users_1 WHERE users_1.name = :name_1", ) @testing.variation("stmt_type", ["core", "orm"]) @@ -1797,7 +1805,7 @@ class InheritedTest(_poly_fixtures._Polymorphic): run_setup_mappers = "once" -class ExplicitWithPolymorhpicTest( +class ExplicitWithPolymorphicTest( _poly_fixtures._PolymorphicUnions, AssertsCompiledSQL ): __dialect__ = "default" @@ -2604,6 +2612,61 @@ def test_cte_recursive_handles_dupe_columns(self): "anon_1.primary_language FROM anon_1", ) + @testing.variation("named", [True, False]) + @testing.variation("flat", [True, False]) + def test_aliased_joined_entities(self, named, flat): + Company = self.classes.Company + Engineer = self.classes.Engineer + + if named: + e1 = aliased(Engineer, flat=flat, name="myengineer") + else: + e1 = aliased(Engineer, flat=flat) + + q = select(Company.name, e1.primary_language).join( + Company.employees.of_type(e1) + ) + + if not flat: + name = "anon_1" if not named else "myengineer" + + self.assert_compile( + q, + "SELECT companies.name, " + f"{name}.engineers_primary_language FROM companies " + "JOIN (SELECT people.person_id AS people_person_id, " + "people.company_id AS people_company_id, " + "people.name AS people_name, people.type AS people_type, " + "engineers.person_id AS engineers_person_id, " + "engineers.status AS engineers_status, " + "engineers.engineer_name AS engineers_engineer_name, " + "engineers.primary_language AS engineers_primary_language " + "FROM people JOIN engineers " + "ON people.person_id = engineers.person_id) AS " + f"{name} " + f"ON companies.company_id = {name}.people_company_id", + ) + elif named: + self.assert_compile( + q, + "SELECT companies.name, " + "myengineer_engineers.primary_language " + "FROM companies JOIN (people AS myengineer_people " + "JOIN engineers AS myengineer_engineers " + "ON myengineer_people.person_id = " + "myengineer_engineers.person_id) " + "ON companies.company_id = myengineer_people.company_id", + ) + else: + self.assert_compile( + q, + "SELECT companies.name, engineers_1.primary_language " + "FROM companies JOIN (people AS people_1 " + "JOIN engineers AS engineers_1 " + "ON people_1.person_id = engineers_1.person_id) " + "ON companies.company_id = people_1.company_id", + ) + class RawSelectTest(QueryTest, AssertsCompiledSQL): """older tests from test_query. Here, they are converted to use diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 7f0f504b569..fb37185f53e 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -5,6 +5,7 @@ T1/T2. """ + from itertools import count from sqlalchemy import bindparam @@ -1187,7 +1188,7 @@ def test_post_update_o2m(self): ], ), CompiledSQL( - "DELETE FROM person " "WHERE person.id = :id", + "DELETE FROM person WHERE person.id = :id", lambda ctx: [{"id": p.id}], ), CompiledSQL( diff --git a/test/orm/test_default_strategies.py b/test/orm/test_default_strategies.py index 657875aa9d8..178b03fe6f6 100644 --- a/test/orm/test_default_strategies.py +++ b/test/orm/test_default_strategies.py @@ -1,11 +1,18 @@ import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import util +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defaultload from sqlalchemy.orm import joinedload from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -738,3 +745,122 @@ def go(): eq_(a1.user, None) self.sql_count_(0, go) + + +class Issue11292Test(fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Parent(Base): + __tablename__ = "parent" + + id = Column(Integer, primary_key=True) + + extension = relationship( + "Extension", back_populates="parent", uselist=False + ) + + class Child(Base): + __tablename__ = "child" + + id = Column(Integer, primary_key=True) + + extensions = relationship("Extension", back_populates="child") + + class Extension(Base): + __tablename__ = "extension" + + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey(Parent.id)) + child_id = Column(Integer, ForeignKey(Child.id)) + + parent = relationship("Parent", back_populates="extension") + child = relationship("Child", back_populates="extensions") + + @classmethod + def insert_data(cls, connection): + Parent, Child, Extension = cls.classes("Parent", "Child", "Extension") + with Session(connection) as session: + for id_ in (1, 2, 3): + session.add(Parent(id=id_)) + session.add(Child(id=id_)) + session.add(Extension(id=id_, parent_id=id_, child_id=id_)) + session.commit() + + @testing.variation("load_as_option", [True, False]) + def test_defaultload_dont_propagate(self, load_as_option): + Parent, Child, Extension = self.classes("Parent", "Child", "Extension") + + session = fixture_session() + + # here, we want the defaultload() to go away on subsequent loads, + # becuase Parent.extension is propagate_to_loaders=False + query = ( + select(Parent) + .join(Extension) + .join(Child) + .options( + contains_eager(Parent.extension), + ( + defaultload(Parent.extension).options( + contains_eager(Extension.child) + ) + if load_as_option + else defaultload(Parent.extension).contains_eager( + Extension.child + ) + ), + ) + ) + + parents = session.scalars(query).all() + + eq_( + [(p.id, p.extension.id, p.extension.child.id) for p in parents], + [(1, 1, 1), (2, 2, 2), (3, 3, 3)], + ) + + session.expire_all() + + eq_( + [(p.id, p.extension.id, p.extension.child.id) for p in parents], + [(1, 1, 1), (2, 2, 2), (3, 3, 3)], + ) + + @testing.variation("load_as_option", [True, False]) + def test_defaultload_yes_propagate(self, load_as_option): + Parent, Child, Extension = self.classes("Parent", "Child", "Extension") + + session = fixture_session() + + # here, we want the defaultload() to go away on subsequent loads, + # becuase Parent.extension is propagate_to_loaders=False + query = select(Parent).options( + ( + defaultload(Parent.extension).options( + joinedload(Extension.child) + ) + if load_as_option + else defaultload(Parent.extension).joinedload(Extension.child) + ), + ) + + parents = session.scalars(query).all() + + eq_( + [(p.id, p.extension.id, p.extension.child.id) for p in parents], + [(1, 1, 1), (2, 2, 2), (3, 3, 3)], + ) + + session.expire_all() + + # this would be 9 without the joinedload + with self.assert_statement_count(testing.db, 6): + eq_( + [ + (p.id, p.extension.id, p.extension.child.id) + for p in parents + ], + [(1, 1, 1), (2, 2, 2), (3, 3, 3)], + ) diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index 66e3104a95d..dbfe3ef7974 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import TypeDecorator from sqlalchemy import union_all from sqlalchemy import util from sqlalchemy.orm import aliased @@ -2215,9 +2216,21 @@ class C(ComparableEntity, Base): c_expr = query_expression(literal(1)) + class CustomTimeStamp(TypeDecorator): + cache_ok = False + impl = Integer + + class HasNonCacheable(ComparableEntity, Base): + __tablename__ = "non_cacheable" + + id = Column(Integer, primary_key=True) + created = Column(CustomTimeStamp) + msg_translated = query_expression() + @classmethod def insert_data(cls, connection): A, A_default, B, C = cls.classes("A", "A_default", "B", "C") + (HasNonCacheable,) = cls.classes("HasNonCacheable") s = Session(connection) s.add_all( @@ -2230,6 +2243,7 @@ def insert_data(cls, connection): C(id=2, x=2), A_default(id=1, x=1, y=2), A_default(id=2, x=2, y=3), + HasNonCacheable(id=1, created=12345), ] ) @@ -2269,6 +2283,30 @@ def test_expr_default_value(self): ) eq_(c2.all(), [C(c_expr=4)]) + def test_non_cacheable_expr(self): + """test #10990""" + + HasNonCacheable = self.classes.HasNonCacheable + + for i in range(3): + s = fixture_session() + + stmt = ( + select(HasNonCacheable) + .where(HasNonCacheable.created > 10) + .options( + with_expression( + HasNonCacheable.msg_translated, + HasNonCacheable.created + 10, + ) + ) + ) + + eq_( + s.scalars(stmt).all(), + [HasNonCacheable(id=1, created=12345, msg_translated=12355)], + ) + def test_reuse_expr(self): A = self.classes.A diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 23248349cd2..bf545d6ad99 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -1995,7 +1995,7 @@ def test_values_specific_order_by(self): @testing.fails_on("mssql", "FIXME: unknown") @testing.fails_on( - "oracle", "Oracle doesn't support boolean expressions as " "columns" + "oracle", "Oracle doesn't support boolean expressions as columns" ) @testing.fails_on( "postgresql+pg8000", @@ -2269,11 +2269,13 @@ def _test(self, bound_session, session_present, expect_bound): eq_ignore_whitespace( str(q), - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = ?" - if expect_bound - else "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = ?" + if expect_bound + else "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1" + ), ) def test_query_bound_session(self): @@ -2307,7 +2309,6 @@ def go(): class RequirementsTest(fixtures.MappedTest): - """Tests the contract for user classes.""" @classmethod diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 83f3101f209..465e29929e9 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -275,6 +275,33 @@ def my_filter(self, arg): use_default_dialect=True, ) + @testing.combinations( + ("all", []), + ("one", exc.NoResultFound), + ("one_or_none", None), + argnames="method, expected", + ) + @testing.variation("add_to_session", [True, False]) + def test_transient_raise( + self, user_address_fixture, method, expected, add_to_session + ): + """test 11562""" + User, Address = user_address_fixture() + + u1 = User(name="u1") + if add_to_session: + sess = fixture_session() + sess.add(u1) + + meth = getattr(u1.addresses, method) + if expected is exc.NoResultFound: + with expect_raises_message( + exc.NoResultFound, "No row was found when one was required" + ): + meth() + else: + eq_(meth(), expected) + def test_detached_raise(self, user_address_fixture): """so filtering on a detached dynamic list raises an error...""" @@ -1444,9 +1471,11 @@ def test_delete_cascade( addresses_args={ "order_by": addresses.c.id, "backref": "user", - "cascade": "save-update" - if not delete_cascade_configured - else "all, delete", + "cascade": ( + "save-update" + if not delete_cascade_configured + else "all, delete" + ), } ) @@ -1519,9 +1548,11 @@ class A(decl_base): data: Mapped[str] bs: WriteOnlyMapped["B"] = relationship( # noqa: F821 passive_deletes=passive_deletes, - cascade="all, delete-orphan" - if cascade_deletes - else "save-update, merge", + cascade=( + "all, delete-orphan" + if cascade_deletes + else "save-update, merge" + ), order_by="B.id", ) @@ -1986,9 +2017,11 @@ def _assert_history(self, obj, compare, compare_passive=None): attributes.get_history( obj, attrname, - PassiveFlag.PASSIVE_NO_FETCH - if self.lazy == "write_only" - else PassiveFlag.PASSIVE_OFF, + ( + PassiveFlag.PASSIVE_NO_FETCH + if self.lazy == "write_only" + else PassiveFlag.PASSIVE_OFF + ), ), compare, ) diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index b1b6e86b794..7e0eca62c65 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -26,6 +26,8 @@ from sqlalchemy.orm import lazyload from sqlalchemy.orm import Load from sqlalchemy.orm import load_only +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import undefer @@ -41,6 +43,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -3696,8 +3699,180 @@ def test_joined_across(self): self._assert_result(q) -class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): +class InnerJoinSplicingWSecondarySelfRefTest( + fixtures.MappedTest, testing.AssertsCompiledSQL +): + """test for issue 11449""" + + __dialect__ = "default" + __backend__ = True # exercise hardcore join nesting on backends + + @classmethod + def define_tables(cls, metadata): + Table( + "kind", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + Table( + "node", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column( + "common_node_id", Integer, ForeignKey("node.id"), nullable=True + ), + Column("kind_id", Integer, ForeignKey("kind.id"), nullable=False), + ) + Table( + "node_group", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + Table( + "node_group_node", + metadata, + Column( + "node_group_id", + Integer, + ForeignKey("node_group.id"), + primary_key=True, + ), + Column( + "node_id", Integer, ForeignKey("node.id"), primary_key=True + ), + ) + + @classmethod + def setup_classes(cls): + class Kind(cls.Comparable): + pass + + class Node(cls.Comparable): + pass + + class NodeGroup(cls.Comparable): + pass + + class NodeGroupNode(cls.Comparable): + pass + + @classmethod + def insert_data(cls, connection): + kind = cls.tables.kind + connection.execute( + kind.insert(), [{"id": 1, "name": "a"}, {"id": 2, "name": "c"}] + ) + node = cls.tables.node + connection.execute( + node.insert(), + {"id": 1, "name": "nc", "kind_id": 2}, + ) + + connection.execute( + node.insert(), + {"id": 2, "name": "na", "kind_id": 1, "common_node_id": 1}, + ) + + node_group = cls.tables.node_group + node_group_node = cls.tables.node_group_node + + connection.execute(node_group.insert(), {"id": 1, "name": "group"}) + connection.execute( + node_group_node.insert(), + {"id": 1, "node_group_id": 1, "node_id": 2}, + ) + connection.commit() + + @testing.fixture(params=["common_nodes,kind", "kind,common_nodes"]) + def node_fixture(self, request): + Kind, Node, NodeGroup, NodeGroupNode = self.classes( + "Kind", "Node", "NodeGroup", "NodeGroupNode" + ) + kind, node, node_group, node_group_node = self.tables( + "kind", "node", "node_group", "node_group_node" + ) + self.mapper_registry.map_imperatively(Kind, kind) + + if request.param == "common_nodes,kind": + self.mapper_registry.map_imperatively( + Node, + node, + properties=dict( + common_node=relationship( + "Node", + remote_side=[node.c.id], + ), + kind=relationship(Kind, innerjoin=True, lazy="joined"), + ), + ) + elif request.param == "kind,common_nodes": + self.mapper_registry.map_imperatively( + Node, + node, + properties=dict( + kind=relationship(Kind, innerjoin=True, lazy="joined"), + common_node=relationship( + "Node", + remote_side=[node.c.id], + ), + ), + ) + + self.mapper_registry.map_imperatively( + NodeGroup, + node_group, + properties=dict( + nodes=relationship(Node, secondary="node_group_node") + ), + ) + self.mapper_registry.map_imperatively(NodeGroupNode, node_group_node) + + def test_select(self, node_fixture): + Kind, Node, NodeGroup, NodeGroupNode = self.classes( + "Kind", "Node", "NodeGroup", "NodeGroupNode" + ) + + session = fixture_session() + with self.sql_execution_asserter(testing.db) as asserter: + group = ( + session.scalars( + select(NodeGroup) + .where(NodeGroup.name == "group") + .options( + joinedload(NodeGroup.nodes).joinedload( + Node.common_node + ) + ) + ) + .unique() + .one_or_none() + ) + + eq_(group.nodes[0].common_node.kind.name, "c") + eq_(group.nodes[0].kind.name, "a") + + asserter.assert_( + RegexSQL( + r"SELECT .* FROM node_group " + r"LEFT OUTER JOIN \(node_group_node AS node_group_node_1 " + r"JOIN node AS node_2 " + r"ON node_2.id = node_group_node_1.node_id " + r"JOIN kind AS kind_\d ON kind_\d.id = node_2.kind_id\) " + r"ON node_group.id = node_group_node_1.node_group_id " + r"LEFT OUTER JOIN " + r"\(node AS node_1 JOIN kind AS kind_\d " + r"ON kind_\d.id = node_1.kind_id\) " + r"ON node_1.id = node_2.common_node_id " + r"WHERE node_group.name = :name_5" + ) + ) + + +class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """test #2188""" __dialect__ = "default" @@ -3892,7 +4067,6 @@ def test_standalone_negated(self): class LoadOnExistingTest(_fixtures.FixtureTest): - """test that loaders from a base Query fully populate.""" run_inserts = "once" @@ -5309,7 +5483,6 @@ def go(): class CorrelatedSubqueryTest(fixtures.MappedTest): - """tests for #946, #947, #948. The "users" table is joined to "stuff", and the relationship @@ -6633,7 +6806,6 @@ def go(): class SecondaryOptionsTest(fixtures.MappedTest): - """test that the contains_eager() option doesn't bleed into a secondary load.""" @@ -6940,3 +7112,94 @@ def go(): ) self.assert_sql_count(testing.db, go, 1) + + +class NestedInnerjoinTestIssue11965( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + """test for issue #11965, regression from #11449""" + + __dialect__ = "default" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Source(Base): + __tablename__ = "source" + id: Mapped[int] = mapped_column(primary_key=True) + + class Day(Base): + __tablename__ = "day" + id: Mapped[int] = mapped_column(primary_key=True) + + class Run(Base): + __tablename__ = "run" + id: Mapped[int] = mapped_column(primary_key=True) + + source_id: Mapped[int] = mapped_column( + ForeignKey(Source.id), nullable=False + ) + source = relationship(Source, lazy="joined", innerjoin=True) + + day = relationship( + Day, + lazy="joined", + innerjoin=True, + ) + day_id: Mapped[int] = mapped_column( + ForeignKey(Day.id), nullable=False + ) + + class Event(Base): + __tablename__ = "event" + + id: Mapped[int] = mapped_column(primary_key=True) + run_id: Mapped[int] = mapped_column( + ForeignKey(Run.id), nullable=False + ) + run = relationship(Run, lazy="joined", innerjoin=True) + + class Room(Base): + __tablename__ = "room" + + id: Mapped[int] = mapped_column(primary_key=True) + event_id: Mapped[int] = mapped_column( + ForeignKey(Event.id), nullable=False + ) + event = relationship(Event, foreign_keys=event_id, lazy="joined") + + @classmethod + def insert_data(cls, connection): + Room, Run, Source, Event, Day = cls.classes( + "Room", "Run", "Source", "Event", "Day" + ) + run = Run(source=Source(), day=Day()) + event = Event(run=run) + room = Room(event=event) + with Session(connection) as session: + session.add(room) + session.commit() + + def test_compile(self): + Room = self.classes.Room + self.assert_compile( + select(Room), + "SELECT room.id, room.event_id, source_1.id AS id_1, " + "day_1.id AS id_2, run_1.id AS id_3, run_1.source_id, " + "run_1.day_id, event_1.id AS id_4, event_1.run_id " + "FROM room LEFT OUTER JOIN " + "(event AS event_1 " + "JOIN run AS run_1 ON run_1.id = event_1.run_id " + "JOIN day AS day_1 ON day_1.id = run_1.day_id " + "JOIN source AS source_1 ON source_1.id = run_1.source_id) " + "ON event_1.id = room.event_id", + ) + + def test_roundtrip(self): + Room = self.classes.Room + session = fixture_session() + rooms = session.scalars(select(Room)).unique().all() + session.close() + # verify eager-loaded correctly + assert rooms[0].event.run.day diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 56d16dfcd76..2b24e47469d 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -385,14 +385,16 @@ def do_orm_execute(ctx): bind_mapper=ctx.bind_mapper, all_mappers=ctx.all_mappers, is_select=ctx.is_select, + is_from_statement=ctx.is_from_statement, + is_insert=ctx.is_insert, is_update=ctx.is_update, is_delete=ctx.is_delete, is_orm_statement=ctx.is_orm_statement, is_relationship_load=ctx.is_relationship_load, is_column_load=ctx.is_column_load, - lazy_loaded_from=ctx.lazy_loaded_from - if ctx.is_select - else None, + lazy_loaded_from=( + ctx.lazy_loaded_from if ctx.is_select else None + ), ) return canary @@ -421,6 +423,8 @@ def test_non_orm_statements(self, stmt, is_select): bind_mapper=None, all_mappers=[], is_select=is_select, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=False, @@ -451,6 +455,8 @@ def test_all_mappers_accessor_one(self): bind_mapper=inspect(User), all_mappers=[inspect(User), inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -475,6 +481,8 @@ def test_all_mappers_accessor_two(self): bind_mapper=None, all_mappers=[], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=False, @@ -501,6 +509,8 @@ def test_all_mappers_accessor_three(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], # Address not in results is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -531,6 +541,54 @@ def test_select_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=False, + lazy_loaded_from=None, + ), + call.options( + bind_mapper=inspect(User), + all_mappers=[inspect(User)], + is_select=True, + is_from_statement=False, + is_insert=False, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=True, + lazy_loaded_from=None, + ), + ], + ) + + def test_select_from_statement_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + s1 = select(User).filter_by(id=7) + u1 = sess.execute(select(User).from_statement(s1)).scalar_one() + + sess.expire(u1) + + eq_(u1.name, "jack") + + eq_( + canary.mock_calls, + [ + call.options( + bind_mapper=inspect(User), + all_mappers=[inspect(User)], + is_select=True, + is_from_statement=True, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -542,6 +600,8 @@ def test_select_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -570,6 +630,8 @@ def test_lazyload_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -581,6 +643,8 @@ def test_lazyload_flags(self): bind_mapper=inspect(Address), all_mappers=[inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -611,6 +675,8 @@ def test_selectinload_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -622,6 +688,8 @@ def test_selectinload_flags(self): bind_mapper=inspect(Address), all_mappers=[inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -652,6 +720,8 @@ def test_subqueryload_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -663,6 +733,8 @@ def test_subqueryload_flags(self): bind_mapper=inspect(Address), all_mappers=[inspect(Address), inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -673,24 +745,45 @@ def test_subqueryload_flags(self): ], ) - def test_update_delete_flags(self): + @testing.variation( + "stmt_type", + [ + ("insert", testing.requires.insert_returning), + ("update", testing.requires.update_returning), + ("delete", testing.requires.delete_returning), + ], + ) + @testing.variation("from_stmt", [True, False]) + def test_update_delete_flags(self, stmt_type, from_stmt): User, Address = self.classes("User", "Address") sess = Session(testing.db, future=True) canary = self._flag_fixture(sess) - sess.execute( - delete(User) - .filter_by(id=18) - .execution_options(synchronize_session="evaluate") - ) - sess.execute( - update(User) - .filter_by(id=18) - .values(name="eighteen") - .execution_options(synchronize_session="evaluate") - ) + if stmt_type.delete: + stmt = ( + delete(User) + .filter_by(id=18) + .execution_options(synchronize_session="evaluate") + ) + elif stmt_type.update: + stmt = ( + update(User) + .filter_by(id=18) + .values(name="eighteen") + .execution_options(synchronize_session="evaluate") + ) + elif stmt_type.insert: + stmt = insert(User).values(name="eighteen") + else: + stmt_type.fail() + + if from_stmt: + stmt = select(User).from_statement(stmt.returning(User)) + + result = sess.execute(stmt) + result.close() eq_( canary.mock_calls, @@ -699,19 +792,10 @@ def test_update_delete_flags(self): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=False, - is_update=False, - is_delete=True, - is_orm_statement=True, - is_relationship_load=False, - is_column_load=False, - lazy_loaded_from=None, - ), - call.options( - bind_mapper=inspect(User), - all_mappers=[inspect(User)], - is_select=False, - is_update=True, - is_delete=False, + is_from_statement=bool(from_stmt), + is_insert=stmt_type.insert, + is_update=stmt_type.update, + is_delete=stmt_type.delete, is_orm_statement=True, is_relationship_load=False, is_column_load=False, @@ -1545,9 +1629,11 @@ def _combinations(fn): ( lambda session: session, "loaded_as_persistent", - lambda session, instance: instance.unloaded - if instance.__class__.__name__ == "A" - else None, + lambda session, instance: ( + instance.unloaded + if instance.__class__.__name__ == "A" + else None + ), ), argnames="target, event_name, fn", )(fn) @@ -1669,8 +1755,7 @@ class C(B): class DeferredMapperEventsTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): - - """ "test event listeners against unmapped classes. + """test event listeners against unmapped classes. This incurs special logic. Note if we ever do the "remove" case, it has to get all of these, too. diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 51c86a5f1da..e0d75db7e16 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -1893,7 +1893,9 @@ def test_no_uniquing_cols(self, with_entities): .order_by(User.id) ) - compile_state = ORMSelectCompileState.create_for_statement(stmt, None) + compile_state = ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) is_(compile_state._primary_entity, None) def test_column_queries_one(self): diff --git a/test/orm/test_hasparent.py b/test/orm/test_hasparent.py index 8f61c11970d..72c90b6d5c9 100644 --- a/test/orm/test_hasparent.py +++ b/test/orm/test_hasparent.py @@ -1,4 +1,5 @@ """test the current state of the hasparent() flag.""" + from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import testing diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 4ab9617123c..9bb8071984d 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -21,7 +21,9 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm import foreign from sqlalchemy.orm import relationship +from sqlalchemy.orm import remote from sqlalchemy.orm import Session from sqlalchemy.orm import with_parent from sqlalchemy.testing import assert_raises @@ -993,7 +995,6 @@ def go(): class GetterStateTest(_fixtures.FixtureTest): - """test lazyloader on non-existent attribute returns expected attribute symbols, maintain expected state""" @@ -1080,11 +1081,13 @@ def _u_ad_fixture(self, populate_user, dont_use_get=False): properties={ "user": relationship( User, - primaryjoin=and_( - users.c.id == addresses.c.user_id, users.c.id != 27 - ) - if dont_use_get - else None, + primaryjoin=( + and_( + users.c.id == addresses.c.user_id, users.c.id != 27 + ) + if dont_use_get + else None + ), back_populates="addresses", ) }, @@ -1269,6 +1272,54 @@ def go(): self.assert_sql_count(testing.db, go, 1) + @testing.fixture() + def composite_overlapping_fixture(self, decl_base, connection): + def go(allow_partial_pks): + + class Section(decl_base): + __tablename__ = "sections" + year = Column(Integer, primary_key=True) + idx = Column(Integer, primary_key=True) + parent_idx = Column(Integer) + + if not allow_partial_pks: + __mapper_args__ = {"allow_partial_pks": False} + + ForeignKeyConstraint((year, parent_idx), (year, idx)) + + parent = relationship( + "Section", + primaryjoin=and_( + year == remote(year), + foreign(parent_idx) == remote(idx), + ), + ) + + decl_base.metadata.create_all(connection) + connection.commit() + + with Session(connection) as sess: + sess.add(Section(year=5, idx=1, parent_idx=None)) + sess.commit() + + return Section + + return go + + @testing.variation("allow_partial_pks", [True, False]) + def test_composite_m2o_load_partial_pks( + self, allow_partial_pks, composite_overlapping_fixture + ): + Section = composite_overlapping_fixture(allow_partial_pks) + + session = fixture_session() + section = session.get(Section, (5, 1)) + + with self.assert_statement_count( + testing.db, 1 if allow_partial_pks else 0 + ): + testing.is_none(section.parent) + class CorrelatedTest(fixtures.MappedTest): @classmethod diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index f90803d6e4d..4b3bb99c5b1 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2010,12 +2010,12 @@ def _x(self): ) # object gracefully handles this condition - assert not hasattr(User.x, "__name__") + assert not hasattr(User.x, "foobar") assert not hasattr(User.x, "comparator") m.add_property("some_attr", column_property(users.c.name)) - assert not hasattr(User.x, "__name__") + assert not hasattr(User.x, "foobar") assert hasattr(User.x, "comparator") def test_synonym_of_non_property_raises(self): @@ -2555,7 +2555,6 @@ class B(OldStyle, NewStyle): class RequirementsTest(fixtures.MappedTest): - """Tests the contract for user classes.""" @classmethod @@ -3484,7 +3483,7 @@ def test_load_options(self, use_bound): self.assert_compile( stmt, - "SELECT users.id, " "users.name " "FROM users", + "SELECT users.id, users.name FROM users", ) is_true(um.configured) diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 0c8e2651cdb..c313c4b33da 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1476,9 +1476,7 @@ def test_relationship_population_maintained( CountStatements( 0 if load.noload - else 1 - if merge_persistent.merge_persistent - else 2 + else 1 if merge_persistent.merge_persistent else 2 ) ) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 7c96539583f..c6058a80b3b 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -419,7 +419,10 @@ def _option_fixture(self, *arg): # loader option works this way right now; the rest all use # defaultload() for the "chain" elements return strategy_options._generate_from_keys( - strategy_options.Load.contains_eager, arg, True, {} + strategy_options.Load.contains_eager, + arg, + True, + dict(_propagate_to_loaders=True), ) @testing.combinations( @@ -976,10 +979,12 @@ def test_wrong_type_in_option_cls(self, first_element): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Item], - lambda: (joinedload(Keyword),) - if first_element - else (Load(Item).joinedload(Keyword),), - "expected ORM mapped attribute for loader " "strategy argument", + lambda: ( + (joinedload(Keyword),) + if first_element + else (Load(Item).joinedload(Keyword),) + ), + "expected ORM mapped attribute for loader strategy argument", ) @testing.combinations( @@ -990,9 +995,11 @@ def test_wrong_type_in_option_any_random_type(self, rando, first_element): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - lambda: (joinedload(rando),) - if first_element - else (Load(Item).joinedload(rando)), + lambda: ( + (joinedload(rando),) + if first_element + else (Load(Item).joinedload(rando)) + ), "expected ORM mapped attribute for loader strategy argument", ) @@ -1002,9 +1009,11 @@ def test_wrong_type_in_option_descriptor(self, first_element): self._assert_eager_with_entity_exception( [OrderWProp], - lambda: (joinedload(OrderWProp.some_attr),) - if first_element - else (Load(OrderWProp).joinedload(OrderWProp.some_attr),), + lambda: ( + (joinedload(OrderWProp.some_attr),) + if first_element + else (Load(OrderWProp).joinedload(OrderWProp.some_attr),) + ), "expected ORM mapped attribute for loader strategy argument", ) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 96dec4a60b7..18904cc3861 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -654,6 +654,17 @@ def test_composite_column_mapped_collection(self): ) is_not_none(collections.collection_adapter(repickled.addresses)) + def test_bulk_save_objects_defaults_pickle(self): + "Test for #11332" + users = self.tables.users + + self.mapper_registry.map_imperatively(User, users) + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + state = pickle.dumps(pes) + pickle.loads(state) + class OptionsTest(_Polymorphic): def test_options_of_type(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 3057087e43b..0e30f58ca16 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -697,8 +697,10 @@ def process_result_value(self, value, dialect): sa_exc.InvalidRequestError, r"Can't apply uniqueness to row tuple " r"containing value of type MyType\(\); " - rf"""{'the values returned appear to be' - if uncertain else 'this datatype produces'} """ + rf"""{ + 'the values returned appear to be' + if uncertain else 'this datatype produces' + } """ r"non-hashable values", ): result = s.execute(q).unique().all() @@ -1974,6 +1976,15 @@ def test_in_on_relationship_not_supported(self): assert_raises(NotImplementedError, Address.user.in_, [User(id=5)]) + def test_in_instrumented_attribute(self): + """test #12019""" + User = self.classes.User + + self._test( + User.id.in_([User.id, User.name]), + "users.id IN (users.id, users.name)", + ) + def test_neg(self): User = self.classes.User @@ -3561,7 +3572,7 @@ def test_filter_by_against_label(self): self.assert_compile( q1, - "SELECT users.id AS foo FROM users " "WHERE users.name = :name_1", + "SELECT users.id AS foo FROM users WHERE users.name = :name_1", ) def test_empty_filters(self): @@ -4346,7 +4357,7 @@ def test_exists(self): q1 = sess.query(User) self.assert_compile( sess.query(q1.exists()), - "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1", + "SELECT EXISTS (SELECT 1 FROM users) AS anon_1", ) q2 = sess.query(User).filter(User.name == "fred") @@ -4364,7 +4375,7 @@ def test_exists_col_expression(self): q1 = sess.query(User.id) self.assert_compile( sess.query(q1.exists()), - "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1", + "SELECT EXISTS (SELECT 1 FROM users) AS anon_1", ) def test_exists_labeled_col_expression(self): @@ -4374,7 +4385,7 @@ def test_exists_labeled_col_expression(self): q1 = sess.query(User.id.label("foo")) self.assert_compile( sess.query(q1.exists()), - "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1", + "SELECT EXISTS (SELECT 1 FROM users) AS anon_1", ) def test_exists_arbitrary_col_expression(self): @@ -4384,7 +4395,7 @@ def test_exists_arbitrary_col_expression(self): q1 = sess.query(func.foo(User.id)) self.assert_compile( sess.query(q1.exists()), - "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1", + "SELECT EXISTS (SELECT 1 FROM users) AS anon_1", ) def test_exists_col_warning(self): @@ -5176,7 +5187,7 @@ def test_one_prefix(self): User = self.classes.User sess = fixture_session() query = sess.query(User.name).prefix_with("PREFIX_1") - expected = "SELECT PREFIX_1 " "users.name AS users_name FROM users" + expected = "SELECT PREFIX_1 users.name AS users_name FROM users" self.assert_compile(query, expected, dialect=default.DefaultDialect()) def test_one_suffix(self): @@ -5192,7 +5203,7 @@ def test_many_prefixes(self): sess = fixture_session() query = sess.query(User.name).prefix_with("PREFIX_1", "PREFIX_2") expected = ( - "SELECT PREFIX_1 PREFIX_2 " "users.name AS users_name FROM users" + "SELECT PREFIX_1 PREFIX_2 users.name AS users_name FROM users" ) self.assert_compile(query, expected, dialect=default.DefaultDialect()) @@ -5535,6 +5546,25 @@ def test_eagerload_opt_disable(self): ) eq_(len(q.all()), 4) + @testing.combinations( + "joined", + "subquery", + "selectin", + "select", + "immediate", + argnames="lazy", + ) + def test_eagerload_config_disable(self, lazy): + self._eagerload_mappings(addresses_lazy=lazy) + + User = self.classes.User + sess = fixture_session() + q = sess.query(User).enable_eagerloads(False).yield_per(1) + objs = q.all() + eq_(len(objs), 4) + for obj in objs: + assert "addresses" not in obj.__dict__ + def test_m2o_joinedload_not_others(self): self._eagerload_mappings(addresses_lazy="joined") Address = self.classes.Address diff --git a/test/orm/test_recursive_loaders.py b/test/orm/test_recursive_loaders.py index 10582e71131..e6ce5ccd7ef 100644 --- a/test/orm/test_recursive_loaders.py +++ b/test/orm/test_recursive_loaders.py @@ -1,3 +1,5 @@ +import logging.handlers + import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import Integer @@ -11,7 +13,6 @@ from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message -from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -258,13 +259,27 @@ def test_unlimited_recursion(self, loader_fn, limited_cache_conn): result = s.scalars(stmt) self._assert_depth(result.one(), 200) + @testing.fixture + def capture_log(self, testing_engine): + existing_level = logging.getLogger("sqlalchemy.engine").level + + buf = logging.handlers.BufferingHandler(100) + logging.getLogger("sqlalchemy.engine").addHandler(buf) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + yield buf + logging.getLogger("sqlalchemy.engine").setLevel(existing_level) + logging.getLogger("sqlalchemy.engine").removeHandler(buf) + @testing.combinations(selectinload, immediateload, argnames="loader_fn") @testing.combinations(4, 9, 12, 25, 41, 55, argnames="depth") @testing.variation("disable_cache", [True, False]) def test_warning_w_no_recursive_opt( - self, loader_fn, depth, limited_cache_conn, disable_cache + self, loader_fn, depth, limited_cache_conn, disable_cache, capture_log ): + buf = capture_log + connection = limited_cache_conn(27) + connection._echo = True Node = self.classes.Node @@ -280,21 +295,24 @@ def test_warning_w_no_recursive_opt( else: exec_opts = {} - # note this is a magic number, it's not important that it's exact, - # just that when someone makes a huge recursive thing, - # it warns - if depth > 8 and not disable_cache: - with expect_warnings( - "Loader depth for query is excessively deep; " - "caching will be disabled for additional loaders." - ): - with Session(connection) as s: - result = s.scalars(stmt, execution_options=exec_opts) - self._assert_depth(result.one(), depth) - else: - with Session(connection) as s: - result = s.scalars(stmt, execution_options=exec_opts) - self._assert_depth(result.one(), depth) + with Session(connection) as s: + result = s.scalars(stmt, execution_options=exec_opts) + self._assert_depth(result.one(), depth) + + if not disable_cache: + # note this is a magic number, it's not important that it's + # exact, just that when someone makes a huge recursive thing, + # it disables caching and notes in the logs + if depth > 8: + eq_( + buf.buffer[-1].message[0:55], + "[caching disabled (excess depth for " + "ORM loader options)", + ) + else: + assert buf.buffer[-1].message.startswith( + "[cached since" if i > 0 else "[generated in" + ) if disable_cache: clen = len(connection.engine._compiled_cache) diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index aebdf6922ae..29720f7dc86 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +from functools import partial import random from typing import List @@ -1661,7 +1662,9 @@ class HasTemporal: """Mixin that identifies a class as having a timestamp column""" timestamp = Column( - DateTime, default=datetime.datetime.utcnow, nullable=False + DateTime, + default=partial(datetime.datetime.now, datetime.timezone.utc), + nullable=False, ) cls.HasTemporal = HasTemporal @@ -1908,9 +1911,11 @@ def go(value): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -1976,9 +1981,11 @@ def go(value): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2033,9 +2040,11 @@ def go(value): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2062,6 +2071,55 @@ def go(value): ), ) + @testing.combinations( + (selectinload,), + (subqueryload,), + (lazyload,), + (joinedload,), + argnames="opt", + ) + @testing.variation("use_in", [True, False]) + def test_opts_local_criteria_cachekey( + self, opt, user_address_fixture, use_in + ): + """test #11173""" + User, Address = user_address_fixture + + s = Session(testing.db, future=True) + + def go(value): + if use_in: + expr = ~Address.email_address.in_([value, "some_email"]) + else: + expr = Address.email_address != value + stmt = ( + select(User) + .options( + opt(User.addresses.and_(expr)), + ) + .order_by(User.id) + ) + result = s.execute(stmt) + return result + + for value in ( + "ed@wood.com", + "ed@lala.com", + "ed@wood.com", + "ed@lala.com", + ): + s.close() + result = go(value) + + eq_( + result.scalars().unique().all(), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), + ) + @testing.combinations( (joinedload, False), (lazyload, True), @@ -2129,9 +2187,11 @@ def go(value): eq_( result, - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) @testing.combinations((True,), (False,), argnames="use_compiled_cache") @@ -2237,9 +2297,11 @@ def go(value): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2309,9 +2371,11 @@ def go(value): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2397,6 +2461,28 @@ def test_select_joinm2m_aliased_local_criteria(self, order_item_fixture): "AND items_1.description != :description_1", ) + def test_use_secondary_table_in_criteria(self, order_item_fixture): + """test #11010 , regression caused by #9779""" + + Order, Item = order_item_fixture + order_items = self.tables.order_items + + stmt = select(Order).join( + Order.items.and_( + order_items.c.item_id > 1, Item.description != "description" + ) + ) + + self.assert_compile( + stmt, + "SELECT orders.id, orders.user_id, orders.address_id, " + "orders.description, orders.isopen FROM orders JOIN order_items " + "AS order_items_1 ON orders.id = order_items_1.order_id " + "JOIN items ON items.id = order_items_1.item_id " + "AND order_items_1.item_id > :item_id_1 " + "AND items.description != :description_1", + ) + class SubqueryCriteriaTest(fixtures.DeclarativeMappedTest): """test #10223""" diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index d6b886be151..104d67f4075 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -183,7 +183,6 @@ def _assert_raises_no_local_remote(self, fn, relname, *arg, **kw): class DependencyTwoParentTest(fixtures.MappedTest): - """Test flush() when a mapper is dependent on multiple relationships""" run_setup_mappers = "once" @@ -430,12 +429,13 @@ def test_collection_relationship_overrides_fk(self): class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): - """Tests the ultimate join condition, a single column that points to itself, e.g. within a SQL function or similar. The test is against a materialized path setup. - this is an **extremely** unusual case:: + this is an **extremely** unusual case: + + .. sourcecode:: text Entity ------ @@ -1022,12 +1022,13 @@ def test_works_two(self): class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): - """Tests a composite FK where, in the relationship(), one col points to itself in the same table. - this is a very unusual case:: + this is a very unusual case: + + .. sourcecode:: text company employee ---------- ---------- @@ -1334,7 +1335,8 @@ def _test_no_overwrite(self, sess, expect_failure): # this happens assert_raises_message( AssertionError, - "Dependency rule tried to blank-out primary key column " + "Dependency rule on column 'employee_t.company_id' " + "tried to blank-out primary key column " "'employee_t.company_id'", sess.flush, ) @@ -1505,7 +1507,6 @@ def test_joins_fully(self): class SynonymsAsFKsTest(fixtures.MappedTest): - """Syncrules on foreign keys that are also primary""" @classmethod @@ -1577,7 +1578,6 @@ def test_synonym_fk(self): class FKsAsPksTest(fixtures.MappedTest): - """Syncrules on foreign keys that are also primary""" @classmethod @@ -1669,7 +1669,7 @@ def test_no_delete_PK_AtoB(self): assert_raises_message( AssertionError, - "Dependency rule tried to blank-out " + "Dependency rule on column 'tableA.id' tried to blank-out " "primary key column 'tableB.id' on instance ", sess.flush, ) @@ -1696,7 +1696,7 @@ def test_no_delete_PK_BtoA(self): b1.a = None assert_raises_message( AssertionError, - "Dependency rule tried to blank-out " + "Dependency rule on column 'tableA.id' tried to blank-out " "primary key column 'tableB.id' on instance ", sess.flush, ) @@ -1862,7 +1862,6 @@ def test_delete_manual_BtoA(self): class UniqueColReferenceSwitchTest(fixtures.MappedTest): - """test a relationship based on a primary join against a unique non-pk column""" @@ -1927,7 +1926,6 @@ def test_switch_parent(self): class RelationshipToSelectableTest(fixtures.MappedTest): - """Test a map to a select that relates to a map to the table.""" @classmethod @@ -2021,7 +2019,6 @@ class LineItem(BasicEntity): class FKEquatedToConstantTest(fixtures.MappedTest): - """test a relationship with a non-column entity in the primary join, is not viewonly, and also has the non-column's clause mentioned in the foreign keys list. @@ -2158,7 +2155,6 @@ def test_backref(self): class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): - """test ambiguous joins due to FKs on both sides treated as self-referential. @@ -2253,7 +2249,6 @@ def test_mapping(self): class ManualBackrefTest(_fixtures.FixtureTest): - """Test explicit relationships that are backrefs to each other.""" run_inserts = None @@ -2392,7 +2387,6 @@ def test_back_propagates_not_relationship(self): class NoLoadBackPopulates(_fixtures.FixtureTest): - """test the noload stratgegy which unlike others doesn't use lazyloader to set up instrumentation""" @@ -2639,7 +2633,6 @@ def teardown_test(self): class TypeMatchTest(fixtures.MappedTest): - """test errors raised when trying to add items whose type is not handled by a relationship""" @@ -2907,7 +2900,6 @@ class T2(BasicEntity): class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): - """test op() in conjunction with join conditions""" run_create_tables = run_deletes = None @@ -3185,7 +3177,6 @@ class B(ComparableEntity): class ViewOnlyOverlappingNames(fixtures.MappedTest): - """'viewonly' mappings with overlapping PK column names.""" @classmethod @@ -3441,7 +3432,6 @@ def rel(): class ViewOnlyUniqueNames(fixtures.MappedTest): - """'viewonly' mappings with unique PK column names.""" @classmethod @@ -3543,7 +3533,6 @@ class C3(BasicEntity): class ViewOnlyLocalRemoteM2M(fixtures.TestBase): - """test that local-remote is correctly determined for m2m""" def test_local_remote(self, registry): @@ -3582,7 +3571,6 @@ class B: class ViewOnlyNonEquijoin(fixtures.MappedTest): - """'viewonly' mappings based on non-equijoins.""" @classmethod @@ -3644,7 +3632,6 @@ class Bar(ComparableEntity): class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): - """'viewonly' mappings that contain the same 'remote' column twice""" @classmethod @@ -3718,7 +3705,6 @@ class Bar(ComparableEntity): class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): - """'viewonly' mappings that contain the same 'local' column twice""" @classmethod @@ -3793,7 +3779,6 @@ class Bar(ComparableEntity): class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): - """'viewonly' mappings with a complex join condition.""" @classmethod @@ -3995,7 +3980,6 @@ def go(): class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): - """test a complex annotation using between(). Using declarative here as an integration test for the local() @@ -4612,7 +4596,6 @@ class B(Base): class SecondaryNestedJoinTest( fixtures.MappedTest, AssertsCompiledSQL, testing.AssertsExecutionResults ): - """test support for a relationship where the 'secondary' table is a compound join(). @@ -6380,7 +6363,6 @@ def go(): class RelationDeprecationTest(fixtures.MappedTest): - """test usage of the old 'relation' function.""" run_inserts = "once" diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 3a7029110e4..d4ea0e29195 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -1,4 +1,5 @@ """Generic mapping to Select statements""" + import sqlalchemy as sa from sqlalchemy import column from sqlalchemy import Integer diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index c9907c76515..d46362abdc8 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -3340,7 +3340,7 @@ def test_use_join_parent_criteria_degrade_on_defer(self): "FROM a WHERE a.id IN (__[POSTCOMPILE_id_1]) ORDER BY a.id", [{"id_1": [1, 3]}], ), - # in the very unlikely case that the the FK col on parent is + # in the very unlikely case that the FK col on parent is # deferred, we degrade to the JOIN version so that we don't need to # emit either for each parent object individually, or as a second # query for them. @@ -3429,9 +3429,9 @@ def test_use_join_parent_degrade_on_defer(self): testing.db, q.all, CompiledSQL( - "SELECT a.id AS a_id, a.q AS a_q " "FROM a ORDER BY a.id", [{}] + "SELECT a.id AS a_id, a.q AS a_q FROM a ORDER BY a.id", [{}] ), - # in the very unlikely case that the the FK col on parent is + # in the very unlikely case that the FK col on parent is # deferred, we degrade to the JOIN version so that we don't need to # emit either for each parent object individually, or as a second # query for them. diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index 00564cfb656..538c77c0cee 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -3759,3 +3759,81 @@ def test_issue_6419(self): ), ) s.close() + + +class Issue11173Test(fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class SubItem(Base): + __tablename__ = "sub_items" + + id = Column(Integer, primary_key=True, autoincrement=True) + item_id = Column(Integer, ForeignKey("items.id")) + name = Column(String(50)) + number = Column(Integer) + + class Item(Base): + __tablename__ = "items" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(50)) + number = Column(Integer) + sub_items = relationship("SubItem", backref="item") + + @classmethod + def insert_data(cls, connection): + Item, SubItem = cls.classes("Item", "SubItem") + + with Session(connection) as sess: + number_of_items = 50 + number_of_sub_items = 5 + + items = [ + Item(name=f"Item:{i}", number=i) + for i in range(number_of_items) + ] + sess.add_all(items) + for item in items: + item.sub_items = [ + SubItem(name=f"SubItem:{item.id}:{i}", number=i) + for i in range(number_of_sub_items) + ] + sess.commit() + + @testing.variation("use_in", [True, False]) + def test_multiple_queries(self, use_in): + Item, SubItem = self.classes("Item", "SubItem") + + for sub_item_number in (1, 2, 3): + s = fixture_session() + base_query = s.query(Item) + + base_query = base_query.filter(Item.number > 5, Item.number <= 10) + + if use_in: + base_query = base_query.options( + subqueryload( + Item.sub_items.and_( + SubItem.number.in_([sub_item_number, 18, 12]) + ) + ) + ) + else: + base_query = base_query.options( + subqueryload( + Item.sub_items.and_(SubItem.number == sub_item_number) + ) + ) + + items = list(base_query) + + eq_(len(items), 5) + + for item in items: + sub_items = list(item.sub_items) + eq_(len(sub_items), 1) + + for sub_item in sub_items: + eq_(sub_item.number, sub_item_number) diff --git a/test/orm/test_sync.py b/test/orm/test_sync.py index c8f511f447a..10d73cb8d64 100644 --- a/test/orm/test_sync.py +++ b/test/orm/test_sync.py @@ -145,7 +145,7 @@ def test_clear_pk(self): eq_(b1.obj().__dict__["id"], 8) assert_raises_message( AssertionError, - "Dependency rule tried to blank-out primary key " + "Dependency rule on column 't1.id' tried to blank-out primary key " "column 't2.id' on instance ' 0.97 and i > 0, unique=random.random() > 0.97 and i > 0, ) diff --git a/test/perf/orm2010.py b/test/perf/orm2010.py index c069430fb1e..520944c9f0b 100644 --- a/test/perf/orm2010.py +++ b/test/perf/orm2010.py @@ -149,14 +149,12 @@ def status(msg): print("Total cpu seconds: %.2f" % stats.total_tt) print( "Total execute calls: %d" - % counts_by_methname[ - "" - ] + % counts_by_methname[""] ) print( "Total executemany calls: %d" % counts_by_methname.get( - "", 0 + "", 0 ) ) diff --git a/test/profiles.txt b/test/profiles.txt index 7db24e2ff56..976949e7b73 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -1,499 +1,482 @@ -# /mnt/photon_home/classic/dev/sqlalchemy/test/profiles.txt +# /home/classic/dev/sqlalchemy/test/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles # option - this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 77 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 77 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75 -test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 78 +test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 78 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_select -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 195 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 193 -test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 193 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 221 +test.aaa_profiling.test_compiler.CompileTest.test_select x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 221 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_select_labels -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 219 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 243 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 243 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 217 -test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 217 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 245 +test.aaa_profiling.test_compiler.CompileTest.test_select_labels x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 245 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 81 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 86 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 86 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 79 -test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 79 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 87 +test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 87 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 184 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 187 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 180 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 180 - -# TEST: test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached - -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 44 - -# TEST: test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached[no_embedded] - -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 44 - -# TEST: test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached[require_embedded] - -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_cached[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 70 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 186 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 189 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 186 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 189 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 186 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 189 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 186 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 189 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 186 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 189 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[no_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 11 -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 13 +test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 11 +test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 13 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[require_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 13 -test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 13 +test.aaa_profiling.test_misc.CCLookupTest.test_corresponding_column_isolated[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 15 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[no_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 13336 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 13354 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 13347 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 13650 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[require_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 13336 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 13354 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 13347 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 13650 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[no_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 29839 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 35374 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 28449 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 35632 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[require_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 29923 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 35374 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 28449 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_aliased_class_select_cols[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 35876 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[no_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1239 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1392 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1261 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1437 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[require_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1257 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1410 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1279 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_many_corresponding_column[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1455 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[no_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1258 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[no_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1395 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1280 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[no_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1440 # TEST: test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[require_embedded] -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1260 -test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[require_embedded] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1397 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1282 +test.aaa_profiling.test_misc.CCLookupTest.test_gen_subq_to_table_single_corresponding_column[require_embedded] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1442 # TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached -test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 303 -test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 303 +test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 303 +test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 303 # TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached -test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 4403 -test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 6103 +test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 4003 +test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6103 # TEST: test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members -test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 924 -test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 924 +test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 924 +test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 924 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 55030 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 65340 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 55930 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 65640 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 51230 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53330 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 63640 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 54230 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 63940 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 49530 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 57930 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 66340 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 58530 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 66240 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 54730 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 57030 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 65440 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 57530 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 65240 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 53730 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48730 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 52040 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 49130 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 51840 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 46030 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52230 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 60040 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 52830 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 60040 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 49130 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 51330 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 59140 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 51830 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 59040 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 48130 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 37005 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 40205 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 37705 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 40805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 34505 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 36105 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 39305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 36705 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 39805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 33505 # TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set -test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 3599 -test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 3599 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 3599 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 3599 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 3598 # TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove -test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 5527 -test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 5527 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 5527 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5527 +test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 5526 # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 128 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 128 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 136 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 136 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 132 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_nocextensions 132 # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 128 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 128 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 136 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 136 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 132 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_nocextensions 132 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15341 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26360 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 15360 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 24378 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 15325 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21419 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26438 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 21420 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 24444 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 21384 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 10704 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 11054 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 10804 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 11204 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 10754 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1154 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1154 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1154 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1154 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 1154 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 4354 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 4604 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 4304 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 4604 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 4304 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98682 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 109932 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 98632 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 112132 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 95532 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96132 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 107582 - -# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results - -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 440705 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 458805 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 96082 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 109782 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 92982 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 26832,1031,97853 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 27722,1217,116453 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 27016,1006,95353 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 28168,1215,116253 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 26604,974,92153 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 23981 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 23981 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 23981 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 23981 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 22982 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 112466 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 120723 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 113225 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 123983 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 108201 # TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 20730 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 22152 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 21197 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 22705 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 20478 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_load -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1453 -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1542 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1481 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1581 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 1412 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 110,20 -test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 110,20 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 108,20 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 108,20 +test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 108,20 # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6586 -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 7406 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 6706 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 7436 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 6316 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 275705 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 297105 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 277005 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 297305 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 263005 # TEST: test.aaa_profiling.test_orm.SessionTest.test_expire_lots -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1212 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1212 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1212 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1212 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 1098 + +# TEST: test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_after_cache + +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_after_cache x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1418 +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_after_cache x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1504 +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_after_cache x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 1399 + +# TEST: test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_no_cache + +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_no_cache x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 1859 +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_no_cache x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 1880 +test.aaa_profiling.test_orm.WithExpresionLoaderOptTest.test_from_opt_no_cache x86_64_linux_cpython_3.12_sqlite_pysqlite_dbapiunicode_cextensions 1830 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect -test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75 -test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75 +test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 75 +test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 75 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect -test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 24 -test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 24 +test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 24 +test.aaa_profiling.test_pool.QueuePoolTest.test_second_connect x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 24 # TEST: test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 55 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 55 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 55 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 55 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 55 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 55 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 53 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 55 # TEST: test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 106 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 110 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 110 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 105 -test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 105 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 108 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 110 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 108 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 110 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 108 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 110 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 108 +test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 110 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 9 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 9 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 9 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 9 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 9 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 8 -test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 9 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 8 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 9 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 8 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 9 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 8 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 9 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 8 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 9 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 2604 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 15608 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 89344 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 102348 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 2597 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 15601 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 2637 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 15641 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 2651 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 14655 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 2539 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 14614 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 2664 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 14671 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 2669 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 14676 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 2649 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 14656 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 2614 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 14621 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 22 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 22 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 14 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 22 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 24 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 21 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 16 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 16 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 15 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 20 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 15 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 22 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 24 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 21 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 16 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 16 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 15 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 19 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 20 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 15 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 27 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 29 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 24 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 26 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 17 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 17 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 19 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 17 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 18 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 17 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 23 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 24 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 17 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_cextensions 25 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_oracle_oracledb_dbapiunicode_nocextensions 26 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 17 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 17 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 18 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 6301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 87041 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 93041 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 269 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 6269 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 361 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 6361 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 5301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 257 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 5277 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 305 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 5307 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 279 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 5281 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 299 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5301 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 272 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5274 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 6301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 87041 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 93041 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 269 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 6269 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 361 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 6361 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 5301 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 257 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 5277 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 305 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 5307 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 279 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 5281 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 299 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5301 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 272 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5274 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_string -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 597 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 6601 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 87337 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 93341 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 590 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 6594 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 630 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 6634 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 642 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 5646 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 532 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 5605 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 655 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 5662 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 660 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 5667 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 640 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5647 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 605 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5612 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_unicode -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 597 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 6601 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 87337 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 93341 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 590 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_nocextensions 6594 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 630 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 6634 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 642 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 5646 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 532 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 5605 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_cextensions 655 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_mariadb_mysqldb_dbapiunicode_nocextensions 5662 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_cextensions 660 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_mssql_pyodbc_dbapiunicode_nocextensions 5667 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 640 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5647 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 605 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5612 diff --git a/test/requirements.py b/test/requirements.py index 4a0b365c2b5..2311f6e35fc 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1,7 +1,4 @@ -"""Requirements specific to SQLAlchemy's own unit tests. - - -""" +"""Requirements specific to SQLAlchemy's own unit tests.""" from sqlalchemy import exc from sqlalchemy.sql import sqltypes @@ -212,6 +209,19 @@ def non_native_boolean_unconstrained(self): ] ) + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.open() + + @property + def expression_server_defaults(self): + return skip_if( + lambda config: against(config, "mysql", "mariadb") + and not self._mysql_expression_defaults(config) + ) + @property def qmark_paramstyle(self): return only_on(["sqlite", "+pyodbc"]) @@ -301,7 +311,9 @@ def binary_comparisons(self): @property def binary_literals(self): """target backend supports simple binary literals, e.g. an - expression like:: + expression like: + + .. sourcecode:: sql SELECT CAST('foo' AS BINARY) @@ -491,6 +503,13 @@ def update_from(self): "Backend does not support UPDATE..FROM", ) + @property + def update_from_returning(self): + """Target must support UPDATE..FROM syntax where RETURNING can + return columns from the non-primary FROM clause""" + + return self.update_returning + self.update_from + skip_if("sqlite") + @property def update_from_using_alias(self): """Target must support UPDATE..FROM syntax against an alias""" @@ -522,7 +541,9 @@ def update_where_target_in_subquery(self): present in a subquery in the WHERE clause. This is an ANSI-standard syntax that apparently MySQL can't handle, - such as:: + such as: + + .. sourcecode:: sql UPDATE documents SET flag=1 WHERE documents.title IN (SELECT max(documents.title) AS title @@ -613,6 +634,16 @@ def unique_constraint_reflection_no_index_overlap(self): + skip_if("oracle") ) + @property + def inline_check_constraint_reflection(self): + return only_on( + [ + "postgresql", + "sqlite", + "oracle", + ] + ) + @property def check_constraint_reflection(self): return only_on( @@ -784,7 +815,7 @@ def order_by_col_from_union(self): #8221. """ - return fails_if(["mssql", "oracle>=12"]) + return fails_if(["mssql", "oracle < 23"]) @property def parens_in_union_contained_select_w_limit_offset(self): @@ -858,32 +889,27 @@ def pg_prepared_transaction(config): else: return num > 0 - return ( - skip_if( - [ - no_support( - "mssql", "two-phase xact not supported by drivers" - ), - no_support( - "sqlite", "two-phase xact not supported by database" - ), - # in Ia3cbbf56d4882fcc7980f90519412f1711fae74d - # we are evaluating which modern MySQL / MariaDB versions - # can handle two-phase testing without too many problems - # no_support( - # "mysql", - # "recent MySQL community editions have too many " - # "issues (late 2016), disabling for now", - # ), - NotPredicate( - LambdaPredicate( - pg_prepared_transaction, - "max_prepared_transactions not available or zero", - ) - ), - ] - ) - + self.fail_on_oracledb_thin + return skip_if( + [ + no_support("mssql", "two-phase xact not supported by drivers"), + no_support( + "sqlite", "two-phase xact not supported by database" + ), + # in Ia3cbbf56d4882fcc7980f90519412f1711fae74d + # we are evaluating which modern MySQL / MariaDB versions + # can handle two-phase testing without too many problems + # no_support( + # "mysql", + # "recent MySQL community editions have too many " + # "issues (late 2016), disabling for now", + # ), + NotPredicate( + LambdaPredicate( + pg_prepared_transaction, + "max_prepared_transactions not available or zero", + ) + ), + ] ) @property @@ -893,7 +919,8 @@ def two_phase_recovery(self): ["mysql", "mariadb"], "still can't get recover to work w/ MariaDB / MySQL", ) - + skip_if("oracle", "recovery not functional") + + skip_if("oracle+cx_oracle", "recovery not functional") + + skip_if("oracle+oracledb", "recovery can't be reliably tested") ) @property @@ -995,11 +1022,16 @@ def symbol_names_w_double_quote(self): @property def arraysize(self): - return skip_if("+pymssql", "DBAPI is missing this attribute") + return skip_if( + [ + no_support("+pymssql", "DBAPI is missing this attribute"), + no_support("+mysqlconnector", "DBAPI ignores this attribute"), + ] + ) @property def emulated_lastrowid(self): - """ "target dialect retrieves cursor.lastrowid or an equivalent + """target dialect retrieves cursor.lastrowid or an equivalent after an insert() construct executes. """ return fails_on_everything_except( @@ -1027,7 +1059,7 @@ def database_discards_null_for_autoincrement(self): @property def emulated_lastrowid_even_with_sequences(self): - """ "target dialect retrieves cursor.lastrowid or an equivalent + """target dialect retrieves cursor.lastrowid or an equivalent after an insert() construct executes, even if the table has a Sequence on it. """ @@ -1040,7 +1072,7 @@ def emulated_lastrowid_even_with_sequences(self): @property def dbapi_lastrowid(self): - """ "target backend includes a 'lastrowid' accessor on the DBAPI + """target backend includes a 'lastrowid' accessor on the DBAPI cursor object. """ @@ -1466,9 +1498,7 @@ def implicit_decimal_binds(self): expr = decimal.Decimal("15.7563") - value = e.scalar( - select(literal(expr)) - ) + value = e.scalar(select(literal(expr))) assert value == expr @@ -1572,6 +1602,16 @@ def postgresql_test_dblink(self): def postgresql_jsonb(self): return only_on("postgresql >= 9.4") + @property + def postgresql_working_nullable_domains(self): + # see https://www.postgresql.org/message-id/flat/a90f53c4-56f3-4b07-aefc-49afdc67dba6%40app.fastmail.com # noqa: E501 + return skip_if( + lambda config: (17, 0) + < config.db.dialect.server_version_info + < (17, 3), + "reflection of nullable domains broken on PG 17.0-17.2", + ) + @property def native_hstore(self): return self.any_psycopg_compatibility @@ -1693,6 +1733,10 @@ def mysql_for_update(self): def mysql_fsp(self): return only_if(["mysql >= 5.6.4", "mariadb"]) + @property + def mysql_notnull_generated_columns(self): + return only_if(["mysql >= 5.7"]) + @property def mysql_fully_case_sensitive(self): return only_if(self._has_mysql_fully_case_sensitive) @@ -1784,6 +1828,15 @@ def _mysql_check_constraints_dont_exist(self, config): # 2. they dont enforce check constraints return not self._mysql_check_constraints_exist(config) + def _mysql_expression_defaults(self, config): + return (against(config, ["mysql", "mariadb"])) and ( + config.db.dialect._support_default_function + ) + + @property + def mysql_expression_defaults(self): + return only_if(self._mysql_expression_defaults) + def _mysql_not_mariadb_102(self, config): return (against(config, ["mysql", "mariadb"])) and ( not config.db.dialect._is_mariadb @@ -1873,16 +1926,6 @@ def oracle5x(self): and config.db.dialect.cx_oracle_ver < (6,) ) - @property - def fail_on_oracledb_thin(self): - def go(config): - if against(config, "oracle+oracledb"): - with config.db.connect() as conn: - return config.db.dialect.is_thin_mode(conn) - return False - - return fails_if(go) - @property def computed_columns(self): return skip_if(["postgresql < 12", "sqlite < 3.31", "mysql < 5.7"]) @@ -2057,3 +2100,42 @@ def go(config): return False return only_if(go, "json_each is required") + + @property + def rowcount_always_cached(self): + """Indicates that ``cursor.rowcount`` is always accessed, + usually in an ``ExecutionContext.post_exec``. + """ + return only_on(["+mariadbconnector"]) + + @property + def rowcount_always_cached_on_insert(self): + """Indicates that ``cursor.rowcount`` is always accessed in an insert + statement. + """ + return only_on(["mssql"]) + + @property + def supports_bitwise_and(self): + """Target database supports bitwise and""" + return exclusions.open() + + @property + def supports_bitwise_or(self): + """Target database supports bitwise or""" + return fails_on(["oracle<21"]) + + @property + def supports_bitwise_not(self): + """Target database supports bitwise not""" + return fails_on(["oracle", "mysql", "mariadb"]) + + @property + def supports_bitwise_xor(self): + """Target database supports bitwise xor""" + return fails_on(["oracle<21"]) + + @property + def supports_bitwise_shift(self): + """Target database supports bitwise left or right shift""" + return fails_on(["oracle"]) diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 6907d213257..5e95d3cb2f7 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -5,7 +5,6 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal_column -from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table @@ -13,50 +12,48 @@ from sqlalchemy import text from sqlalchemy.sql import column from sqlalchemy.sql import table +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -info_table = None - - -class CaseTest(fixtures.TestBase, AssertsCompiledSQL): +class CaseTest(fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default" + run_inserts = "once" + run_deletes = "never" + @classmethod - def setup_test_class(cls): - metadata = MetaData() - global info_table - info_table = Table( - "infos", + def define_tables(cls, metadata): + Table( + "info_table", metadata, Column("pk", Integer, primary_key=True), Column("info", String(30)), ) - with testing.db.begin() as conn: - info_table.create(conn) - - conn.execute( - info_table.insert(), - [ - {"pk": 1, "info": "pk_1_data"}, - {"pk": 2, "info": "pk_2_data"}, - {"pk": 3, "info": "pk_3_data"}, - {"pk": 4, "info": "pk_4_data"}, - {"pk": 5, "info": "pk_5_data"}, - {"pk": 6, "info": "pk_6_data"}, - ], - ) - @classmethod - def teardown_test_class(cls): - with testing.db.begin() as conn: - info_table.drop(conn) + def insert_data(cls, connection): + info_table = cls.tables.info_table + + connection.execute( + info_table.insert(), + [ + {"pk": 1, "info": "pk_1_data"}, + {"pk": 2, "info": "pk_2_data"}, + {"pk": 3, "info": "pk_3_data"}, + {"pk": 4, "info": "pk_4_data"}, + {"pk": 5, "info": "pk_5_data"}, + {"pk": 6, "info": "pk_6_data"}, + ], + ) + connection.commit() @testing.requires.subqueries def test_case(self, connection): + info_table = self.tables.info_table + inner = select( case( (info_table.c.pk < 3, "lessthan3"), @@ -222,6 +219,8 @@ def test_when_dicts(self, test_case, expected): ) def test_text_doesnt_explode(self, connection): + info_table = self.tables.info_table + for s in [ select( case( @@ -255,6 +254,8 @@ def test_text_doenst_explode_even_in_whenlist(self): ) def testcase_with_dict(self): + info_table = self.tables.info_table + query = select( case( { @@ -294,3 +295,61 @@ def testcase_with_dict(self): ("two", 2), ("other", 3), ] + + @testing.variation("add_else", [True, False]) + def test_type_of_case_expression_with_all_nulls(self, add_else): + info_table = self.tables.info_table + + expr = case( + (info_table.c.pk < 0, None), + (info_table.c.pk > 9, None), + else_=column("q") if add_else else None, + ) + + assert isinstance(expr.type, NullType) + + @testing.combinations( + lambda info_table: ( + [ + # test non-None in middle of WHENS takes precedence over Nones + (info_table.c.pk < 0, None), + (info_table.c.pk < 5, "five"), + (info_table.c.pk <= 9, info_table.c.pk), + (info_table.c.pk > 9, None), + ], + None, + ), + lambda info_table: ( + # test non-None ELSE takes precedence over WHENs that are None + [(info_table.c.pk < 0, None)], + info_table.c.pk, + ), + lambda info_table: ( + # test non-None WHEN takes precedence over non-None ELSE + [ + (info_table.c.pk < 0, None), + (info_table.c.pk <= 9, info_table.c.pk), + (info_table.c.pk > 9, None), + ], + column("q", String), + ), + lambda info_table: ( + # test last WHEN in list takes precedence + [ + (info_table.c.pk < 0, String), + (info_table.c.pk > 9, None), + (info_table.c.pk <= 9, info_table.c.pk), + ], + column("q", String), + ), + ) + def test_type_of_case_expression(self, when_lambda): + info_table = self.tables.info_table + + whens, else_ = testing.resolve_lambda( + when_lambda, info_table=info_table + ) + + expr = case(*whens, else_=else_) + + assert isinstance(expr.type, Integer) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index b2be90f60cd..77743b9c924 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1,4 +1,5 @@ import importlib +from inspect import signature import itertools import random @@ -42,10 +43,12 @@ from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce from sqlalchemy.sql import visitors +from sqlalchemy.sql.annotation import Annotated +from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.base import HasCacheKey +from sqlalchemy.sql.base import SingletonConstant from sqlalchemy.sql.elements import _label_reference from sqlalchemy.sql.elements import _textual_label_reference -from sqlalchemy.sql.elements import Annotated from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ClauseList @@ -61,10 +64,10 @@ from sqlalchemy.sql.lambdas import LambdaElement from sqlalchemy.sql.lambdas import LambdaOptions from sqlalchemy.sql.selectable import _OffsetLimitParam -from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from sqlalchemy.sql.selectable import NoInit from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Selectable from sqlalchemy.sql.selectable import SelectStatementGrouping @@ -204,6 +207,43 @@ class CoreFixtures: bindparam("bar", type_=String) ), ), + lambda: ( + # test #11471 + text("select * from table") + .columns(a=Integer()) + .add_cte(table_b.select().cte()), + text("select * from table") + .columns(a=Integer()) + .add_cte(table_b.select().where(table_b.c.a > 5).cte()), + ), + lambda: ( + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).add_cte(select(table_b).where(table_b.c.a < 1).cte("ttt")), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ) + .add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")) + ._annotate({"foo": "bar"}), + ), + lambda: ( + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).self_group(), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ) + .self_group() + ._annotate({"foo": "bar"}), + ), lambda: ( literal(1).op("+")(literal(1)), literal(1).op("-")(literal(1)), @@ -401,6 +441,7 @@ class CoreFixtures: func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), + func.row_number().over(order_by=table_a.c.a, groups=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over( order_by=table_a.c.a, partition_by=table_a.c.b @@ -468,6 +509,21 @@ class CoreFixtures: select(table_a.c.a) .where(table_a.c.b == 5) .with_for_update(nowait=True), + select(table_a.c.a) + .where(table_a.c.b == 5) + .with_for_update(nowait=True, skip_locked=True), + select(table_a.c.a) + .where(table_a.c.b == 5) + .with_for_update(nowait=True, read=True), + select(table_a.c.a) + .where(table_a.c.b == 5) + .with_for_update(of=table_a.c.a), + select(table_a.c.a) + .where(table_a.c.b == 5) + .with_for_update(of=table_a.c.b), + select(table_a.c.a) + .where(table_a.c.b == 5) + .with_for_update(nowait=True, key_share=True), select(table_a.c.a).where(table_a.c.b == 5).correlate(table_b), select(table_a.c.a) .where(table_a.c.b == 5) @@ -482,6 +538,7 @@ class CoreFixtures: select(table_a.c.a).fetch(2, percent=True), select(table_a.c.a).fetch(2, with_ties=True), select(table_a.c.a).fetch(2, with_ties=True, percent=True), + select(table_a.c.a).fetch(2, oracle_fetch_approximate=True), select(table_a.c.a).fetch(2).offset(3), select(table_a.c.a).fetch(2).offset(5), select(table_a.c.a).limit(2).offset(5), @@ -1194,6 +1251,23 @@ def test_cache_key_object_comparators(self, lc1, lc2, lc3): is_true(c1._generate_cache_key() != c3._generate_cache_key()) is_false(c1._generate_cache_key() == c3._generate_cache_key()) + def test_in_with_none(self): + """test #12314""" + + def fixture(): + elements = list( + random_choices([1, 2, None, 3, 4], k=random.randint(1, 7)) + ) + + # slight issue. if the first element is None and not an int, + # the type of the BindParameter goes from Integer to Nulltype. + # but if we set the left side to be Integer then it comes from + # that side, and the vast majority of in_() use cases come from + # a typed column expression, so this is fine + return (column("x", Integer).in_(elements),) + + self._run_cache_key_fixture(fixture, False) + def test_cache_key(self): for fixtures_, compare_values in [ (self.fixtures, True), @@ -1345,6 +1419,253 @@ def test_generative_cache_key_regen_w_del(self): is_not(ck3, None) +def all_hascachekey_subclasses(ignore_subclasses=()): + def find_subclasses(cls: type): + for s in class_hierarchy(cls): + if ( + # class_hierarchy may return values that + # aren't subclasses of cls + not issubclass(s, cls) + or "_traverse_internals" not in s.__dict__ + or any(issubclass(s, ignore) for ignore in ignore_subclasses) + ): + continue + yield s + + return dict.fromkeys(find_subclasses(HasCacheKey)) + + +class HasCacheKeySubclass(fixtures.TestBase): + custom_traverse = { + "AnnotatedFunctionAsBinary": { + "sql_function", + "left_index", + "right_index", + "modifiers", + "_annotations", + }, + "Annotatednext_value": {"sequence", "_annotations"}, + "FunctionAsBinary": { + "sql_function", + "left_index", + "right_index", + "modifiers", + }, + "next_value": {"sequence"}, + "array": ({"type", "clauses"}), + } + + ignore_keys = { + "AnnotatedColumn": {"dialect_options"}, + "SelectStatementGrouping": { + "_independent_ctes", + "_independent_ctes_opts", + }, + } + + @testing.combinations(*all_hascachekey_subclasses()) + def test_traverse_internals(self, cls: type): + super_traverse = {} + # ignore_super = self.ignore_super.get(cls.__name__, set()) + for s in cls.mro()[1:]: + # if s.__name__ in ignore_super: + # continue + if s.__name__ == "Executable": + continue + for attr in s.__dict__: + if not attr.endswith("_traverse_internals"): + continue + for k, v in s.__dict__[attr]: + if k not in super_traverse: + super_traverse[k] = v + traverse_dict = dict(cls.__dict__["_traverse_internals"]) + eq_(len(cls.__dict__["_traverse_internals"]), len(traverse_dict)) + if cls.__name__ in self.custom_traverse: + eq_(traverse_dict.keys(), self.custom_traverse[cls.__name__]) + else: + ignore = self.ignore_keys.get(cls.__name__, set()) + + left_keys = traverse_dict.keys() | ignore + is_true( + left_keys >= super_traverse.keys(), + f"{left_keys} >= {super_traverse.keys()} - missing: " + f"{super_traverse.keys() - left_keys} - ignored {ignore}", + ) + + subset = { + k: v for k, v in traverse_dict.items() if k in super_traverse + } + eq_( + subset, + {k: v for k, v in super_traverse.items() if k not in ignore}, + ) + + # name -> (traverse names, init args) + custom_init = { + "BinaryExpression": ( + {"right", "operator", "type", "negate", "modifiers", "left"}, + {"right", "operator", "type_", "negate", "modifiers", "left"}, + ), + "BindParameter": ( + {"literal_execute", "type", "callable", "value", "key"}, + {"required", "isoutparam", "literal_execute", "type_", "callable_"} + | {"unique", "expanding", "quote", "value", "key"}, + ), + "Cast": ({"type", "clause"}, {"type_", "expression"}), + "ClauseList": ( + {"clauses", "operator"}, + {"group_contents", "group", "operator", "clauses"}, + ), + "ColumnClause": ( + {"is_literal", "type", "table", "name"}, + {"type_", "is_literal", "text"}, + ), + "ExpressionClauseList": ( + {"clauses", "operator"}, + {"type_", "operator", "clauses"}, + ), + "FromStatement": ( + {"_raw_columns", "_with_options", "element"} + | {"_propagate_attrs", "_with_context_options"}, + {"element", "entities"}, + ), + "FunctionAsBinary": ( + {"modifiers", "sql_function", "right_index", "left_index"}, + {"right_index", "left_index", "fn"}, + ), + "FunctionElement": ( + {"clause_expr", "_table_value_type", "_with_ordinality"}, + {"clauses"}, + ), + "Function": ( + {"_table_value_type", "clause_expr", "_with_ordinality"} + | {"packagenames", "type", "name"}, + {"type_", "packagenames", "name", "clauses"}, + ), + "Label": ({"_element", "type", "name"}, {"type_", "element", "name"}), + "LambdaElement": ( + {"_resolved"}, + {"role", "opts", "apply_propagate_attrs", "fn"}, + ), + "Load": ( + {"propagate_to_loaders", "additional_source_entities"} + | {"path", "context"}, + {"entity"}, + ), + "LoaderCriteriaOption": ( + {"where_criteria", "entity", "propagate_to_loaders"} + | {"root_entity", "include_aliases"}, + {"where_criteria", "include_aliases", "propagate_to_loaders"} + | {"entity_or_base", "loader_only", "track_closure_variables"}, + ), + "NullLambdaStatement": ({"_resolved"}, {"statement"}), + "ScalarFunctionColumn": ( + {"type", "fn", "name"}, + {"type_", "name", "fn"}, + ), + "ScalarValues": ( + {"_data", "_column_args", "literal_binds"}, + {"columns", "data", "literal_binds"}, + ), + "Select": ( + { + "_having_criteria", + "_distinct", + "_group_by_clauses", + "_fetch_clause", + "_limit_clause", + "_label_style", + "_order_by_clauses", + "_raw_columns", + "_correlate_except", + "_statement_hints", + "_hints", + "_independent_ctes", + "_distinct_on", + "_with_context_options", + "_setup_joins", + "_suffixes", + "_memoized_select_entities", + "_for_update_arg", + "_prefixes", + "_propagate_attrs", + "_with_options", + "_independent_ctes_opts", + "_offset_clause", + "_correlate", + "_where_criteria", + "_annotations", + "_fetch_clause_options", + "_from_obj", + }, + {"entities"}, + ), + "TableValuedColumn": ( + {"scalar_alias", "type", "name"}, + {"type_", "scalar_alias"}, + ), + "TableValueType": ({"_elements"}, {"elements"}), + "TextualSelect": ( + {"column_args", "_annotations", "_independent_ctes"} + | {"element", "_independent_ctes_opts"}, + {"positional", "columns", "text"}, + ), + "Tuple": ({"clauses", "operator"}, {"clauses", "types"}), + "TypeClause": ({"type"}, {"type_"}), + "TypeCoerce": ({"type", "clause"}, {"type_", "expression"}), + "UnaryExpression": ( + {"modifier", "element", "operator"}, + {"operator", "wraps_column_expression"} + | {"type_", "modifier", "element"}, + ), + "Values": ( + {"_column_args", "literal_binds", "name", "_data"}, + {"columns", "name", "literal_binds"}, + ), + "_FrameClause": ( + {"upper_integer_bind", "upper_type"} + | {"lower_type", "lower_integer_bind"}, + {"range_"}, + ), + "_MemoizedSelectEntities": ( + {"_with_options", "_raw_columns", "_setup_joins"}, + {"args"}, + ), + "array": ({"type", "clauses"}, {"clauses", "type_"}), + "next_value": ({"sequence"}, {"seq"}), + } + + @testing.combinations( + *all_hascachekey_subclasses( + ignore_subclasses=[ + Annotated, + NoInit, + SingletonConstant, + DialectKWArgs, + ] + ) + ) + def test_init_args_in_traversal(self, cls: type): + sig = signature(cls.__init__) + init_args = set() + for p in sig.parameters.values(): + if ( + p.name == "self" + or p.name.startswith("_") + or p.kind in (p.VAR_KEYWORD,) + ): + continue + init_args.add(p.name) + + names = {n for n, _ in cls.__dict__["_traverse_internals"]} + if cls.__name__ in self.custom_init: + traverse, inits = self.custom_init[cls.__name__] + eq_(names, traverse) + eq_(init_args, inits) + else: + is_true(names.issuperset(init_args), f"{names} : {init_args}") + + class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod def setup_test_class(cls): @@ -1360,21 +1681,16 @@ def test_all_present(self): also included in the fixtures above. """ - need = { + need = set( cls - for cls in class_hierarchy(ClauseElement) - if issubclass(cls, (ColumnElement, Selectable, LambdaElement)) - and ( - "__init__" in cls.__dict__ - or issubclass(cls, AliasedReturnsRows) + for cls in all_hascachekey_subclasses( + ignore_subclasses=[Annotated, NoInit, SingletonConstant] ) - and not issubclass(cls, (Annotated)) - and cls.__module__.startswith("sqlalchemy.") - and "orm" not in cls.__module__ + if "orm" not in cls.__module__ and "compiler" not in cls.__module__ - and "crud" not in cls.__module__ - and "dialects" not in cls.__module__ # TODO: dialects? - }.difference({ColumnElement, UnaryExpression}) + and "dialects" not in cls.__module__ + and issubclass(cls, (ColumnElement, Selectable, LambdaElement)) + ) for fixture in self.fixtures + self.dont_compare_values_fixtures: case_a = fixture() diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 3bd1bacc6d8..7c43e60db3f 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -12,6 +12,7 @@ import datetime import decimal +import re from typing import TYPE_CHECKING from sqlalchemy import alias @@ -44,6 +45,10 @@ from sqlalchemy import MetaData from sqlalchemy import not_ from sqlalchemy import null +from sqlalchemy import nulls_first +from sqlalchemy import nulls_last +from sqlalchemy import nullsfirst +from sqlalchemy import nullslast from sqlalchemy import Numeric from sqlalchemy import or_ from sqlalchemy import outerjoin @@ -1544,7 +1549,7 @@ def test_scalar_select(self): ) self.assert_compile( select(select(table1.c.name).label("foo")), - "SELECT (SELECT mytable.name FROM mytable) " "AS foo", + "SELECT (SELECT mytable.name FROM mytable) AS foo", ) # scalar selects should not have any attributes on their 'c' or @@ -1668,44 +1673,85 @@ def test_label_comparison_two(self): "foo || :param_1", ) - def test_order_by_labels_enabled(self): + def test_order_by_labels_enabled_negative_cases(self): + """test order_by_labels enabled but the cases where we expect + ORDER BY the expression without the label name""" + lab1 = (table1.c.myid + 12).label("foo") lab2 = func.somefunc(table1.c.name).label("bar") dialect = default.DefaultDialect() + # binary expressions render as the expression without labels self.assert_compile( - select(lab1, lab2).order_by(lab1, desc(lab2)), + select(lab1, lab2).order_by(lab1 + "test"), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY foo, bar DESC", + "ORDER BY mytable.myid + :myid_1 + :param_1", dialect=dialect, ) - # the function embedded label renders as the function + # labels within functions in the columns clause render + # with the expression self.assert_compile( - select(lab1, lab2).order_by(func.hoho(lab1), desc(lab2)), + select(lab1, func.foo(lab1)).order_by(lab1, func.foo(lab1)), "SELECT mytable.myid + :myid_1 AS foo, " - "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY hoho(mytable.myid + :myid_1), bar DESC", + "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable " + "ORDER BY foo, foo(mytable.myid + :myid_1)", dialect=dialect, ) - # binary expressions render as the expression without labels + # here, 'name' is implicitly available, but w/ #3882 we don't + # want to render a name that isn't specifically a Label elsewhere + # in the query self.assert_compile( - select(lab1, lab2).order_by(lab1 + "test"), + select(table1.c.myid).order_by(table1.c.name.label("name")), + "SELECT mytable.myid FROM mytable ORDER BY mytable.name", + ) + + # as well as if it doesn't match + self.assert_compile( + select(table1.c.myid).order_by( + func.lower(table1.c.name).label("name") + ), + "SELECT mytable.myid FROM mytable ORDER BY lower(mytable.name)", + ) + + @testing.combinations( + (desc, "DESC"), + (asc, "ASC"), + (nulls_first, "NULLS FIRST"), + (nulls_last, "NULLS LAST"), + (nullsfirst, "NULLS FIRST"), + (nullslast, "NULLS LAST"), + (lambda c: c.desc().nulls_last(), "DESC NULLS LAST"), + (lambda c: c.desc().nullslast(), "DESC NULLS LAST"), + (lambda c: c.nulls_first().asc(), "NULLS FIRST ASC"), + ) + def test_order_by_labels_enabled(self, operator, expected): + """test positive cases with order_by_labels enabled. this is + multipled out to all the ORDER BY modifier operators + (see #11592) + + + """ + lab1 = (table1.c.myid + 12).label("foo") + lab2 = func.somefunc(table1.c.name).label("bar") + dialect = default.DefaultDialect() + + self.assert_compile( + select(lab1, lab2).order_by(lab1, operator(lab2)), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY mytable.myid + :myid_1 + :param_1", + f"ORDER BY foo, bar {expected}", dialect=dialect, ) - # labels within functions in the columns clause render - # with the expression + # the function embedded label renders as the function self.assert_compile( - select(lab1, func.foo(lab1)).order_by(lab1, func.foo(lab1)), + select(lab1, lab2).order_by(func.hoho(lab1), operator(lab2)), "SELECT mytable.myid + :myid_1 AS foo, " - "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable " - "ORDER BY foo, foo(mytable.myid + :myid_1)", + "somefunc(mytable.name) AS bar FROM mytable " + f"ORDER BY hoho(mytable.myid + :myid_1), bar {expected}", dialect=dialect, ) @@ -1713,62 +1759,49 @@ def test_order_by_labels_enabled(self): ly = (func.lower(table1.c.name) + table1.c.description).label("ly") self.assert_compile( - select(lx, ly).order_by(lx, ly.desc()), + select(lx, ly).order_by(lx, operator(ly)), "SELECT mytable.myid + mytable.myid AS lx, " "lower(mytable.name) || mytable.description AS ly " - "FROM mytable ORDER BY lx, ly DESC", + f"FROM mytable ORDER BY lx, ly {expected}", dialect=dialect, ) # expression isn't actually the same thing (even though label is) self.assert_compile( select(lab1, lab2).order_by( - table1.c.myid.label("foo"), desc(table1.c.name.label("bar")) + table1.c.myid.label("foo"), + operator(table1.c.name.label("bar")), ), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY mytable.myid, mytable.name DESC", + f"ORDER BY mytable.myid, mytable.name {expected}", dialect=dialect, ) # it's also an exact match, not aliased etc. self.assert_compile( select(lab1, lab2).order_by( - desc(table1.alias().c.name.label("bar")) + operator(table1.alias().c.name.label("bar")) ), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY mytable_1.name DESC", + f"ORDER BY mytable_1.name {expected}", dialect=dialect, ) # but! it's based on lineage lab2_lineage = lab2.element._clone() self.assert_compile( - select(lab1, lab2).order_by(desc(lab2_lineage.label("bar"))), + select(lab1, lab2).order_by(operator(lab2_lineage.label("bar"))), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY bar DESC", + f"ORDER BY bar {expected}", dialect=dialect, ) - # here, 'name' is implicitly available, but w/ #3882 we don't - # want to render a name that isn't specifically a Label elsewhere - # in the query - self.assert_compile( - select(table1.c.myid).order_by(table1.c.name.label("name")), - "SELECT mytable.myid FROM mytable ORDER BY mytable.name", - ) - - # as well as if it doesn't match - self.assert_compile( - select(table1.c.myid).order_by( - func.lower(table1.c.name).label("name") - ), - "SELECT mytable.myid FROM mytable ORDER BY lower(mytable.name)", - ) - def test_order_by_labels_disabled(self): + """test when the order_by_labels feature is disabled entirely""" + lab1 = (table1.c.myid + 12).label("foo") lab2 = func.somefunc(table1.c.name).label("bar") dialect = default.DefaultDialect() @@ -2694,7 +2727,7 @@ def test_deduping_unique_across_selects(self): self.assert_compile( s3, - "SELECT NULL AS anon_1, NULL AS anon__1 " "UNION " + "SELECT NULL AS anon_1, NULL AS anon__1 UNION " # without the feature tested in test_deduping_hash_algo we'd get # "SELECT true AS anon_2, true AS anon__1", "SELECT true AS anon_2, true AS anon__2", @@ -3180,6 +3213,41 @@ def test_over_framespec(self): checkparams={"param_1": 10, "param_2": 1}, ) + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(None, 0))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + "UNBOUNDED PRECEDING AND CURRENT ROW)" + " AS anon_1 FROM mytable", + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(-5, 10))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 PRECEDING AND :param_2 FOLLOWING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 5, "param_2": 10}, + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(1, 10))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 FOLLOWING AND :param_2 FOLLOWING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 1, "param_2": 10}, + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(-10, -1))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 PRECEDING AND :param_2 PRECEDING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 10, "param_2": 1}, + ) + def test_over_invalid_framespecs(self): assert_raises_message( exc.ArgumentError, @@ -3197,10 +3265,35 @@ def test_over_invalid_framespecs(self): assert_raises_message( exc.ArgumentError, - "'range_' and 'rows' are mutually exclusive", + "only one of 'rows', 'range_', or 'groups' may be provided", + func.row_number().over, + range_=(-5, 8), + rows=(-2, 5), + ) + + assert_raises_message( + exc.ArgumentError, + "only one of 'rows', 'range_', or 'groups' may be provided", func.row_number().over, range_=(-5, 8), + groups=(None, None), + ) + + assert_raises_message( + exc.ArgumentError, + "only one of 'rows', 'range_', or 'groups' may be provided", + func.row_number().over, rows=(-2, 5), + groups=(None, None), + ) + + assert_raises_message( + exc.ArgumentError, + "only one of 'rows', 'range_', or 'groups' may be provided", + func.row_number().over, + range_=(-5, 8), + rows=(-2, 5), + groups=(None, None), ) def test_over_within_group(self): @@ -3775,7 +3868,7 @@ def test_binds(self): ) assert_raises_message( exc.CompileError, - "conflicts with unique bind parameter " "of the same name", + "conflicts with unique bind parameter of the same name", str, s, ) @@ -3789,7 +3882,7 @@ def test_binds(self): ) assert_raises_message( exc.CompileError, - "conflicts with unique bind parameter " "of the same name", + "conflicts with unique bind parameter of the same name", str, s, ) @@ -4434,7 +4527,7 @@ def test_tuple_expanding_in_no_values(self): ) self.assert_compile( expr, - "(mytable.myid, mytable.name) IN " "(__[POSTCOMPILE_param_1])", + "(mytable.myid, mytable.name) IN (__[POSTCOMPILE_param_1])", checkparams={"param_1": [(1, "foo"), (5, "bar")]}, check_post_param={"param_1": [(1, "foo"), (5, "bar")]}, check_literal_execute={}, @@ -4469,7 +4562,7 @@ def test_tuple_expanding_in_values(self): dialect.tuple_in_values = True self.assert_compile( tuple_(table1.c.myid, table1.c.name).in_([(1, "foo"), (5, "bar")]), - "(mytable.myid, mytable.name) IN " "(__[POSTCOMPILE_param_1])", + "(mytable.myid, mytable.name) IN (__[POSTCOMPILE_param_1])", dialect=dialect, checkparams={"param_1": [(1, "foo"), (5, "bar")]}, check_post_param={"param_1": [(1, "foo"), (5, "bar")]}, @@ -4816,7 +4909,7 @@ def test_render_literal_execute_parameter_literal_binds(self): select(table1.c.myid).where( table1.c.myid == bindparam("foo", 5, literal_execute=True) ), - "SELECT mytable.myid FROM mytable " "WHERE mytable.myid = 5", + "SELECT mytable.myid FROM mytable WHERE mytable.myid = 5", literal_binds=True, ) @@ -4843,7 +4936,7 @@ def test_render_literal_execute_parameter_render_postcompile(self): select(table1.c.myid).where( table1.c.myid == bindparam("foo", 5, literal_execute=True) ), - "SELECT mytable.myid FROM mytable " "WHERE mytable.myid = 5", + "SELECT mytable.myid FROM mytable WHERE mytable.myid = 5", render_postcompile=True, ) @@ -5974,6 +6067,53 @@ def visit_widget(self, element, **kw): ): eq_(str(Grouping(Widget())), "(widget)") + def test_dialect_sub_compile_has_stack(self): + """test #10753""" + + class Widget(ColumnElement): + __visit_name__ = "widget" + stringify_dialect = "sqlite" + + def visit_widget(self, element, **kw): + assert self.stack + return "widget" + + with mock.patch( + "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget", + visit_widget, + create=True, + ): + eq_(str(select(Widget())), "SELECT widget AS anon_1") + + def test_dialect_sub_compile_has_stack_pg_specific(self): + """test #10753""" + my_table = table( + "my_table", column("id"), column("data"), column("user_email") + ) + + from sqlalchemy.dialects.postgresql import insert + + insert_stmt = insert(my_table).values( + id="some_existing_id", data="inserted value" + ) + + do_update_stmt = insert_stmt.on_conflict_do_update( + index_elements=["id"], set_=dict(data="updated value") + ) + + # note! two different bound parameter formats. It's weird yes, + # but this is what I want. They are stringifying without using the + # correct dialect. We could use the PG compiler at the point of + # the insert() but that still would not accommodate params in other + # parts of the statement. + eq_ignore_whitespace( + str(select(do_update_stmt.cte())), + "WITH anon_1 AS (INSERT INTO my_table (id, data) " + "VALUES (:param_1, :param_2) " + "ON CONFLICT (id) " + "DO UPDATE SET data = %(param_3)s) SELECT FROM anon_1", + ) + def test_dialect_sub_compile_w_binds(self): """test sub-compile into a new compiler where state != CompilerState.COMPILING, but we have to render a bindparam @@ -6089,7 +6229,7 @@ def test_dialect_specific_ddl(self): eq_ignore_whitespace( str(schema.AddConstraint(cons)), - "ALTER TABLE testtbl ADD EXCLUDE USING gist " "(room WITH =)", + "ALTER TABLE testtbl ADD EXCLUDE USING gist (room WITH =)", ) def test_try_cast(self): @@ -6595,6 +6735,9 @@ def test_fk_illegal_sql_phrases(self): "FOO RESTRICT", "CASCADE WRONG", "SET NULL", + # test that PostgreSQL's syntax added in #11595 is not + # accepted by base compiler + "SET NULL(postgresql_db.some_column)", ): const = schema.AddConstraint( schema.ForeignKeyConstraint( @@ -6603,7 +6746,7 @@ def test_fk_illegal_sql_phrases(self): ) assert_raises_message( exc.CompileError, - r"Unexpected SQL phrase: '%s'" % phrase, + rf"Unexpected SQL phrase: '{re.escape(phrase)}'", const.compile, ) @@ -6822,65 +6965,59 @@ def test_schema_translate_crud(self): render_schema_translate=True, ) - def test_schema_non_schema_disambiguation(self): - """test #7471""" - - t1 = table("some_table", column("id"), column("q")) - t2 = table("some_table", column("id"), column("p"), schema="foo") - - self.assert_compile( - select(t1, t2), + @testing.combinations( + ( + lambda t1, t2: select(t1, t2), "SELECT some_table_1.id, some_table_1.q, " "foo.some_table.id AS id_1, foo.some_table.p " "FROM some_table AS some_table_1, foo.some_table", - ) - - self.assert_compile( - select(t1, t2).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL), + ), + ( + lambda t1, t2: select(t1, t2).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ), # the original "tablename_colname" label is preserved despite # the alias of some_table "SELECT some_table_1.id AS some_table_id, some_table_1.q AS " "some_table_q, foo.some_table.id AS foo_some_table_id, " "foo.some_table.p AS foo_some_table_p " "FROM some_table AS some_table_1, foo.some_table", - ) - - self.assert_compile( - select(t1, t2).join_from(t1, t2, t1.c.id == t2.c.id), + ), + ( + lambda t1, t2: select(t1, t2).join_from( + t1, t2, t1.c.id == t2.c.id + ), "SELECT some_table_1.id, some_table_1.q, " "foo.some_table.id AS id_1, foo.some_table.p " "FROM some_table AS some_table_1 " "JOIN foo.some_table ON some_table_1.id = foo.some_table.id", - ) - - self.assert_compile( - select(t1, t2).where(t1.c.id == t2.c.id), + ), + ( + lambda t1, t2: select(t1, t2).where(t1.c.id == t2.c.id), "SELECT some_table_1.id, some_table_1.q, " "foo.some_table.id AS id_1, foo.some_table.p " "FROM some_table AS some_table_1, foo.some_table " "WHERE some_table_1.id = foo.some_table.id", - ) - - self.assert_compile( - select(t1).where(t1.c.id == t2.c.id), + ), + ( + lambda t1, t2: select(t1).where(t1.c.id == t2.c.id), "SELECT some_table_1.id, some_table_1.q " "FROM some_table AS some_table_1, foo.some_table " "WHERE some_table_1.id = foo.some_table.id", - ) - - subq = select(t1).where(t1.c.id == t2.c.id).subquery() - self.assert_compile( - select(t2).select_from(t2).join(subq, t2.c.id == subq.c.id), + ), + ( + lambda t2, subq: select(t2) + .select_from(t2) + .join(subq, t2.c.id == subq.c.id), "SELECT foo.some_table.id, foo.some_table.p " "FROM foo.some_table JOIN " "(SELECT some_table_1.id AS id, some_table_1.q AS q " "FROM some_table AS some_table_1, foo.some_table " "WHERE some_table_1.id = foo.some_table.id) AS anon_1 " "ON foo.some_table.id = anon_1.id", - ) - - self.assert_compile( - select(t1, subq.c.id) + ), + ( + lambda t1, subq: select(t1, subq.c.id) .select_from(t1) .join(subq, t1.c.id == subq.c.id), # some_table is only aliased inside the subquery. this is not @@ -6892,8 +7029,59 @@ def test_schema_non_schema_disambiguation(self): "FROM some_table AS some_table_1, foo.some_table " "WHERE some_table_1.id = foo.some_table.id) AS anon_1 " "ON some_table.id = anon_1.id", + ), + ( + # issue #12451 + lambda t1alias, t2: select(t2, t1alias), + "SELECT foo.some_table.id, foo.some_table.p, " + "some_table_1.id AS id_1, some_table_1.q FROM foo.some_table, " + "some_table AS some_table_1", + ), + ( + # issue #12451 + lambda t1alias, t2: select(t2).join( + t1alias, t1alias.c.q == t2.c.p + ), + "SELECT foo.some_table.id, foo.some_table.p FROM foo.some_table " + "JOIN some_table AS some_table_1 " + "ON some_table_1.q = foo.some_table.p", + ), + ( + # issue #12451 + lambda t1alias, t2: select(t1alias).join( + t2, t1alias.c.q == t2.c.p + ), + "SELECT some_table_1.id, some_table_1.q " + "FROM some_table AS some_table_1 " + "JOIN foo.some_table ON some_table_1.q = foo.some_table.p", + ), + ( + # issue #12451 + lambda t1alias, t2alias: select(t1alias, t2alias).join( + t2alias, t1alias.c.q == t2alias.c.p + ), + "SELECT some_table_1.id, some_table_1.q, " + "some_table_2.id AS id_1, some_table_2.p " + "FROM some_table AS some_table_1 " + "JOIN foo.some_table AS some_table_2 " + "ON some_table_1.q = some_table_2.p", + ), + ) + def test_schema_non_schema_disambiguation(self, stmt, expected): + """test #7471, and its regression #12451""" + + t1 = table("some_table", column("id"), column("q")) + t2 = table("some_table", column("id"), column("p"), schema="foo") + t1alias = t1.alias() + t2alias = t2.alias() + subq = select(t1).where(t1.c.id == t2.c.id).subquery() + + stmt = testing.resolve_lambda( + stmt, t1=t1, t2=t2, subq=subq, t1alias=t1alias, t2alias=t2alias ) + self.assert_compile(stmt, expected) + def test_alias(self): a = alias(table4, "remtable") self.assert_compile( @@ -7290,7 +7478,7 @@ def test_correlate_auto_where_singlefrom(self): s = select(t1.c.a) s2 = select(t1).where(t1.c.a == s.scalar_subquery()) self.assert_compile( - s2, "SELECT t1.a FROM t1 WHERE t1.a = " "(SELECT t1.a FROM t1)" + s2, "SELECT t1.a FROM t1 WHERE t1.a = (SELECT t1.a FROM t1)" ) def test_correlate_semiauto_where_singlefrom(self): @@ -7478,7 +7666,6 @@ def test_val_and_null(self): class ResultMapTest(fixtures.TestBase): - """test the behavior of the 'entry stack' and the determination when the result_map needs to be populated. @@ -7693,9 +7880,9 @@ def test_select_wraps_for_translate_ambiguity(self): with mock.patch.object( dialect.statement_compiler, "translate_select_structure", - lambda self, to_translate, **kw: wrapped_again - if to_translate is stmt - else to_translate, + lambda self, to_translate, **kw: ( + wrapped_again if to_translate is stmt else to_translate + ), ): compiled = stmt.compile(dialect=dialect) @@ -7752,9 +7939,9 @@ def test_select_wraps_for_translate_ambiguity_dupe_cols(self): with mock.patch.object( dialect.statement_compiler, "translate_select_structure", - lambda self, to_translate, **kw: wrapped_again - if to_translate is stmt - else to_translate, + lambda self, to_translate, **kw: ( + wrapped_again if to_translate is stmt else to_translate + ), ): compiled = stmt.compile(dialect=dialect) diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 54fcba576ca..ebd44cdcb57 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -286,7 +286,7 @@ def _assert_cyclic_constraint_supports_alter(self, metadata, auto=False): if auto: fk_assertions.append( CompiledSQL( - "ALTER TABLE a ADD " "FOREIGN KEY(bid) REFERENCES b (id)" + "ALTER TABLE a ADD FOREIGN KEY(bid) REFERENCES b (id)" ) ) assertions.append(AllOf(*fk_assertions)) @@ -409,10 +409,10 @@ def test_cycle_unnamed_fks(self): ), AllOf( CompiledSQL( - "ALTER TABLE b ADD " "FOREIGN KEY(aid) REFERENCES a (id)" + "ALTER TABLE b ADD FOREIGN KEY(aid) REFERENCES a (id)" ), CompiledSQL( - "ALTER TABLE a ADD " "FOREIGN KEY(bid) REFERENCES b (id)" + "ALTER TABLE a ADD FOREIGN KEY(bid) REFERENCES b (id)" ), ), ] @@ -720,10 +720,10 @@ def test_index_create_inline(self): RegexSQL("^CREATE TABLE events"), AllOf( CompiledSQL( - "CREATE UNIQUE INDEX ix_events_name ON events " "(name)" + "CREATE UNIQUE INDEX ix_events_name ON events (name)" ), CompiledSQL( - "CREATE INDEX ix_events_location ON events " "(location)" + "CREATE INDEX ix_events_location ON events (location)" ), CompiledSQL( "CREATE UNIQUE INDEX sport_announcer ON events " @@ -817,7 +817,7 @@ def test_too_long_index_name(self): self.assert_compile( schema.CreateIndex(ix1), - "CREATE INDEX %s " "ON %s (%s)" % (exp, tname, cname), + "CREATE INDEX %s ON %s (%s)" % (exp, tname, cname), dialect=dialect, ) @@ -1219,7 +1219,11 @@ def test_render_ck_constraint_external(self): "CHECK (a < b) DEFERRABLE INITIALLY DEFERRED", ) - def test_external_ck_constraint_cancels_internal(self): + @testing.variation("isolate", [True, False]) + @testing.variation("type_", ["add", "drop"]) + def test_external_ck_constraint_cancels_internal( + self, isolate: testing.Variation, type_: testing.Variation + ): t, t2 = self._constraint_create_fixture() constraint = CheckConstraint( @@ -1230,15 +1234,27 @@ def test_external_ck_constraint_cancels_internal(self): table=t, ) - schema.AddConstraint(constraint) - - # once we make an AddConstraint, - # inline compilation of the CONSTRAINT - # is disabled - self.assert_compile( - schema.CreateTable(t), - "CREATE TABLE tbl (" "a INTEGER, " "b INTEGER" ")", - ) + if type_.add: + cls = schema.AddConstraint + elif type_.drop: + cls = schema.DropConstraint + else: + type_.fail() + + if not isolate: + cls(constraint, isolate_from_table=False) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER, " + "CONSTRAINT my_test_constraint CHECK (a < b) " + "DEFERRABLE INITIALLY DEFERRED)", + ) + else: + cls(constraint) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER)", + ) def test_render_drop_constraint(self): t, t2 = self._constraint_create_fixture() diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index d044212aa60..92b83b7fe35 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -8,6 +8,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true +from sqlalchemy import union_all from sqlalchemy import update from sqlalchemy.dialects import mssql from sqlalchemy.engine import default @@ -296,7 +297,9 @@ def test_recursive_union_alias_one(self): def test_recursive_union_no_alias_two(self): """ - pg's example:: + pg's example: + + .. sourcecode:: sql WITH RECURSIVE t(n) AS ( VALUES (1) @@ -490,16 +493,22 @@ def test_recursive_union_alias_four(self): ) @testing.combinations(True, False, argnames="identical") - @testing.combinations(True, False, argnames="use_clone") - def test_conflicting_names(self, identical, use_clone): + @testing.variation("clone_type", ["none", "clone", "annotated"]) + def test_conflicting_names(self, identical, clone_type): """test a flat out name conflict.""" s1 = select(1) c1 = s1.cte(name="cte1", recursive=True) - if use_clone: + if clone_type.clone: c2 = c1._clone() if not identical: c2 = c2.union(select(2)) + elif clone_type.annotated: + # this does not seem to trigger the issue that was fixed in + # #12364 howver is still a worthy test + c2 = c1._annotate({"foo": "bar"}) + if not identical: + c2 = c2.union(select(2)) else: if identical: s2 = select(1) @@ -509,19 +518,53 @@ def test_conflicting_names(self, identical, use_clone): s = select(c1, c2) - if use_clone and identical: + if clone_type.clone and identical: self.assert_compile( s, 'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1, ' 'cte1.1 AS "1_1" FROM cte1', ) + elif clone_type.annotated and identical: + # annotated seems to have a slightly different rendering + # scheme here + self.assert_compile( + s, + 'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1, ' + 'cte1.1 AS "1__1" FROM cte1', + ) else: assert_raises_message( CompileError, - "Multiple, unrelated CTEs found " "with the same name: 'cte1'", + "Multiple, unrelated CTEs found with the same name: 'cte1'", s.compile, ) + @testing.variation("annotated", [True, False]) + def test_cte_w_annotated(self, annotated): + """test #12364""" + + A = table("a", column("i"), column("j")) + B = table("b", column("i"), column("j")) + + a = select(A).where(A.c.i > A.c.j).cte("filtered_a") + + if annotated: + a = a._annotate({"foo": "bar"}) + + a1 = select(a.c.i, literal(1).label("j")) + b = select(B).join(a, a.c.i == B.c.i).where(B.c.j.is_not(None)) + + query = union_all(a1, b) + self.assert_compile( + query, + "WITH filtered_a AS " + "(SELECT a.i AS i, a.j AS j FROM a WHERE a.i > a.j) " + "SELECT filtered_a.i, :param_1 AS j FROM filtered_a " + "UNION ALL SELECT b.i, b.j " + "FROM b JOIN filtered_a ON filtered_a.i = b.i " + "WHERE b.j IS NOT NULL", + ) + def test_with_recursive_no_name_currently_buggy(self): s1 = select(1) c1 = s1.cte(name="cte1", recursive=True) @@ -613,7 +656,7 @@ def test_order_by_group_by_label_w_scalar_subquery( stmt, "WITH anon_1 AS (SELECT test.a AS b FROM test %s b) " "SELECT (SELECT anon_1.b FROM anon_1) AS c" - % ("ORDER BY" if order_by == "order_by" else "GROUP BY") + % ("ORDER BY" if order_by == "order_by" else "GROUP BY"), # prior to the fix, the use_object version came out as: # "WITH anon_1 AS (SELECT test.a AS b FROM test " # "ORDER BY test.a) " @@ -1383,6 +1426,36 @@ def test_insert_w_cte_in_scalar_subquery(self, dialect): else: assert False + @testing.variation("operation", ["insert", "update", "delete"]) + def test_stringify_standalone_dml_cte(self, operation): + """test issue discovered as part of #10753""" + + t1 = table("table_1", column("id"), column("val")) + + if operation.insert: + stmt = t1.insert() + expected = ( + "INSERT INTO table_1 (id, val) VALUES (:id, :val) " + "RETURNING table_1.id, table_1.val" + ) + elif operation.update: + stmt = t1.update() + expected = ( + "UPDATE table_1 SET id=:id, val=:val " + "RETURNING table_1.id, table_1.val" + ) + elif operation.delete: + stmt = t1.delete() + expected = "DELETE FROM table_1 RETURNING table_1.id, table_1.val" + else: + operation.fail() + + stmt = stmt.returning(t1.c.id, t1.c.val) + + cte = stmt.cte() + + self.assert_compile(cte, expected) + @testing.combinations( ("default_enhanced",), ("postgresql",), @@ -1827,6 +1900,37 @@ def test_insert_uses_independent_cte(self): checkparams={"id": 1, "price": 20, "param_1": 10, "price_1": 50}, ) + @testing.variation("num_ctes", ["one", "two"]) + def test_multiple_multivalues_inserts(self, num_ctes): + """test #12363""" + + t1 = table("table1", column("id"), column("a"), column("b")) + + t2 = table("table2", column("id"), column("a"), column("b")) + + if num_ctes.one: + self.assert_compile( + insert(t1) + .values([{"a": 1}, {"a": 2}]) + .add_cte(insert(t2).values([{"a": 5}, {"a": 6}]).cte()), + "WITH anon_1 AS " + "(INSERT INTO table2 (a) VALUES (:param_1), (:param_2)) " + "INSERT INTO table1 (a) VALUES (:a_m0), (:a_m1)", + ) + + elif num_ctes.two: + self.assert_compile( + insert(t1) + .values([{"a": 1}, {"a": 2}]) + .add_cte(insert(t1).values([{"b": 5}, {"b": 6}]).cte()) + .add_cte(insert(t2).values([{"a": 5}, {"a": 6}]).cte()), + "WITH anon_1 AS " + "(INSERT INTO table1 (b) VALUES (:param_1), (:param_2)), " + "anon_2 AS " + "(INSERT INTO table2 (a) VALUES (:param_3), (:param_4)) " + "INSERT INTO table1 (a) VALUES (:a_m0), (:a_m1)", + ) + def test_insert_from_select_uses_independent_cte(self): """test #7036""" diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index bbfb3b07782..bcfdfcdb9c9 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1234,7 +1234,6 @@ def test_col_w_nonoptional_sequence_non_autoinc_no_firing( class SpecialTypePKTest(fixtures.TestBase): - """test process_result_value in conjunction with primary key columns. Also tests that "autoincrement" checks are against diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index dbb5644cd1e..96b636bd058 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -326,7 +326,7 @@ def test_append_column_after_replace_selectable(self): sel = select(basefrom.c.a) with testing.expect_deprecated( - r"The Selectable.replace_selectable\(\) " "method is deprecated" + r"The Selectable.replace_selectable\(\) method is deprecated" ): replaced = sel.replace_selectable( basefrom, basefrom.join(joinfrom, basefrom.c.a == joinfrom.c.a) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index e474e75d756..d044d8b57f0 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -54,7 +54,6 @@ class TraversalTest( fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL ): - """test ClauseVisitor's traversal, particularly its ability to copy and modify a ClauseElement in place.""" @@ -362,7 +361,6 @@ class CustomObj(Column): class BinaryEndpointTraversalTest(fixtures.TestBase): - """test the special binary product visit""" def _assert_traversal(self, expr, expected): @@ -443,7 +441,6 @@ def test_subquery(self): class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): - """test copy-in-place behavior of various ClauseElements.""" __dialect__ = "default" @@ -2188,7 +2185,7 @@ def test_table_to_alias_8(self): def test_table_to_alias_9(self): s = select(literal_column("*")).select_from(t1).alias("foo") self.assert_compile( - s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo" + s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo" ) def test_table_to_alias_10(self): @@ -2197,13 +2194,13 @@ def test_table_to_alias_10(self): vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse(s.select()), - "SELECT foo.* FROM (SELECT * FROM table1 " "AS t1alias) AS foo", + "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo", ) def test_table_to_alias_11(self): s = select(literal_column("*")).select_from(t1).alias("foo") self.assert_compile( - s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo" + s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo" ) def test_table_to_alias_12(self): @@ -2212,7 +2209,7 @@ def test_table_to_alias_12(self): ff = vis.traverse(func.count(t1.c.col1).label("foo")) self.assert_compile( select(ff), - "SELECT count(t1alias.col1) AS foo FROM " "table1 AS t1alias", + "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias", ) assert list(_from_objects(ff)) == [t1alias] @@ -2703,7 +2700,7 @@ def test_splice_2(self): ) self.assert_compile( sql_util.splice_joins(table1, j2), - "table1 JOIN table4 AS table4_1 ON " "table1.col3 = table4_1.col3", + "table1 JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3", ) self.assert_compile( sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2), @@ -2716,7 +2713,6 @@ def test_splice_2(self): class SelectTest(fixtures.TestBase, AssertsCompiledSQL): - """tests the generative capability of Select""" __dialect__ = "default" @@ -2730,23 +2726,23 @@ def setup_test_class(cls): def test_columns(self): s = t1.select() self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) select_copy = s.add_columns(column("yyy")) self.assert_compile( select_copy, - "SELECT table1.col1, table1.col2, " "table1.col3, yyy FROM table1", + "SELECT table1.col1, table1.col2, table1.col3, yyy FROM table1", ) is_not(s.selected_columns, select_copy.selected_columns) is_not(s._raw_columns, select_copy._raw_columns) self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) def test_froms(self): s = t1.select() self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) select_copy = s.select_from(t2) self.assert_compile( @@ -2756,13 +2752,13 @@ def test_froms(self): ) self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) def test_prefixes(self): s = t1.select() self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) select_copy = s.prefix_with("FOOBER") self.assert_compile( @@ -2771,7 +2767,7 @@ def test_prefixes(self): "table1.col3 FROM table1", ) self.assert_compile( - s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1" ) def test_execution_options(self): @@ -2811,7 +2807,6 @@ def _NOTYET_test_execution_options_in_text(self): class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests the generative capability of Insert, Update""" __dialect__ = "default" diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 139499d941e..6608c51073b 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -97,7 +97,7 @@ def test_plain_cartesian(self): @testing.combinations(("lateral",), ("cartesian",), ("join",)) def test_lateral_subqueries(self, control): """ - :: + .. sourcecode:: sql test=> create table a (id integer); CREATE TABLE diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index c47601b7616..28cdb03a965 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -844,7 +844,58 @@ def test_funcfilter_windowing_rows(self): "AS anon_1 FROM mytable", ) + def test_funcfilter_windowing_groups(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(groups=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > :name_1) " + "OVER (PARTITION BY mytable.description GROUPS BETWEEN :param_1 " + "FOLLOWING AND :param_2 FOLLOWING) " + "AS anon_1 FROM mytable", + ) + + def test_funcfilter_windowing_groups_positional(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(groups=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > ?) " + "OVER (PARTITION BY mytable.description GROUPS BETWEEN ? " + "FOLLOWING AND ? FOLLOWING) " + "AS anon_1 FROM mytable", + checkpositional=("foo", 1, 5), + dialect="default_qmark", + ) + + def test_funcfilter_more_criteria(self): + ff = func.rank().filter(table1.c.name > "foo") + ff2 = ff.filter(table1.c.myid == 1) + self.assert_compile( + select(ff, ff2), + "SELECT rank() FILTER (WHERE mytable.name > :name_1) AS anon_1, " + "rank() FILTER (WHERE mytable.name > :name_1 AND " + "mytable.myid = :myid_1) AS anon_2 FROM mytable", + {"name_1": "foo", "myid_1": 1}, + ) + def test_funcfilter_within_group(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .within_group(table1.c.name) + ), + "SELECT rank() FILTER (WHERE mytable.name > :name_1) " + "WITHIN GROUP (ORDER BY mytable.name) " + "AS anon_1 FROM mytable", + ) + + def test_within_group(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group(table1.c.name), @@ -858,7 +909,7 @@ def test_funcfilter_within_group(self): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_multi(self): + def test_within_group_multi(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group( @@ -874,7 +925,7 @@ def test_funcfilter_within_group_multi(self): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_desc(self): + def test_within_group_desc(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group(table1.c.name.desc()), @@ -888,7 +939,7 @@ def test_funcfilter_within_group_desc(self): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_w_over(self): + def test_within_group_w_over(self): stmt = select( table1.c.myid, func.percentile_cont(0.5) @@ -904,6 +955,23 @@ def test_funcfilter_within_group_w_over(self): {"percentile_cont_1": 0.5}, ) + def test_within_group_filter(self): + stmt = select( + table1.c.myid, + func.percentile_cont(0.5) + .within_group(table1.c.name) + .filter(table1.c.myid > 42), + ) + self.assert_compile( + stmt, + "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " + "WITHIN GROUP (ORDER BY mytable.name) " + "FILTER (WHERE mytable.myid > :myid_1) " + "AS anon_1 " + "FROM mytable", + {"percentile_cont_1": 0.5, "myid_1": 42}, + ) + def test_incorrect_none_type(self): from sqlalchemy.sql.expression import FunctionElement @@ -1586,8 +1654,7 @@ def test_json_object_keys_with_ordinality(self): def test_alias_column(self): """ - - :: + .. sourcecode:: sql SELECT x, y FROM @@ -1618,8 +1685,7 @@ def test_column_valued_one(self): def test_column_valued_two(self): """ - - :: + .. sourcecode:: sql SELECT x, y FROM @@ -1734,7 +1800,7 @@ def test_render_derived_with_lateral(self, apply_alias_after_lateral): def test_function_alias(self): """ - :: + .. sourcecode:: sql SELECT result_elem -> 'Field' as field FROM "check" AS check_, json_array_elements( diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index ddfb9aea200..a5cfad5b694 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -1120,7 +1120,7 @@ def test_anticipate_no_pk_non_composite_pk(self): Column("q", Integer), ) with expect_warnings( - "Column 't.x' is marked as a member.*" "may not store NULL.$" + "Column 't.x' is marked as a member.*may not store NULL.$" ): self.assert_compile( t.insert(), "INSERT INTO t (q) VALUES (:q)", params={"q": 5} @@ -1136,7 +1136,7 @@ def test_anticipate_no_pk_non_composite_pk_implicit_returning(self): d = postgresql.dialect() d.implicit_returning = True with expect_warnings( - "Column 't.x' is marked as a member.*" "may not store NULL.$" + "Column 't.x' is marked as a member.*may not store NULL.$" ): self.assert_compile( t.insert(), @@ -1156,7 +1156,7 @@ def test_anticipate_no_pk_non_composite_pk_prefetch(self): d.implicit_returning = False with expect_warnings( - "Column 't.x' is marked as a member.*" "may not store NULL.$" + "Column 't.x' is marked as a member.*may not store NULL.$" ): self.assert_compile( t.insert(), @@ -1172,7 +1172,7 @@ def test_anticipate_no_pk_lower_case_table(self): Column("notpk", String(10), nullable=True), ) with expect_warnings( - "Column 't.id' is marked as a member.*" "may not store NULL.$" + "Column 't.id' is marked as a member.*may not store NULL.$" ): self.assert_compile( t.insert(), @@ -1755,7 +1755,7 @@ def test_sql_expression_pk_autoinc_lastinserted(self): self.assert_compile( stmt, - "INSERT INTO sometable (id, data) VALUES " "(foobar(), ?)", + "INSERT INTO sometable (id, data) VALUES (foobar(), ?)", checkparams={"data": "foo"}, params={"data": "foo"}, dialect=dialect, diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 29484696da8..f80b4c447ea 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -17,6 +17,7 @@ from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import literal +from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import sql @@ -472,7 +473,6 @@ def test_no_inserted_pk_on_returning( class TableInsertTest(fixtures.TablesTest): - """test for consistent insert behavior across dialects regarding the inline() method, values() method, lower-case 't' tables. @@ -771,6 +771,27 @@ def define_tables(cls, metadata): Column("x_value", String(50)), Column("y_value", String(50)), ) + Table( + "uniq_cons", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50), unique=True), + ) + + @testing.variation("use_returning", [True, False]) + def test_returning_integrity_error(self, connection, use_returning): + """test for #11532""" + + stmt = self.tables.uniq_cons.insert() + if use_returning: + stmt = stmt.returning(self.tables.uniq_cons.c.id) + + # pymssql thought it would be funny to use OperationalError for + # a unique key violation. + with expect_raises((exc.IntegrityError, exc.OperationalError)): + connection.execute( + stmt, [{"data": "the data"}, {"data": "the data"}] + ) def test_insert_unicode_keys(self, connection): table = self.tables["Unitéble2"] @@ -788,7 +809,8 @@ def test_insert_unicode_keys(self, connection): eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)]) - def test_insert_returning_values(self, connection): + @testing.variation("preserve_rowcount", [True, False]) + def test_insert_returning_values(self, connection, preserve_rowcount): t = self.tables.data conn = connection @@ -797,7 +819,14 @@ def test_insert_returning_values(self, connection): {"x": "x%d" % i, "y": "y%d" % i} for i in range(1, page_size * 2 + 27) ] - result = conn.execute(t.insert().returning(t.c.x, t.c.y), data) + if preserve_rowcount: + eo = {"preserve_rowcount": True} + else: + eo = {} + + result = conn.execute( + t.insert().returning(t.c.x, t.c.y), data, execution_options=eo + ) eq_([tup[0] for tup in result.cursor.description], ["x", "y"]) eq_(result.keys(), ["x", "y"]) @@ -815,6 +844,9 @@ def test_insert_returning_values(self, connection): # assert result.closed assert result.cursor is None + if preserve_rowcount: + eq_(result.rowcount, len(data)) + def test_insert_returning_preexecute_pk(self, metadata, connection): counter = itertools.count(1) @@ -1037,10 +1069,14 @@ def test_insert_w_bindparam_in_subq( eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) - def test_insert_returning_defaults(self, connection): + @testing.variation("preserve_rowcount", [True, False]) + def test_insert_returning_defaults(self, connection, preserve_rowcount): t = self.tables.data - conn = connection + if preserve_rowcount: + conn = connection.execution_options(preserve_rowcount=True) + else: + conn = connection result = conn.execute(t.insert(), {"x": "x0", "y": "y0"}) first_pk = result.inserted_primary_key[0] @@ -1055,6 +1091,9 @@ def test_insert_returning_defaults(self, connection): [(pk, 5) for pk in range(1 + first_pk, total_rows + first_pk)], ) + if preserve_rowcount: + eq_(result.rowcount, total_rows - 1) # range starts from 1 + def test_insert_return_pks_default_values(self, connection): """test sending multiple, empty rows into an INSERT and getting primary key values back. @@ -1439,12 +1478,138 @@ def test_invalid_identities( coll(expected_data), ) + @testing.requires.sequences + @testing.variation("explicit_sentinel", [True, False]) + @testing.variation("sequence_actually_translates", [True, False]) + @testing.variation("the_table_translates", [True, False]) + def test_sequence_schema_translate( + self, + metadata, + connection, + explicit_sentinel, + warn_for_downgrades, + randomize_returning, + sort_by_parameter_order, + sequence_actually_translates, + the_table_translates, + ): + """test #11157""" + + # so there's a bit of a bug which is that functions has_table() + # and has_sequence() do not take schema translate map into account, + # at all. So on MySQL, where we dont have transactional DDL, the + # DROP for Table / Sequence does not really work for all test runs + # when the schema is set to a "to be translated" kind of name. + # so, make a Table/Sequence with fixed schema name for the CREATE, + # then use a different object for the test that has a translate + # schema name + Table( + "t1", + metadata, + Column( + "id", + Integer, + Sequence("some_seq", start=1, schema=config.test_schema), + primary_key=True, + insert_sentinel=bool(explicit_sentinel), + ), + Column("data", String(50)), + schema=config.test_schema if the_table_translates else None, + ) + metadata.create_all(connection) + + if sequence_actually_translates: + connection = connection.execution_options( + schema_translate_map={ + "should_be_translated": config.test_schema + } + ) + sequence = Sequence( + "some_seq", start=1, schema="should_be_translated" + ) + else: + connection = connection.execution_options( + schema_translate_map={"foo": "bar"} + ) + sequence = Sequence("some_seq", start=1, schema=config.test_schema) + + m2 = MetaData() + t1 = Table( + "t1", + m2, + Column( + "id", + Integer, + sequence, + primary_key=True, + insert_sentinel=bool(explicit_sentinel), + ), + Column("data", String(50)), + schema=( + "should_be_translated" + if sequence_actually_translates and the_table_translates + else config.test_schema if the_table_translates else None + ), + ) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + use_imv = testing.db.dialect.use_insertmanyvalues + if ( + use_imv + and explicit_sentinel + and sort_by_parameter_order + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Column t1.id can't be explicitly marked as a sentinel " + r"column .* as the particular type of default generation", + ): + connection.execute(stmt, data) + return + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + server_autoincrement=True, + autoincrement_is_sequence=True, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + expected_data = [(i + 1, f"d{i}") for i in range(10)] + + eq_( + coll(result), + coll(expected_data), + ) + @testing.combinations( Integer(), String(50), (ARRAY(Integer()), testing.requires.array_type), DateTime(), Uuid(), + Uuid(native_uuid=False), argnames="datatype", ) def test_inserts_w_all_nulls( @@ -1620,10 +1785,8 @@ def test_sentinel_cant_match_keys( """test assertions to ensure sentinel values passed in parameter structures can be identified when they come back in cursor.fetchall(). - Values that are further modified by the database driver or by - SQL expressions (as in the case below) before being INSERTed - won't match coming back out, so datatypes need to implement - _sentinel_value_resolver() if this is the case. + Sentinels are now matched based on the data on the outside of the + type, that is, before the bind, and after the result. """ @@ -1636,11 +1799,8 @@ def bind_expression(self, bindparam): if resolve_sentinel_values: - def _sentinel_value_resolver(self, dialect): - def fix_sentinels(value): - return value.lower() - - return fix_sentinels + def process_result_value(self, value, dialect): + return value.replace("upper", "UPPER") t1 = Table( "data", @@ -1672,10 +1832,16 @@ def fix_sentinels(value): connection.execute(stmt, data) else: result = connection.execute(stmt, data) - eq_( - set(result.all()), - {(f"d{i}", f"upper_d{i}") for i in range(10)}, - ) + if resolve_sentinel_values: + eq_( + set(result.all()), + {(f"d{i}", f"UPPER_d{i}") for i in range(10)}, + ) + else: + eq_( + set(result.all()), + {(f"d{i}", f"upper_d{i}") for i in range(10)}, + ) @testing.variation("add_insert_sentinel", [True, False]) def test_sentinel_insert_default_pk_only( @@ -1766,9 +1932,11 @@ def test_no_sentinel_on_non_int_ss_function( Column( "id", Uuid(), - server_default=func.gen_random_uuid() - if default_type.server_side - else None, + server_default=( + func.gen_random_uuid() + if default_type.server_side + else None + ), default=uuid.uuid4 if default_type.client_side else None, primary_key=True, insert_sentinel=bool(add_insert_sentinel), @@ -1987,6 +2155,8 @@ def test_sentinel_col_configurations( "return_type", ["include_sentinel", "default_only", "return_defaults"] ) @testing.variation("add_sentinel_flag_to_col", [True, False]) + @testing.variation("native_uuid", [True, False]) + @testing.variation("as_uuid", [True, False]) def test_sentinel_on_non_autoinc_primary_key( self, metadata, @@ -1995,8 +2165,13 @@ def test_sentinel_on_non_autoinc_primary_key( sort_by_parameter_order, randomize_returning, add_sentinel_flag_to_col, + native_uuid, + as_uuid, ): uuids = [uuid.uuid4() for i in range(10)] + if not as_uuid: + uuids = [str(u) for u in uuids] + _some_uuids = iter(uuids) t1 = Table( @@ -2004,7 +2179,7 @@ def test_sentinel_on_non_autoinc_primary_key( metadata, Column( "id", - Uuid(), + Uuid(native_uuid=bool(native_uuid), as_uuid=bool(as_uuid)), default=functools.partial(next, _some_uuids), primary_key=True, insert_sentinel=bool(add_sentinel_flag_to_col), @@ -2060,7 +2235,7 @@ def test_sentinel_on_non_autoinc_primary_key( collection_cls(r), collection_cls( [ - (uuids[i], f"d{i+1}", "some_server_default") + (uuids[i], f"d{i + 1}", "some_server_default") for i in range(5) ] ), @@ -2072,7 +2247,7 @@ def test_sentinel_on_non_autoinc_primary_key( collection_cls( [ ( - f"d{i+1}", + f"d{i + 1}", "some_server_default", ) for i in range(5) @@ -2096,6 +2271,8 @@ def test_sentinel_on_non_autoinc_primary_key( else: return_type.fail() + @testing.variation("native_uuid", [True, False]) + @testing.variation("as_uuid", [True, False]) def test_client_composite_pk( self, metadata, @@ -2103,15 +2280,19 @@ def test_client_composite_pk( randomize_returning, sort_by_parameter_order, warn_for_downgrades, + native_uuid, + as_uuid, ): uuids = [uuid.uuid4() for i in range(10)] + if not as_uuid: + uuids = [str(u) for u in uuids] t1 = Table( "data", metadata, Column( "id1", - Uuid(), + Uuid(as_uuid=bool(as_uuid), native_uuid=bool(native_uuid)), default=functools.partial(next, iter(uuids)), primary_key=True, ), diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index eed861fe17b..9eb20dd4e59 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -221,7 +221,7 @@ def go(val): self.assert_compile( go("u1"), - "SELECT users.id FROM users " "WHERE users.name = 'u1'", + "SELECT users.id FROM users WHERE users.name = 'u1'", literal_binds=True, ) @@ -413,9 +413,11 @@ def run_my_statement(parameter, add_criteria=False): stmt = lambda_stmt(lambda: select(tab)) stmt = stmt.add_criteria( - lambda s: s.where(tab.c.col > parameter) - if add_criteria - else s.where(tab.c.col == parameter), + lambda s: ( + s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter) + ), ) stmt += lambda s: s.order_by(tab.c.id) @@ -437,9 +439,11 @@ def run_my_statement(parameter, add_criteria=False): stmt = lambda_stmt(lambda: select(tab)) stmt = stmt.add_criteria( - lambda s: s.where(tab.c.col > parameter) - if add_criteria - else s.where(tab.c.col == parameter), + lambda s: ( + s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter) + ), track_on=[add_criteria], ) @@ -1885,6 +1889,47 @@ def upd(id_, newname): (7, "foo"), ) + def test_bindparam_not_cached(self, user_address_fixture, testing_engine): + """test #12084""" + + users, addresses = user_address_fixture + + engine = testing_engine( + share_pool=True, options={"query_cache_size": 0} + ) + with engine.begin() as conn: + conn.execute( + users.insert(), + [{"id": 7, "name": "bar"}, {"id": 8, "name": "foo"}], + ) + + def make_query(stmt, *criteria): + for crit in criteria: + stmt += lambda s: s.where(crit) + + return stmt + + for i in range(2): + with engine.connect() as conn: + stmt = lambda_stmt(lambda: select(users)) + # create a filter criterion that will never match anything + stmt1 = make_query( + stmt, + users.c.name == "bar", + users.c.name == "foo", + ) + + assert len(conn.scalars(stmt1).all()) == 0 + + stmt2 = make_query( + stmt, + users.c.name == "bar", + users.c.name == "bar", + users.c.name == "foo", + ) + + assert len(conn.scalars(stmt2).all()) == 0 + class DeferredLambdaElementTest( fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL @@ -1945,9 +1990,9 @@ def test_detect_change_in_binds_tracking_negative(self): # lambda produces either "t1 IN vv" or "t2 IN qq" based on the # argument. will not produce a consistent cache key elem = lambdas.DeferredLambdaElement( - lambda tab: tab.c.q.in_(vv) - if tab.name == "t1" - else tab.c.q.in_(qq), + lambda tab: ( + tab.c.q.in_(vv) if tab.name == "t1" else tab.c.q.in_(qq) + ), roles.WhereHavingRole, lambda_args=(t1,), opts=lambdas.LambdaOptions(track_closure_variables=False), diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 0b35adc1ccc..1b068c02f7f 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -751,13 +751,25 @@ def test_assorted_repr(self): comment="foo", ), "Column('foo', Integer(), table=None, primary_key=True, " - "nullable=False, onupdate=%s, default=%s, server_default=%s, " - "comment='foo')" - % ( - ColumnDefault(1), - ColumnDefault(42), - DefaultClause("42"), + f"nullable=False, onupdate={ColumnDefault(1)}, default=" + f"{ColumnDefault(42)}, server_default={DefaultClause('42')}, " + "comment='foo')", + ), + ( + Column( + "foo", + Integer, + primary_key=True, + nullable=False, + onupdate=1, + insert_default=42, + server_default="42", + comment="foo", ), + "Column('foo', Integer(), table=None, primary_key=True, " + f"nullable=False, onupdate={ColumnDefault(1)}, default=" + f"{ColumnDefault(42)}, server_default={DefaultClause('42')}, " + "comment='foo')", ), ( Table("bar", MetaData(), Column("x", String)), @@ -1777,6 +1789,18 @@ def test_invalid_objects(self): 12, ) + assert_raises_message( + tsa.exc.ArgumentError, + "'SchemaItem' object, such as a 'Column' or a " + "'Constraint' expected, got " + r"\(Column\('q', Integer\(\), table=None\), " + r"Column\('p', Integer\(\), table=None\)\)", + Table, + "asdf", + MetaData(), + (Column("q", Integer), Column("p", Integer)), + ) + def test_reset_exported_passes(self): m = MetaData() @@ -2371,17 +2395,27 @@ def test_inherit_schema_enum(self): t1 = Table("x", m, Column("y", type_), schema="z") eq_(t1.c.y.type.schema, "z") - def test_to_metadata_copy_type(self): + @testing.variation("assign_metadata", [True, False]) + def test_to_metadata_copy_type(self, assign_metadata): m1 = MetaData() - type_ = self.MyType() + if assign_metadata: + type_ = self.MyType(metadata=m1) + else: + type_ = self.MyType() + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.to_metadata(m2) - # metadata isn't set - is_(t2.c.y.type.metadata, None) + if assign_metadata: + # metadata was transferred + # issue #11802 + is_(t2.c.y.type.metadata, m2) + else: + # metadata isn't set + is_(t2.c.y.type.metadata, None) # our test type sets table, though is_(t2.c.y.type.table, t2) @@ -2411,11 +2445,34 @@ def test_to_metadata_independent_schema(self): eq_(t2.c.y.type.schema, None) - def test_to_metadata_inherit_schema(self): + @testing.combinations( + ("name", "foobar", "name"), + ("schema", "someschema", "schema"), + ("inherit_schema", True, "inherit_schema"), + ("metadata", MetaData(), "metadata"), + ) + def test_copy_args(self, argname, value, attrname): + kw = {argname: value} + e1 = self.MyType(**kw) + + e1_copy = e1.copy() + + eq_(getattr(e1_copy, attrname), value) + + @testing.variation("already_has_a_schema", [True, False]) + def test_to_metadata_inherit_schema(self, already_has_a_schema): m1 = MetaData() - type_ = self.MyType(inherit_schema=True) + if already_has_a_schema: + type_ = self.MyType(schema="foo", inherit_schema=True) + eq_(type_.schema, "foo") + else: + type_ = self.MyType(inherit_schema=True) + t1 = Table("x", m1, Column("y", type_)) + # note that inherit_schema means the schema mutates to be that + # of the table + is_(type_.schema, None) m2 = MetaData() t2 = t1.to_metadata(m2, schema="bar") @@ -2870,7 +2927,7 @@ def go(): assert_raises_message( exc.InvalidRequestError, - "Table 'users' is already defined for this " "MetaData instance.", + "Table 'users' is already defined for this MetaData instance.", go, ) @@ -4134,7 +4191,6 @@ def test_pickle_ck_binary_annotated_col(self, no_pickle_annotated): class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): - """Test Column() construction.""" __dialect__ = "default" @@ -4365,6 +4421,28 @@ def compile_(element, compiler, **kw): deregister(schema.CreateColumn) + @testing.combinations(("index",), ("unique",), argnames="paramname") + @testing.combinations((True,), (False,), (None,), argnames="orig") + @testing.combinations((True,), (False,), (None,), argnames="merging") + def test_merge_index_unique(self, paramname, orig, merging): + """test #11091""" + source = Column(**{paramname: merging}) + + target = Column(**{paramname: orig}) + + source._merge(target) + + target_copy = target._copy() + for col in ( + target, + target_copy, + ): + result = getattr(col, paramname) + if orig is None: + is_(result, merging) + else: + is_(result, orig) + @testing.combinations( ("default", lambda ctx: 10), ("default", func.foo()), @@ -4550,7 +4628,6 @@ def test_dont_merge_column( class ColumnDefaultsTest(fixtures.TestBase): - """test assignment of default fixures to columns""" def _fixture(self, *arg, **kw): @@ -4659,6 +4736,16 @@ def test_column_default_onupdate_keyword_as_clause(self): assert c.onupdate.arg == target assert c.onupdate.column is c + def test_column_insert_default(self): + c = self._fixture(insert_default="y") + assert c.default.arg == "y" + + def test_column_insert_default_predecende_on_default(self): + c = self._fixture(insert_default="x", default="y") + assert c.default.arg == "x" + c = self._fixture(default="y", insert_default="x") + assert c.default.arg == "x" + class ColumnOptionsTest(fixtures.TestBase): def test_default_generators(self): @@ -5655,7 +5742,7 @@ def test_ix_allcols_truncation(self): dialect.max_identifier_length = 15 self.assert_compile( schema.CreateIndex(ix), - "CREATE INDEX ix_user_2de9 ON " '"user" (data, "Data2", "Data3")', + 'CREATE INDEX ix_user_2de9 ON "user" (data, "Data2", "Data3")', dialect=dialect, ) @@ -5780,9 +5867,11 @@ def test_fk_ref_local_referent_has_no_type(self, col_has_type): "b", metadata, Column("id", Integer, primary_key=True), - Column("aid", ForeignKey("a.id")) - if not col_has_type - else Column("aid", Integer, ForeignKey("a.id")), + ( + Column("aid", ForeignKey("a.id")) + if not col_has_type + else Column("aid", Integer, ForeignKey("a.id")) + ), ) fks = list( c for c in b.constraints if isinstance(c, ForeignKeyConstraint) @@ -5937,7 +6026,7 @@ def test_schematype_ck_name_boolean_no_name(self): # no issue with native boolean self.assert_compile( schema.CreateTable(u1), - 'CREATE TABLE "user" (' "x BOOLEAN" ")", + """CREATE TABLE "user" (x BOOLEAN)""", dialect="postgresql", ) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index af51010c761..8ef260a179f 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -83,10 +83,18 @@ def operate(self, op, *other, **kwargs): return op +class ColExpressionDuckTypeOnly: + def __init__(self, expr): + self.expr = expr + + def __clause_element__(self): + return self.expr + + class DefaultColumnComparatorTest( testing.AssertsCompiledSQL, fixtures.TestBase ): - dialect = "default_enhanced" + dialect = __dialect__ = "default_enhanced" @testing.combinations((operators.desc_op, desc), (operators.asc_op, asc)) def test_scalar(self, operator, compare_to): @@ -419,7 +427,7 @@ def test_parenthesized_exprs(self, op, reverse, negate): ), ( lambda p, q: (1 - p) * (2 - q) * (3 - p) * (4 - q), - "(:p_1 - t.p) * (:q_1 - t.q) * " "(:p_2 - t.p) * (:q_2 - t.q)", + "(:p_1 - t.p) * (:q_1 - t.q) * (:p_2 - t.p) * (:q_2 - t.q)", ), ( lambda p, q: ( @@ -483,19 +491,24 @@ def test_associatives(self, op, reverse, negate): if negate: self.assert_compile( select(~expr), - f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) " - "AS anon_1 FROM t" - if not reverse - else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) " - "AS anon_1 FROM t", + ( + f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) " + "AS anon_1 FROM t" + if not reverse + else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) " + "AS anon_1 FROM t" + ), ) else: self.assert_compile( select(expr), - f"SELECT t.q{opstring}t.p{opstring}{exprs} AS anon_1 FROM t" - if not reverse - else f"SELECT {exprs}{opstring}t.q{opstring}t.p " - f"AS anon_1 FROM t", + ( + f"SELECT t.q{opstring}t.p{opstring}{exprs} " + "AS anon_1 FROM t" + if not reverse + else f"SELECT {exprs}{opstring}t.q{opstring}t.p " + "AS anon_1 FROM t" + ), ) @testing.combinations( @@ -565,9 +578,11 @@ def test_non_associatives(self, op, reverse, negate): self.assert_compile( select(~expr), - f"SELECT {str_expr} AS anon_1 FROM t" - if not reverse - else f"SELECT {str_expr} AS anon_1 FROM t", + ( + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t" + ), ) else: if reverse: @@ -583,9 +598,11 @@ def test_non_associatives(self, op, reverse, negate): self.assert_compile( select(expr), - f"SELECT {str_expr} AS anon_1 FROM t" - if not reverse - else f"SELECT {str_expr} AS anon_1 FROM t", + ( + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t" + ), ) @@ -650,9 +667,11 @@ def test_modulus(self, modulus, paramstyle): col = column("somecol", modulus()) self.assert_compile( col.modulus(), - "somecol %%" - if paramstyle in ("format", "pyformat") - else "somecol %", + ( + "somecol %%" + if paramstyle in ("format", "pyformat") + else "somecol %" + ), dialect=default.DefaultDialect(paramstyle=paramstyle), ) @@ -667,9 +686,11 @@ def test_modulus_prefix(self, modulus, paramstyle): col = column("somecol", modulus()) self.assert_compile( col.modulus_prefix(), - "%% somecol" - if paramstyle in ("format", "pyformat") - else "% somecol", + ( + "%% somecol" + if paramstyle in ("format", "pyformat") + else "% somecol" + ), dialect=default.DefaultDialect(paramstyle=paramstyle), ) @@ -1272,7 +1293,6 @@ def _adapt_expression(self, op, other_comparator): class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): - """test standalone booleans being wrapped in an AsBoolean, as well as true/false compilation.""" @@ -1433,7 +1453,6 @@ def test_twelve(self): class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): - """test interaction of and_()/or_() with boolean , null constants""" __dialect__ = default.DefaultDialect(supports_native_boolean=True) @@ -2187,6 +2206,15 @@ def test_in_14(self): "mytable.myid IN (mytable.myid)", ) + def test_in_14_5(self): + """test #12019""" + self.assert_compile( + self.table1.c.myid.in_( + [ColExpressionDuckTypeOnly(self.table1.c.myid)] + ), + "mytable.myid IN (mytable.myid)", + ) + def test_in_15(self): self.assert_compile( self.table1.c.myid.in_(["a", self.table1.c.myid]), @@ -2302,8 +2330,23 @@ def test_in_27(self): ) def test_in_28(self): + """revised to test #12314""" + self.assert_compile( + self.table1.c.myid.in_([None]), + "mytable.myid IN (__[POSTCOMPILE_myid_1])", + ) + + @testing.combinations( + [1, 2, None, 3], + [None, None, None], + [None, 2, 3, 3], + ) + def test_in_null_combinations(self, expr): + """test #12314""" + self.assert_compile( - self.table1.c.myid.in_([None]), "mytable.myid IN (NULL)" + self.table1.c.myid.in_(expr), + "mytable.myid IN (__[POSTCOMPILE_myid_1])", ) @testing.combinations(True, False) @@ -3216,7 +3259,7 @@ def test_regexp_precedence_1(self): self.table.c.myid.match("foo"), self.table.c.myid.regexp_match("xx"), ), - "mytable.myid MATCH :myid_1 AND " "mytable.myid :myid_2", + "mytable.myid MATCH :myid_1 AND mytable.myid :myid_2", ) self.assert_compile( and_( @@ -4540,7 +4583,7 @@ def t_fixture(self): ) return t - @testing.combinations( + null_comparisons = testing.combinations( lambda col: any_(col) == None, lambda col: col.any_() == None, lambda col: any_(col) == null(), @@ -4551,12 +4594,23 @@ def t_fixture(self): lambda col: None == col.any_(), argnames="expr", ) + + @null_comparisons @testing.combinations("int", "array", argnames="datatype") def test_any_generic_null(self, datatype, expr, t_fixture): col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name) + @null_comparisons + @testing.combinations("int", "array", argnames="datatype") + def test_any_generic_null_negate(self, datatype, expr, t_fixture): + col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval + + self.assert_compile( + ~expr(col), "NOT (NULL = ANY (tab1.%s))" % col.name + ) + @testing.fixture( params=[ ("ANY", any_), @@ -4565,48 +4619,78 @@ def test_any_generic_null(self, datatype, expr, t_fixture): ("ALL", lambda x: x.all_()), ] ) - def operator(self, request): + def any_all_operators(self, request): return request.param + # test legacy array any() / all(). these are superseded by the + # any_() / all_() versions @testing.fixture( params=[ ("ANY", lambda x, *o: x.any(*o)), ("ALL", lambda x, *o: x.all(*o)), ] ) - def array_op(self, request): + def legacy_any_all_operators(self, request): return request.param - def test_array(self, t_fixture, operator): + def test_array(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(t.c.arrval), f":param_1 = {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_comparator_array(self, t_fixture, operator): + def test_comparator_inline_negate(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators + self.assert_compile( + 5 != fn(t.c.arrval), + f":param_1 != {op} (tab1.arrval)", + checkparams={"param_1": 5}, + ) + + @testing.combinations( + (operator.eq, "="), + (operator.ne, "!="), + (operator.gt, ">"), + (operator.le, "<="), + argnames="operator,opstring", + ) + def test_comparator_outer_negate( + self, t_fixture, any_all_operators, operator, opstring + ): + """test #10817""" + t = t_fixture + op, fn = any_all_operators + self.assert_compile( + ~(operator(5, fn(t.c.arrval))), + f"NOT (:param_1 {opstring} {op} (tab1.arrval))", + checkparams={"param_1": 5}, + ) + + def test_comparator_array(self, t_fixture, any_all_operators): + t = t_fixture + op, fn = any_all_operators self.assert_compile( 5 > fn(t.c.arrval), f":param_1 > {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_comparator_array_wexpr(self, t_fixture, operator): + def test_comparator_array_wexpr(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( t.c.data > fn(t.c.arrval), f"tab1.data > {op} (tab1.arrval)", checkparams={}, ) - def test_illegal_ops(self, t_fixture, operator): + def test_illegal_ops(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators assert_raises_message( exc.ArgumentError, @@ -4622,10 +4706,10 @@ def test_illegal_ops(self, t_fixture, operator): t.c.data + fn(t.c.arrval), f"tab1.data + {op} (tab1.arrval)" ) - def test_bindparam_coercion(self, t_fixture, array_op): + def test_bindparam_coercion(self, t_fixture, legacy_any_all_operators): """test #7979""" t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators expr = fn(t.c.arrval, bindparam("param")) expected = f"%(param)s = {op} (tab1.arrval)" @@ -4633,9 +4717,11 @@ def test_bindparam_coercion(self, t_fixture, array_op): self.assert_compile(expr, expected, dialect="postgresql") - def test_array_comparator_accessor(self, t_fixture, array_op): + def test_array_comparator_accessor( + self, t_fixture, legacy_any_all_operators + ): t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators self.assert_compile( fn(t.c.arrval, 5, operator.gt), @@ -4643,9 +4729,11 @@ def test_array_comparator_accessor(self, t_fixture, array_op): checkparams={"arrval_1": 5}, ) - def test_array_comparator_negate_accessor(self, t_fixture, array_op): + def test_array_comparator_negate_accessor( + self, t_fixture, legacy_any_all_operators + ): t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators self.assert_compile( ~fn(t.c.arrval, 5, operator.gt), @@ -4653,9 +4741,9 @@ def test_array_comparator_negate_accessor(self, t_fixture, array_op): checkparams={"arrval_1": 5}, ) - def test_array_expression(self, t_fixture, operator): + def test_array_expression(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(t.c.arrval[5:6] + postgresql.array([3, 4])), @@ -4671,9 +4759,9 @@ def test_array_expression(self, t_fixture, operator): dialect="postgresql", ) - def test_subq(self, t_fixture, operator): + def test_subq(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(select(t.c.data).where(t.c.data < 10).scalar_subquery()), @@ -4682,9 +4770,9 @@ def test_subq(self, t_fixture, operator): checkparams={"data_1": 10, "param_1": 5}, ) - def test_scalar_values(self, t_fixture, operator): + def test_scalar_values(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(values(t.c.data).data([(1,), (42,)]).scalar_values()), diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 54943897e11..5d7788fcf1c 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1076,7 +1076,6 @@ def test_select_distinct_limit_offset(self, connection): class CompoundTest(fixtures.TablesTest): - """test compound statements like UNION, INTERSECT, particularly their ability to nest on different databases.""" @@ -1463,7 +1462,6 @@ def test_composite_alias(self, connection): class JoinTest(fixtures.TablesTest): - """Tests join execution. The compiled SQL emitted by the dialect might be ANSI joins or diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 08c9c4207ef..58a64e5c381 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -195,7 +195,9 @@ def test_labels(self): """test the quoting of labels. If labels aren't quoted, a query in postgresql in particular will - fail since it produces:: + fail since it produces: + + .. sourcecode:: sql SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" @@ -821,7 +823,7 @@ def test_apply_labels_shouldnt_quote(self): # what if table/schema *are* quoted? self.assert_compile( t1.select().set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL), - "SELECT " "Foo.T1.Col1 AS Foo_T1_Col1 " "FROM " "Foo.T1", + "SELECT Foo.T1.Col1 AS Foo_T1_Col1 FROM Foo.T1", ) def test_quote_flag_propagate_check_constraint(self): @@ -830,7 +832,7 @@ def test_quote_flag_propagate_check_constraint(self): CheckConstraint(t.c.x > 5) self.assert_compile( schema.CreateTable(t), - "CREATE TABLE t (" '"x" INTEGER, ' 'CHECK ("x" > 5)' ")", + 'CREATE TABLE t ("x" INTEGER, CHECK ("x" > 5))', ) def test_quote_flag_propagate_index(self): @@ -858,7 +860,6 @@ def test_quote_flag_propagate_anon_label(self): class PreparerTest(fixtures.TestBase): - """Test the db-agnostic quoting services of IdentifierPreparer.""" def test_unformat(self): diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index a5d1befa206..93c5c892969 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1,4 +1,5 @@ import collections +from collections import defaultdict import collections.abc as collections_abc from contextlib import contextmanager import csv @@ -490,7 +491,7 @@ def test_pickled_rows(self, connection, use_pickle, use_labels): if use_pickle: with expect_raises_message( exc.NoSuchColumnError, - "Row was unpickled; lookup by ColumnElement is " "unsupported", + "Row was unpickled; lookup by ColumnElement is unsupported", ): result[0]._mapping[users.c.user_id] else: @@ -499,7 +500,7 @@ def test_pickled_rows(self, connection, use_pickle, use_labels): if use_pickle: with expect_raises_message( exc.NoSuchColumnError, - "Row was unpickled; lookup by ColumnElement is " "unsupported", + "Row was unpickled; lookup by ColumnElement is unsupported", ): result[0]._mapping[users.c.user_name] else: @@ -527,8 +528,14 @@ def test_pickle_rows_other_process(self, connection, use_labels): "import sqlalchemy; import pickle; print([" f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])" ) + parts = list(sys.path) + if os.environ.get("PYTHONPATH"): + parts.append(os.environ["PYTHONPATH"]) + pythonpath = os.pathsep.join(parts) proc = subprocess.run( - [sys.executable, "-c", code], stdout=subprocess.PIPE + [sys.executable, "-c", code], + stdout=subprocess.PIPE, + env={**os.environ, "PYTHONPATH": pythonpath}, ) exp = str([r[0] for r in result]).encode() eq_(proc.returncode, 0) @@ -1301,11 +1308,15 @@ def test_label_against_star( stmt = select( *[ - text("*") - if colname == "*" - else users.c.user_name.label("name_label") - if colname == "name_label" - else users.c[colname] + ( + text("*") + if colname == "*" + else ( + users.c.user_name.label("name_label") + if colname == "name_label" + else users.c[colname] + ) + ) for colname in cols ] ) @@ -1730,6 +1741,29 @@ def __getitem__(self, i): eq_(proxy.key, "value") eq_(proxy._mapping["key"], "value") + @contextmanager + def cursor_wrapper(self, engine): + calls = defaultdict(int) + + class CursorWrapper: + def __init__(self, real_cursor): + self.real_cursor = real_cursor + + def __getattr__(self, name): + calls[name] += 1 + return getattr(self.real_cursor, name) + + create_cursor = engine.dialect.execution_ctx_cls.create_cursor + + def new_create(context): + cursor = create_cursor(context) + return CursorWrapper(cursor) + + with patch.object( + engine.dialect.execution_ctx_cls, "create_cursor", new_create + ): + yield calls + def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine): """assert that rowcount is only called on deletes and updates. @@ -1741,33 +1775,71 @@ def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine): engine = testing_engine() + req = testing.requires + t = Table("t1", metadata, Column("data", String(10))) metadata.create_all(engine) - - with patch.object( - engine.dialect.execution_ctx_cls, "rowcount" - ) as mock_rowcount: + count = 0 + with self.cursor_wrapper(engine) as call_counts: with engine.begin() as conn: - mock_rowcount.__get__ = Mock() conn.execute( t.insert(), [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], ) - - eq_(len(mock_rowcount.__get__.mock_calls), 0) + if ( + req.rowcount_always_cached.enabled + or req.rowcount_always_cached_on_insert.enabled + ): + count += 1 + eq_(call_counts["rowcount"], count) eq_( conn.execute(t.select()).fetchall(), [("d1",), ("d2",), ("d3",)], ) - eq_(len(mock_rowcount.__get__.mock_calls), 0) + if req.rowcount_always_cached.enabled: + count += 1 + eq_(call_counts["rowcount"], count) conn.execute(t.update(), {"data": "d4"}) - eq_(len(mock_rowcount.__get__.mock_calls), 1) + count += 1 + eq_(call_counts["rowcount"], count) conn.execute(t.delete()) - eq_(len(mock_rowcount.__get__.mock_calls), 2) + count += 1 + eq_(call_counts["rowcount"], count) + + def test_rowcount_always_called_when_preserve_rowcount( + self, metadata, testing_engine + ): + """assert that rowcount is called on any statement when + ``preserve_rowcount=True``. + + """ + + engine = testing_engine() + + t = Table("t1", metadata, Column("data", String(10))) + metadata.create_all(engine) + + with self.cursor_wrapper(engine) as call_counts: + with engine.begin() as conn: + conn = conn.execution_options(preserve_rowcount=True) + # Do not use insertmanyvalues on any driver + conn.execute(t.insert(), {"data": "d1"}) + + eq_(call_counts["rowcount"], 1) + + eq_(conn.execute(t.select()).fetchall(), [("d1",)]) + eq_(call_counts["rowcount"], 2) + + conn.execute(t.update(), {"data": "d4"}) + + eq_(call_counts["rowcount"], 3) + + conn.execute(t.delete()) + eq_(call_counts["rowcount"], 4) def test_row_is_sequence(self): row = Row(object(), [None], {}, ["value"]) @@ -2501,6 +2573,60 @@ def test_keyed_accessor_column_is_repeated_multiple_times( eq_(row[6], "d3") eq_(row[7], "d3") + @testing.requires.duplicate_names_in_cursor_description + @testing.combinations((None,), (0,), (1,), (2,), argnames="pos") + @testing.variation("texttype", ["literal", "text"]) + def test_dupe_col_targeting(self, connection, pos, texttype): + """test #11306""" + + keyed2 = self.tables.keyed2 + col = keyed2.c.b + data_value = "b2" + + cols = [col, col, col] + expected = [data_value, data_value, data_value] + + if pos is not None: + if texttype.literal: + cols[pos] = literal_column("10") + elif texttype.text: + cols[pos] = text("10") + else: + texttype.fail() + + expected[pos] = 10 + + stmt = select(*cols) + + result = connection.execute(stmt) + + if texttype.text and pos is not None: + # when using text(), the name of the col is taken from + # cursor.description directly since we don't know what's + # inside a text() + key_for_text_col = result.cursor.description[pos][0] + elif texttype.literal and pos is not None: + # for literal_column(), we use the text + key_for_text_col = "10" + + eq_(result.all(), [tuple(expected)]) + + result = connection.execute(stmt).mappings() + if pos is None: + eq_(set(result.keys()), {"b", "b__1", "b__2"}) + eq_( + result.all(), + [{"b": data_value, "b__1": data_value, "b__2": data_value}], + ) + + else: + eq_(set(result.keys()), {"b", "b__1", key_for_text_col}) + + eq_( + result.all(), + [{"b": data_value, "b__1": data_value, key_for_text_col: 10}], + ) + def test_columnclause_schema_column_one(self, connection): # originally addressed by [ticket:2932], however liberalized # Column-targeting rules are deprecated diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 4d55c435db1..6cccd01d4a9 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -690,7 +690,6 @@ def test_insert(self, connection): class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): - """test returning() works with columns that define 'key'.""" __requires__ = ("insert_returning",) @@ -1561,9 +1560,11 @@ def test_upsert_data_w_defaults(self, connection, update_cols): config, t1, (t1.c.id, t1.c.insdef, t1.c.data), - set_lambda=(lambda excluded: {"data": excluded.data + " excluded"}) - if update_cols - else None, + set_lambda=( + (lambda excluded: {"data": excluded.data + " excluded"}) + if update_cols + else None + ), ) upserted_rows = connection.execute( diff --git a/test/sql/test_select.py b/test/sql/test_select.py index e772c5911d0..2bef71dd1e5 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -469,6 +469,21 @@ def test_select_multiple_compound_elements(self, methname, joiner): " %(joiner)s SELECT :param_3 AS anon_3" % {"joiner": joiner}, ) + @testing.combinations( + lambda stmt: stmt.with_statement_hint("some hint"), + lambda stmt: stmt.with_hint(table("x"), "some hint"), + lambda stmt: stmt.where(column("q") == 5), + lambda stmt: stmt.having(column("q") == 5), + lambda stmt: stmt.order_by(column("q")), + lambda stmt: stmt.group_by(column("q")), + # TODO: continue + ) + def test_methods_generative(self, testcase): + s1 = select(1) + s2 = testing.resolve_lambda(testcase, stmt=s1) + + assert s1 is not s2 + class ColumnCollectionAsSelectTest(fixtures.TestBase, AssertsCompiledSQL): """tests related to #8285.""" diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index a146a94c600..6a7be981412 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1,5 +1,9 @@ """Test various algorithmic properties of selectables.""" + from itertools import zip_longest +import random +import threading +import time from sqlalchemy import and_ from sqlalchemy import bindparam @@ -41,6 +45,7 @@ from sqlalchemy.sql import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.sql import operators +from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors @@ -1961,7 +1966,6 @@ def test_fk_join(self): class AnonLabelTest(fixtures.TestBase): - """Test behaviors fixed by [ticket:2168].""" def test_anon_labels_named_column(self): @@ -2044,6 +2048,16 @@ def test_join_standalone_alias_flat(self): "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b", ) + def test_join_alias_name_flat(self): + a = table("a", column("a")) + b = table("b", column("b")) + self.assert_compile( + a.join(b, a.c.a == b.c.b)._anonymous_fromclause( + name="foo", flat=True + ), + "a AS foo_a JOIN b AS foo_b ON foo_a.a = foo_b.b", + ) + def test_composed_join_alias_flat(self): a = table("a", column("a")) b = table("b", column("b")) @@ -2062,6 +2076,24 @@ def test_composed_join_alias_flat(self): "ON b_1.b = c_1.c", ) + def test_composed_join_alias_name_flat(self): + a = table("a", column("a")) + b = table("b", column("b")) + c = table("c", column("c")) + d = table("d", column("d")) + + j1 = a.join(b, a.c.a == b.c.b) + j2 = c.join(d, c.c.c == d.c.d) + + self.assert_compile( + j1.join(j2, b.c.b == c.c.c)._anonymous_fromclause( + name="foo", flat=True + ), + "a AS foo_a JOIN b AS foo_b ON foo_a.a = foo_b.b JOIN " + "(c AS foo_c JOIN d AS foo_d ON foo_c.c = foo_d.d) " + "ON foo_b.b = foo_c.c", + ) + def test_composed_join_alias(self): a = table("a", column("a")) b = table("b", column("b")) @@ -3023,6 +3055,37 @@ def test_replacement_traverse_preserve(self): eq_(whereclause.left._annotations, {"foo": "bar"}) eq_(whereclause.right._annotations, {"foo": "bar"}) + @testing.variation("use_col_ahead_of_time", [True, False]) + def test_set_type_on_column(self, use_col_ahead_of_time): + """test related to #10597""" + + col = Column() + + col_anno = col._annotate({"foo": "bar"}) + + if use_col_ahead_of_time: + expr = col_anno == bindparam("foo") + + # this could only be fixed if we put some kind of a container + # that receives the type directly rather than using NullType; + # like a PendingType or something + + is_(expr.right.type._type_affinity, sqltypes.NullType) + + assert "type" not in col_anno.__dict__ + + col.name = "name" + col._set_type(Integer()) + + eq_(col_anno.name, "name") + is_(col_anno.type._type_affinity, Integer) + + expr = col_anno == bindparam("foo") + + is_(expr.right.type._type_affinity, Integer) + + assert "type" in col_anno.__dict__ + @testing.combinations(True, False, None) def test_setup_inherit_cache(self, inherit_cache_value): if inherit_cache_value is None: @@ -3982,3 +4045,39 @@ def test_copy_internals_multiple_nesting(self): a3 = a2._clone() a3._copy_internals() is_(a1.corresponding_column(a3.c.c), a1.c.c) + + +class FromClauseConcurrencyTest(fixtures.TestBase): + """test for issue 12302""" + + @testing.requires.timing_intensive + def test_c_collection(self): + dictionary_meta = MetaData() + all_indexes_table = Table( + "all_indexes", + dictionary_meta, + *[Column(f"col{i}", Integer) for i in range(50)], + ) + + fails = 0 + + def use_table(): + nonlocal fails + try: + for i in range(3): + time.sleep(random.random() * 0.0001) + all_indexes.c.col35 + except: + fails += 1 + raise + + for j in range(1000): + all_indexes = all_indexes_table.alias("a_indexes") + + threads = [threading.Thread(target=use_table) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not fails, "one or more runs failed" diff --git a/test/sql/test_text.py b/test/sql/test_text.py index de40c8f4298..941a02d9e7e 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -71,7 +71,6 @@ def test_text_adds_to_result_map(self): class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): - """test the usage of text() implicit within the select() construct when strings are passed.""" @@ -471,7 +470,7 @@ def test_escaping_double_colons(self): r"SELECT * FROM pg_attribute WHERE " r"attrelid = :tab\:\:regclass" ), - "SELECT * FROM pg_attribute WHERE " "attrelid = %(tab)s::regclass", + "SELECT * FROM pg_attribute WHERE attrelid = %(tab)s::regclass", params={"tab": None}, dialect="postgresql", ) @@ -484,7 +483,7 @@ def test_double_colons_dont_actually_need_escaping(self): r"SELECT * FROM pg_attribute WHERE " r"attrelid = foo::regclass" ), - "SELECT * FROM pg_attribute WHERE " "attrelid = foo::regclass", + "SELECT * FROM pg_attribute WHERE attrelid = foo::regclass", params={}, dialect="postgresql", ) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index eb91d9c4cdf..abea93418c4 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -3,6 +3,10 @@ import importlib import operator import os +import pickle +import subprocess +import sys +from tempfile import mkstemp import sqlalchemy as sa from sqlalchemy import and_ @@ -15,6 +19,7 @@ from sqlalchemy import cast from sqlalchemy import CHAR from sqlalchemy import CLOB +from sqlalchemy import collate from sqlalchemy import DATE from sqlalchemy import Date from sqlalchemy import DATETIME @@ -57,14 +62,17 @@ from sqlalchemy import types from sqlalchemy import Unicode from sqlalchemy import util +from sqlalchemy import VARBINARY from sqlalchemy import VARCHAR import sqlalchemy.dialects.mysql as mysql import sqlalchemy.dialects.oracle as oracle import sqlalchemy.dialects.postgresql as pg from sqlalchemy.engine import default +from sqlalchemy.engine import interfaces from sqlalchemy.schema import AddConstraint from sqlalchemy.schema import CheckConstraint from sqlalchemy.sql import column +from sqlalchemy.sql import compiler from sqlalchemy.sql import ddl from sqlalchemy.sql import elements from sqlalchemy.sql import null @@ -289,6 +297,7 @@ def test_adapt_method(self, is_down_adaption, typ, target_adaptions): "schema", "metadata", "name", + "dispatch", ): continue # assert each value was copied, or that @@ -443,6 +452,11 @@ def load_dialect_impl(self, dialect): class AsGenericTest(fixtures.TestBase): @testing.combinations( (String(), String()), + (VARBINARY(), LargeBinary()), + (mysql.BINARY(), LargeBinary()), + (mysql.MEDIUMBLOB(), LargeBinary()), + (oracle.RAW(), LargeBinary()), + (pg.BYTEA(), LargeBinary()), (VARCHAR(length=100), String(length=100)), (NVARCHAR(length=100), Unicode(length=100)), (DATE(), Date()), @@ -465,6 +479,9 @@ def test_as_generic(self, t1, t2): (t,) for t in _all_types(omit_special_types=True) if not util.method_is_overridden(t, TypeEngine.as_generic) + and not util.method_is_overridden( + t, TypeEngine._generic_type_affinity + ) ] ) def test_as_generic_all_types_heuristic(self, type_): @@ -496,6 +513,11 @@ def test_as_generic_all_types_custom(self, type_): assert isinstance(gentype, TypeEngine) +class SomeTypeDecorator(TypeDecorator): + impl = String() + cache_ok = True + + class PickleTypesTest(fixtures.TestBase): @testing.combinations( ("Boo", Boolean()), @@ -507,23 +529,112 @@ class PickleTypesTest(fixtures.TestBase): ("Big", BigInteger()), ("Num", Numeric()), ("Flo", Float()), + ("Enu", Enum("one", "two", "three")), ("Dat", DateTime()), ("Dat", Date()), ("Tim", Time()), ("Lar", LargeBinary()), ("Pic", PickleType()), ("Int", Interval()), + ("Dec", SomeTypeDecorator()), + argnames="name,type_", id_="ar", ) - def test_pickle_types(self, name, type_): + @testing.variation("use_adapt", [True, False]) + def test_pickle_types(self, name, type_, use_adapt): + + if use_adapt: + type_ = type_.copy() + column_type = Column(name, type_) meta = MetaData() Table("foo", meta, column_type) + expr = select(1).where(column_type == bindparam("q")) + for loads, dumps in picklers(): loads(dumps(column_type)) loads(dumps(meta)) + expr_str_one = str(expr) + ne = loads(dumps(expr)) + + eq_(str(ne), expr_str_one) + + re_pickle_it = loads(dumps(ne)) + eq_(str(re_pickle_it), expr_str_one) + + def test_pickle_td_comparator(self): + comparator = SomeTypeDecorator().comparator_factory(column("q")) + + expected_mro = ( + TypeDecorator.Comparator, + sqltypes.Concatenable.Comparator, + TypeEngine.Comparator, + ) + eq_(comparator.__class__.__mro__[1:4], expected_mro) + + for loads, dumps in picklers(): + unpickled = loads(dumps(comparator)) + eq_(unpickled.__class__.__mro__[1:4], expected_mro) + + reunpickled = loads(dumps(unpickled)) + eq_(reunpickled.__class__.__mro__[1:4], expected_mro) + + @testing.combinations( + ("Str", String()), + ("Tex", Text()), + ("Uni", Unicode()), + ("Boo", Boolean()), + ("Dat", DateTime()), + ("Dat", Date()), + ("Tim", Time()), + ("Lar", LargeBinary()), + ("Pic", PickleType()), + ("Int", Interval()), + ("Enu", Enum("one", "two", "three")), + argnames="name,type_", + id_="ar", + ) + @testing.variation("use_adapt", [True, False]) + def test_pickle_types_other_process(self, name, type_, use_adapt): + """test for #11530 + + this does a full exec of python interpreter so the number of variations + here is reduced to just a single pickler, else each case takes + a full second. + + """ + + if use_adapt: + type_ = type_.copy() + + column_type = Column(name, type_) + meta = MetaData() + Table("foo", meta, column_type) + + for target in column_type, meta: + f, name = mkstemp("pkl") + with os.fdopen(f, "wb") as f: + pickle.dump(target, f) + + name = name.replace(os.sep, "/") + code = ( + "import sqlalchemy; import pickle; " + f"pickle.load(open('''{name}''', 'rb'))" + ) + parts = list(sys.path) + if os.environ.get("PYTHONPATH"): + parts.append(os.environ["PYTHONPATH"]) + pythonpath = os.pathsep.join(parts) + proc = subprocess.run( + [sys.executable, "-c", code], + env={**os.environ, "PYTHONPATH": pythonpath}, + stderr=subprocess.PIPE, + ) + eq_(proc.returncode, 0, proc.stderr.decode(errors="replace")) + os.unlink(name) + class _UserDefinedTypeFixture: @classmethod @@ -1417,9 +1528,11 @@ def col_to_bind(col): # on the way in here eq_( conn.execute(new_stmt).fetchall(), - [("x", "BIND_INxBIND_OUT")] - if coerce_fn is type_coerce - else [("x", "xBIND_OUT")], + ( + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")] + ), ) def test_cast_bind(self, connection): @@ -1441,9 +1554,11 @@ def _test_bind(self, coerce_fn, conn): eq_( conn.execute(stmt).fetchall(), - [("x", "BIND_INxBIND_OUT")] - if coerce_fn is type_coerce - else [("x", "xBIND_OUT")], + ( + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")] + ), ) def test_cast_existing_typed(self, connection): @@ -1691,6 +1806,19 @@ def get_col_spec(self): ) self.composite = self.variant.with_variant(self.UTypeThree(), "mysql") + def test_copy_doesnt_lose_variants(self): + """test #11176""" + + v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql") + + v_c = v.copy() + + self.assert_compile(v_c, "UTYPEONE", dialect="default") + + self.assert_compile( + v_c, "UTYPETWO", dialect=dialects.postgresql.dialect() + ) + def test_one_dialect_is_req(self): with expect_raises_message( exc.ArgumentError, "At least one dialect name is required" @@ -2299,7 +2427,7 @@ def test_variant_we_are_default(self, metadata): assert_raises( (exc.DBAPIError,), connection.exec_driver_sql, - "insert into my_table " "(data) values('four')", + "insert into my_table (data) values('four')", ) trans.rollback() @@ -3284,6 +3412,91 @@ def test_control(self, connection): ], ) + @testing.fixture + def renders_bind_cast(self): + class MyText(Text): + render_bind_cast = True + + class MyCompiler(compiler.SQLCompiler): + def render_bind_cast(self, type_, dbapi_type, sqltext): + return f"""{sqltext}->BINDCAST->[{ + self.dialect.type_compiler_instance.process( + dbapi_type, identifier_preparer=self.preparer + ) + }]""" + + class MyDialect(default.DefaultDialect): + bind_typing = interfaces.BindTyping.RENDER_CASTS + colspecs = {Text: MyText} + statement_compiler = MyCompiler + + return MyDialect() + + @testing.combinations( + (lambda c1: c1.like("qpr"), "q LIKE :q_1->BINDCAST->[TEXT]"), + ( + lambda c2: c2.like("qpr"), + 'q LIKE :q_1->BINDCAST->[TEXT COLLATE "xyz"]', + ), + ( + # new behavior, a type with no collation passed into collate() + # now has a new type with that collation, so we get the collate + # on the right side bind-cast. previous to #11576 we'd only + # get TEXT for the bindcast. + lambda c1: collate(c1, "abc").like("qpr"), + '(q COLLATE abc) LIKE :param_1->BINDCAST->[TEXT COLLATE "abc"]', + ), + ( + lambda c2: collate(c2, "abc").like("qpr"), + '(q COLLATE abc) LIKE :param_1->BINDCAST->[TEXT COLLATE "abc"]', + ), + argnames="testcase,expected", + ) + @testing.variation("use_type_decorator", [True, False]) + def test_collate_type_interaction( + self, renders_bind_cast, testcase, expected, use_type_decorator + ): + """test #11576. + + This involves dialects that use the render_bind_cast feature only, + currently asycnpg and psycopg. However, the implementation of the + feature is mostly in Core, so a fixture dialect / compiler is used so + that the test is agnostic of those dialects. + + """ + + if use_type_decorator: + + class MyTextThing(TypeDecorator): + cache_ok = True + impl = Text + + c1 = Column("q", MyTextThing()) + c2 = Column("q", MyTextThing(collation="xyz")) + else: + c1 = Column("q", Text()) + c2 = Column("q", Text(collation="xyz")) + + expr = testing.resolve_lambda(testcase, c1=c1, c2=c2) + if use_type_decorator: + assert isinstance(expr.left.type, MyTextThing) + self.assert_compile(expr, expected, dialect=renders_bind_cast) + + # original types still work, have not been modified + eq_(c1.type.collation, None) + eq_(c2.type.collation, "xyz") + + self.assert_compile( + c1.like("qpr"), + "q LIKE :q_1->BINDCAST->[TEXT]", + dialect=renders_bind_cast, + ) + self.assert_compile( + c2.like("qpr"), + 'q LIKE :q_1->BINDCAST->[TEXT COLLATE "xyz"]', + dialect=renders_bind_cast, + ) + def test_bind_adapt(self, connection): # test an untyped bind gets the left side's type @@ -3876,7 +4089,6 @@ def get_col_spec(self, **kw): class NumericRawSQLTest(fixtures.TestBase): - """Test what DBAPIs and dialects return without any typing information supplied at the SQLA level. @@ -4007,7 +4219,6 @@ def test_integer_literal_processor(self): class BooleanTest( fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL ): - """test edge cases for booleans. Note that the main boolean test suite is now in testing/suite/test_types.py diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 74cf1eb4f2e..b741d5d8c0b 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -14,6 +14,7 @@ from sqlalchemy.sql import column from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import roles +from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -174,3 +175,12 @@ def test_unwrap_order_by(self, expr, expected): for a, b in zip_longest(unwrapped, expected): assert a is not None and a.compare(b) + + def test_column_collection_get(self): + col_id = column("id", Integer) + col_alt = column("alt", Integer) + table1 = table("mytable", col_id) + + is_(table1.columns.get("id"), col_id) + is_(table1.columns.get("alt"), None) + is_(table1.columns.get("alt", col_alt), col_alt) diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 4567daa3866..0f1e588bd95 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -10,16 +10,20 @@ from sqlalchemy import select from sqlalchemy import Text from sqlalchemy import UniqueConstraint +from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import INT4RANGE +from sqlalchemy.dialects.postgresql import INT8MULTIRANGE from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column - # test #6402 c1 = Column(UUID()) @@ -77,3 +81,76 @@ class Test(Base): ).on_conflict_do_update( unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 ).excluded.foo.desc() + +s1 = insert(Test) +s1.on_conflict_do_update(set_=s1.excluded) + + +# EXPECTED_TYPE: Column[Range[int]] +reveal_type(Column(INT4RANGE())) +# EXPECTED_TYPE: Column[Range[datetime.date]] +reveal_type(Column("foo", DATERANGE())) +# EXPECTED_TYPE: Column[Sequence[Range[int]]] +reveal_type(Column(INT8MULTIRANGE())) +# EXPECTED_TYPE: Column[Sequence[Range[datetime.datetime]]] +reveal_type(Column("foo", TSTZMULTIRANGE())) + + +range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) + +# EXPECTED_TYPE: Select[Tuple[Range[int], Sequence[Range[int]]]] +reveal_type(range_col_stmt) + +array_from_ints = array(range(2)) + +# EXPECTED_TYPE: array[int] +reveal_type(array_from_ints) + +array_of_strings = array([], type_=Text) + +# EXPECTED_TYPE: array[str] +reveal_type(array_of_strings) + +array_of_ints = array([0], type_=Integer) + +# EXPECTED_TYPE: array[int] +reveal_type(array_of_ints) + +# EXPECTED_MYPY: Cannot infer type argument 1 of "array" +array([0], type_=Text) + +# EXPECTED_TYPE: ARRAY[str] +reveal_type(ARRAY(Text)) + +# EXPECTED_TYPE: Column[Sequence[int]] +reveal_type(Column(type_=ARRAY(Integer))) + +stmt_array_agg = select(func.array_agg(Column("num", type_=Integer))) + +# EXPECTED_TYPE: Select[Tuple[Sequence[int]]] +reveal_type(stmt_array_agg) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(select(func.array_agg(Test.ident_str))) + +stmt_array_agg_order_by_1 = select( + func.array_agg( + aggregate_order_by( + Column("title", type_=Text), + Column("date", type_=DATERANGE).desc(), + Column("id", type_=Integer), + ), + ) +) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(stmt_array_agg_order_by_1) + +stmt_array_agg_order_by_2 = select( + func.array_agg( + aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident), + ) +) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(stmt_array_agg_order_by_2) diff --git a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py index 00debda5096..456f402937a 100644 --- a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py +++ b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py @@ -21,3 +21,6 @@ class Test(Base): insert(Test).on_conflict_do_nothing("foo", Test.id > 0).on_conflict_do_update( unique, Test.id > 0, {"id": 42, Test.data: 99}, Test.id == 22 ).excluded.foo.desc() + +s1 = insert(Test) +s1.on_conflict_do_update(set_=s1.excluded) diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py new file mode 100644 index 00000000000..eedcc309474 --- /dev/null +++ b/test/typing/plain_files/engine/engine_result.py @@ -0,0 +1,93 @@ +from typing import Tuple + +from sqlalchemy import column +from sqlalchemy.engine import Result +from sqlalchemy.engine import Row + + +def row_one(row: Row[Tuple[int, str, bool]]) -> None: + # EXPECTED_TYPE: Any + reveal_type(row[0]) + # EXPECTED_TYPE: Any + reveal_type(row[1]) + # EXPECTED_TYPE: Any + reveal_type(row[2]) + + # EXPECTED_MYPY: No overload variant of "__getitem__" of "Row" matches argument type "str" # noqa: E501 + row["a"] + + # EXPECTED_TYPE: RowMapping + reveal_type(row._mapping) + rm = row._mapping + # EXPECTED_TYPE: Any + reveal_type(rm["foo"]) + # EXPECTED_TYPE: Any + reveal_type(rm[column("bar")]) + + # EXPECTED_MYPY_RE: Invalid index type "int" for "RowMapping"; expected type "(str \| SQLCoreOperations\[Any\]|Union\[str, SQLCoreOperations\[Any\]\])" # noqa: E501 + rm[3] + + +def result_one( + res: Result[Tuple[int, str]], r_single: Result[Tuple[float]] +) -> None: + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(res.one()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.one_or_none()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.fetchone()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.first()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.all()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.fetchmany()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.fetchall()) + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(next(res)) + for rf in res: + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(rf) + for rp in res.partitions(): + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(rp) + + # EXPECTED_TYPE: ScalarResult[Any] + res_s = reveal_type(res.scalars()) + # EXPECTED_TYPE: ScalarResult[Any] + res_s = reveal_type(res.scalars(0)) + # EXPECTED_TYPE: Any + reveal_type(res_s.one()) + # EXPECTED_TYPE: ScalarResult[Any] + reveal_type(res.scalars(1)) + # EXPECTED_TYPE: MappingResult + reveal_type(res.mappings()) + # EXPECTED_TYPE: FrozenResult[Tuple[int, str]] + reveal_type(res.freeze()) + + # EXPECTED_TYPE: Any + reveal_type(res.scalar_one()) + # EXPECTED_TYPE: Union[Any, None] + reveal_type(res.scalar_one_or_none()) + # EXPECTED_TYPE: Any + reveal_type(res.scalar()) + + # EXPECTED_TYPE: ScalarResult[float] + res_s2 = reveal_type(r_single.scalars()) + # EXPECTED_TYPE: ScalarResult[float] + res_s2 = reveal_type(r_single.scalars(0)) + # EXPECTED_TYPE: float + reveal_type(res_s2.one()) + # EXPECTED_TYPE: ScalarResult[Any] + reveal_type(r_single.scalars(1)) + # EXPECTED_TYPE: MappingResult + reveal_type(r_single.mappings()) + + # EXPECTED_TYPE: float + reveal_type(r_single.scalar_one()) + # EXPECTED_TYPE: Union[float, None] + reveal_type(r_single.scalar_one_or_none()) + # EXPECTED_TYPE: Union[float, None] + reveal_type(r_single.scalar()) diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py index 5777b914841..a204fb9182f 100644 --- a/test/typing/plain_files/engine/engines.py +++ b/test/typing/plain_files/engine/engines.py @@ -1,5 +1,6 @@ from sqlalchemy import create_engine from sqlalchemy import Pool +from sqlalchemy import select from sqlalchemy import text @@ -30,5 +31,9 @@ def regular() -> None: engine = create_engine("postgresql://scott:tiger@localhost/test") status: str = engine.pool.status() other_pool: Pool = engine.pool.recreate() + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement print(status, other_pool) diff --git a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py index 664ff0411df..b081aa1b130 100644 --- a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py +++ b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py @@ -2,6 +2,7 @@ for asynchronous ORM use. """ + from __future__ import annotations import asyncio @@ -51,6 +52,10 @@ def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any: pass +def work_with_wrong_parameter(session: Session, foo: int) -> Any: + pass + + async def async_main() -> None: """Main program function.""" @@ -70,6 +75,9 @@ async def async_main() -> None: await session.run_sync(work_with_a_session_one) await session.run_sync(work_with_a_session_two, param="foo") + # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncSession" + await session.run_sync(work_with_wrong_parameter) + session.add_all( [ A(bs=[B(), B()], data="a1"), diff --git a/test/typing/plain_files/ext/asyncio/engines.py b/test/typing/plain_files/ext/asyncio/engines.py index 598d319a776..7c93466e0bf 100644 --- a/test/typing/plain_files/ext/asyncio/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,7 +1,17 @@ +from typing import Any + +from sqlalchemy import Connection +from sqlalchemy import Enum +from sqlalchemy import MetaData +from sqlalchemy import select from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine +def work_sync(conn: Connection, foo: int) -> Any: + pass + + async def asyncio() -> None: e = create_async_engine("sqlite://") @@ -53,3 +63,35 @@ async def asyncio() -> None: # EXPECTED_TYPE: CursorResult[Any] reveal_type(result) + + await conn.run_sync(work_sync, 1) + + # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection" + await conn.run_sync(work_sync) + + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement + + async with e.connect() as conn: + metadata = MetaData() + + await conn.run_sync(metadata.create_all) + await conn.run_sync(metadata.reflect) + await conn.run_sync(metadata.drop_all) + + # Just to avoid creating new constructs manually: + for _, table in metadata.tables.items(): + await conn.run_sync(table.create) + await conn.run_sync(table.drop) + + # Indexes: + for index in table.indexes: + await conn.run_sync(index.create) + await conn.run_sync(index.drop) + + # Test for enum types: + enum = Enum("a", "b") + await conn.run_sync(enum.create) + await conn.run_sync(enum.drop) diff --git a/test/typing/plain_files/ext/misc_ext.py b/test/typing/plain_files/ext/misc_ext.py new file mode 100644 index 00000000000..c44d09bb3e6 --- /dev/null +++ b/test/typing/plain_files/ext/misc_ext.py @@ -0,0 +1,17 @@ +from typing import Any + +from sqlalchemy import JSON +from sqlalchemy import Select +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.ext.mutable import MutableList +from sqlalchemy.sql.compiler import SQLCompiler + + +@compiles(Select[Any], "my_cool_driver") +def go(sel: Select[Any], compiler: SQLCompiler, **kw: Any) -> str: + return "select 42" + + +MutableList.as_mutable(JSON) +MutableDict.as_mutable(JSON()) diff --git a/test/typing/plain_files/orm/issue_9340.py b/test/typing/plain_files/orm/issue_9340.py index 72dc72df1ec..a4fe8c08831 100644 --- a/test/typing/plain_files/orm/issue_9340.py +++ b/test/typing/plain_files/orm/issue_9340.py @@ -10,8 +10,7 @@ from sqlalchemy.orm import with_polymorphic -class Base(DeclarativeBase): - ... +class Base(DeclarativeBase): ... class Message(Base): diff --git a/test/typing/plain_files/orm/mapped_column.py b/test/typing/plain_files/orm/mapped_column.py index 26f5722a6fc..81080a4faa5 100644 --- a/test/typing/plain_files/orm/mapped_column.py +++ b/test/typing/plain_files/orm/mapped_column.py @@ -1,13 +1,20 @@ from typing import Optional +from sqlalchemy import Boolean +from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import String +from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import UniqueConstraint from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.schema import SchemaConst class Base(DeclarativeBase): @@ -94,3 +101,84 @@ class X(Base): ) __table_args__ = (UniqueConstraint(a, b, name="uq1"), Index("ix1", c, d)) + + +mapped_column() +mapped_column( + init=True, + repr=True, + default=42, + compare=True, + kw_only=True, + primary_key=True, + deferred=True, + deferred_group="str", + deferred_raiseload=True, + use_existing_column=True, + name="str", + type_=Integer(), + doc="str", + key="str", + index=True, + unique=True, + info={"str": 42}, + active_history=True, + quote=True, + system=True, + comment="str", + sort_order=-1, + any_kwarg="str", + another_kwarg=42, +) + +mapped_column(default_factory=lambda: 1) +mapped_column(default_factory=lambda: "str") + +mapped_column(nullable=True) +mapped_column(nullable=SchemaConst.NULL_UNSPECIFIED) + +mapped_column(autoincrement=True) +mapped_column(autoincrement="auto") +mapped_column(autoincrement="ignore_fk") + +mapped_column(onupdate=1) +mapped_column(onupdate="str") + +mapped_column(insert_default=1) +mapped_column(insert_default="str") + +mapped_column(server_default=FetchedValue()) +mapped_column(server_default=true()) +mapped_column(server_default=func.now()) +mapped_column(server_default="NOW()") +mapped_column(server_default=text("NOW()")) +mapped_column(server_default=literal_column("false", Boolean)) + +mapped_column(server_onupdate=FetchedValue()) +mapped_column(server_onupdate=true()) +mapped_column(server_onupdate=func.now()) +mapped_column(server_onupdate="NOW()") +mapped_column(server_onupdate=text("NOW()")) +mapped_column(server_onupdate=literal_column("false", Boolean)) + +mapped_column( + default=None, + nullable=None, + primary_key=None, + deferred_group=None, + deferred_raiseload=None, + name=None, + type_=None, + doc=None, + key=None, + index=None, + unique=None, + info=None, + onupdate=None, + insert_default=None, + server_default=None, + server_onupdate=None, + quote=None, + comment=None, + any_kwarg=None, +) diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py index 1a17ee3848b..9eca6e9593f 100644 --- a/test/typing/plain_files/orm/mapped_covariant.py +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -1,13 +1,17 @@ """Tests Mapped covariance.""" from datetime import datetime +from typing import List from typing import Protocol +from typing import Sequence +from typing import TypeVar from typing import Union from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Nullable from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -17,15 +21,17 @@ class ParentProtocol(Protocol): - name: Mapped[str] + # Read-only for simplicity, mutable protocol members are complicated, + # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + @property + def name(self) -> Mapped[str]: ... class ChildProtocol(Protocol): # Read-only for simplicity, mutable protocol members are complicated, # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected @property - def parent(self) -> Mapped[ParentProtocol]: - ... + def parent(self) -> Mapped[ParentProtocol]: ... def get_parent_name(child: ChildProtocol) -> str: @@ -44,6 +50,8 @@ class Parent(Base): name: Mapped[str] = mapped_column(primary_key=True) + children: Mapped[Sequence["Child"]] = relationship("Child") + class Child(Base): __tablename__ = "child" @@ -56,6 +64,23 @@ class Child(Base): assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo" +# Make sure that relationships are covariant as well +_BaseT = TypeVar("_BaseT", bound=Base, covariant=True) +RelationshipType = Union[ + InstrumentedAttribute[_BaseT], + InstrumentedAttribute[Sequence[_BaseT]], + InstrumentedAttribute[Union[_BaseT, None]], +] + + +def operate_on_relationships( + relationships: List[RelationshipType[_BaseT]], +) -> int: + return len(relationships) + + +assert operate_on_relationships([Parent.children, Child.parent]) == 2 + # other test diff --git a/test/typing/plain_files/orm/orm_querying.py b/test/typing/plain_files/orm/orm_querying.py index fa59baad43a..8f18e2fcc18 100644 --- a/test/typing/plain_files/orm/orm_querying.py +++ b/test/typing/plain_files/orm/orm_querying.py @@ -1,7 +1,9 @@ from __future__ import annotations +from sqlalchemy import ColumnElement from sqlalchemy import ForeignKey from sqlalchemy import orm +from sqlalchemy import ScalarSelect from sqlalchemy import select from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase @@ -124,3 +126,26 @@ def load_options_error() -> None: # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .* orm.undefer(B.a).undefer("bar"), ) + + +# test 10959 +def test_10959_with_loader_criteria() -> None: + def where_criteria(cls_: type[A]) -> ColumnElement[bool]: + return cls_.data == "some data" + + orm.with_loader_criteria(A, lambda cls: cls.data == "some data") + orm.with_loader_criteria(A, where_criteria) + + +def test_10937() -> None: + stmt: ScalarSelect[bool] = select(A.id == B.id).scalar_subquery() + stmt1: ScalarSelect[bool] = select(A.id > 0).scalar_subquery() + stmt2: ScalarSelect[int] = select(A.id + 2).scalar_subquery() + stmt3: ScalarSelect[str] = select(A.data + B.data).scalar_subquery() + + select(stmt, stmt2, stmt3, stmt1) + + +def test_bundles() -> None: + b1 = orm.Bundle("b1", A.id, A.data) + orm.Bundle("b2", A.id, A.data, b1) diff --git a/test/typing/plain_files/orm/relationship.py b/test/typing/plain_files/orm/relationship.py index d0ab35249d1..a972e23b83e 100644 --- a/test/typing/plain_files/orm/relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -1,6 +1,5 @@ -"""this suite experiments with other kinds of relationship syntaxes. +"""this suite experiments with other kinds of relationship syntaxes.""" -""" from __future__ import annotations import typing @@ -15,24 +14,38 @@ from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import Table +from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm import with_polymorphic class Base(DeclarativeBase): pass +class Group(Base): + __tablename__ = "group" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + addresses_style_one_anno_only: Mapped[List["User"]] + addresses_style_two_anno_only: Mapped[Set["User"]] + + class User(Base): __tablename__ = "user" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column() + group_id = mapped_column(ForeignKey("group.id")) # this currently doesnt generate an error. not sure how to get the # overloads to hit this one, nor am i sure i really want to do that @@ -57,6 +70,19 @@ class Address(Base): user_style_one: Mapped[User] = relationship() user_style_two: Mapped["User"] = relationship() + rel_style_one: Relationship[List["MoreMail"]] = relationship() + # everything works even if using Relationship instead of Mapped + # users should use Mapped though + rel_style_one_anno_only: Relationship[Set["MoreMail"]] + + +class MoreMail(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + aggress_id = mapped_column(ForeignKey("address.id")) + email: Mapped[str] + class SelfReferential(Base): """test for #9150""" @@ -80,6 +106,30 @@ class SelfReferential(Base): ) +class Employee(Base): + __tablename__ = "employee" + id: Mapped[int] = mapped_column(primary_key=True) + team_id: Mapped[int] = mapped_column(ForeignKey("team.id")) + team: Mapped["Team"] = relationship(back_populates="employees") + + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "employee", + } + + +class Team(Base): + __tablename__ = "team" + id: Mapped[int] = mapped_column(primary_key=True) + employees: Mapped[list[Employee]] = relationship("Employee") + + +class Engineer(Employee): + engineer_info: Mapped[str] + + __mapper_args__ = {"polymorphic_identity": "engineer"} + + if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] reveal_type(User.extra) @@ -99,6 +149,30 @@ class SelfReferential(Base): # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\] reveal_type(User.addresses_style_two) + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.User\]\] + reveal_type(Group.addresses_style_one_anno_only) + + # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.User\]\] + reveal_type(Group.addresses_style_two_anno_only) + + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.MoreMail\]\] + reveal_type(Address.rel_style_one) + + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[relationship.MoreMail\]\] + reveal_type(Address.rel_style_one_anno_only) + + # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\] + reveal_type(Team.employees.of_type(Engineer)) + + # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\] + reveal_type(Team.employees.of_type(aliased(Employee))) + + # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\] + reveal_type(Team.employees.of_type(aliased(Engineer))) + + # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\] + reveal_type(Team.employees.of_type(with_polymorphic(Employee, [Engineer]))) + mapper_registry: registry = registry() diff --git a/test/typing/plain_files/orm/session.py b/test/typing/plain_files/orm/session.py index 0f1c35eafa1..43fb17a7542 100644 --- a/test/typing/plain_files/orm/session.py +++ b/test/typing/plain_files/orm/session.py @@ -97,6 +97,12 @@ class Address(Base): User.id ).offset(User.id) + # test #11083 + + with sess.begin() as tx: + # EXPECTED_TYPE: SessionTransaction + reveal_type(tx) + # more result tests in typed_results.py diff --git a/test/typing/plain_files/orm/trad_relationship_uselist.py b/test/typing/plain_files/orm/trad_relationship_uselist.py index 8d7d7e71a2e..e15fe709341 100644 --- a/test/typing/plain_files/orm/trad_relationship_uselist.py +++ b/test/typing/plain_files/orm/trad_relationship_uselist.py @@ -1,7 +1,5 @@ -"""traditional relationship patterns with explicit uselist. +"""traditional relationship patterns with explicit uselist.""" - -""" import typing from typing import cast from typing import Dict diff --git a/test/typing/plain_files/orm/traditional_relationship.py b/test/typing/plain_files/orm/traditional_relationship.py index 02afc7c8012..bd6bada528c 100644 --- a/test/typing/plain_files/orm/traditional_relationship.py +++ b/test/typing/plain_files/orm/traditional_relationship.py @@ -5,6 +5,7 @@ if no uselists are present. """ + import typing from typing import List from typing import Set diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index 7d8a2dd1a32..b1226da30fc 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -97,7 +97,7 @@ def t_select_3() -> None: # awkwardnesses that aren't really worth it ua(id=1, name="foo") - # EXPECTED_TYPE: Type[User] + # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] reveal_type(ua) stmt = select(ua.id, ua.name).filter(User.id == 5) @@ -529,13 +529,13 @@ def t_aliased_fromclause() -> None: a4 = aliased(user_table) - # EXPECTED_TYPE: Type[User] + # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] reveal_type(a1) - # EXPECTED_TYPE: Type[User] + # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] reveal_type(a2) - # EXPECTED_TYPE: Type[User] + # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] reveal_type(a3) # EXPECTED_TYPE: FromClause diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index 57aae8fac81..5ce0793ac69 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -6,20 +6,26 @@ """ - from __future__ import annotations from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import ColumnElement from sqlalchemy import desc +from sqlalchemy import except_ +from sqlalchemy import except_all from sqlalchemy import Integer +from sqlalchemy import intersect +from sqlalchemy import intersect_all from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import SQLColumnExpression from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import union +from sqlalchemy import union_all from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -99,6 +105,11 @@ def core_expr(email: str) -> SQLColumnExpression[bool]: # EXPECTED_TYPE: Select[Tuple[int]] reveal_type(stmt2) +stmt3 = select(User.id).exists().select() + +# EXPECTED_TYPE: Select[Tuple[bool]] +reveal_type(stmt3) + receives_str_col_expr(User.email) receives_str_col_expr(User.email + "some expr") @@ -172,3 +183,75 @@ def core_expr(email: str) -> SQLColumnExpression[bool]: literal("5"): "q", column("q"): "q", } + +# compound selects (issue #11922): + +str_col = ColumnElement[str]() +int_col = ColumnElement[int]() + +first_stmt = select(str_col, int_col) +second_stmt = select(str_col, int_col) +third_stmt = select(int_col, str_col) + +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(union(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(union_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(except_(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(except_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(intersect(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(intersect_all(first_stmt, second_stmt)) + +# EXPECTED_TYPE: Result[Tuple[str, int]] +reveal_type(Session().execute(union(first_stmt, second_stmt))) +# EXPECTED_TYPE: Result[Tuple[str, int]] +reveal_type(Session().execute(union_all(first_stmt, second_stmt))) + +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.union(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.union_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.except_(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.except_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.intersect(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.intersect_all(second_stmt)) + +# TODO: the following do not error because _SelectStatementForCompoundArgument +# includes untyped elements so the type checker falls back on them when +# the type does not match. Also for the standalone functions mypy +# looses the plot and returns a random type back. See TODO in the +# overloads + +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(union(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(union_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(except_(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(except_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(intersect(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(intersect_all(first_stmt, third_stmt)) + +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.union(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.union_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.except_(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.except_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.intersect(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Tuple[str, int]] +reveal_type(first_stmt.intersect_all(third_stmt)) diff --git a/test/typing/plain_files/sql/core_ddl.py b/test/typing/plain_files/sql/core_ddl.py index b7e0ec5350f..549375d0af2 100644 --- a/test/typing/plain_files/sql/core_ddl.py +++ b/test/typing/plain_files/sql/core_ddl.py @@ -138,10 +138,18 @@ Column(Integer, server_default=literal_column("42", Integer), nullable=False) # server_onupdate -Column("name", server_onupdate=FetchedValue(), nullable=False) Column(server_onupdate=FetchedValue(), nullable=False) +Column(server_onupdate="now()", nullable=False) +Column("name", server_onupdate=FetchedValue(), nullable=False) Column("name", Integer, server_onupdate=FetchedValue(), nullable=False) +Column("name", Integer, server_onupdate=text("now()"), nullable=False) +Column(Boolean, nullable=False, server_default=true()) Column(Integer, server_onupdate=FetchedValue(), nullable=False) +Column(DateTime, server_onupdate="now()") +Column(DateTime, server_onupdate=text("now()")) +Column(DateTime, server_onupdate=FetchedValue()) +Column(Boolean, server_onupdate=literal_column("false", Boolean)) +Column(Integer, server_onupdate=literal_column("42", Integer), nullable=False) # TypeEngine.with_variant should accept both a TypeEngine instance and the Concrete Type Integer().with_variant(Integer, "mysql") diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index e66e554cff7..e1cea4193e4 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -1,125 +1,165 @@ """this file is generated by tools/generate_sql_functions.py""" +from typing import Tuple + from sqlalchemy import column from sqlalchemy import func +from sqlalchemy import Integer +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy import Sequence +from sqlalchemy import String # START GENERATED FUNCTION TYPING TESTS # code within this block is **programmatically, # statically generated** by tools/generate_sql_functions.py -stmt1 = select(func.aggregate_strings(column("x"), column("x"))) +stmt1 = select(func.aggregate_strings(column("x", String), ",")) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt1) -stmt2 = select(func.char_length(column("x"))) +stmt2 = select(func.array_agg(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Sequence\[.*int\]\]\] reveal_type(stmt2) -stmt3 = select(func.concat()) +stmt3 = select(func.char_length(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt3) -stmt4 = select(func.count(column("x"))) +stmt4 = select(func.coalesce(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt4) -stmt5 = select(func.cume_dist()) +stmt5 = select(func.concat()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt5) -stmt6 = select(func.current_date()) +stmt6 = select(func.count(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt6) -stmt7 = select(func.current_time()) +stmt7 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt7) -stmt8 = select(func.current_timestamp()) +stmt8 = select(func.current_date()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] reveal_type(stmt8) -stmt9 = select(func.current_user()) +stmt9 = select(func.current_time()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] reveal_type(stmt9) -stmt10 = select(func.dense_rank()) +stmt10 = select(func.current_timestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt10) -stmt11 = select(func.localtime()) +stmt11 = select(func.current_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt11) -stmt12 = select(func.localtimestamp()) +stmt12 = select(func.dense_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt12) -stmt13 = select(func.next_value(column("x"))) +stmt13 = select(func.localtime()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt13) -stmt14 = select(func.now()) +stmt14 = select(func.localtimestamp()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt14) -stmt15 = select(func.percent_rank()) +stmt15 = select(func.max(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt15) -stmt16 = select(func.rank()) +stmt16 = select(func.min(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt16) -stmt17 = select(func.session_user()) +stmt17 = select(func.next_value(Sequence("x_seq"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt17) -stmt18 = select(func.sysdate()) +stmt18 = select(func.now()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt18) -stmt19 = select(func.user()) +stmt19 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt19) + +stmt20 = select(func.rank()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt20) + + +stmt21 = select(func.session_user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt21) + + +stmt22 = select(func.sum(column("x", Integer))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt22) + + +stmt23 = select(func.sysdate()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt23) + + +stmt24 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt24) + # END GENERATED FUNCTION TYPING TESTS + +stmt_count: Select[Tuple[int, int, int]] = select( + func.count(), func.count("*"), func.count(1) +) diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index edfbd6bb2b1..24b720f6710 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -1,4 +1,7 @@ +from sqlalchemy import column from sqlalchemy import func +from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -14,10 +17,56 @@ class Foo(Base): id: Mapped[int] = mapped_column(primary_key=True) a: Mapped[int] b: Mapped[int] + c: Mapped[str] -func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())) func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(order_by="a", partition_by=("a", "b")) func.row_number().over(partition_by="a", order_by=("a", "b")) + + +# EXPECTED_TYPE: Function[Any] +reveal_type(func.row_number().filter()) +# EXPECTED_TYPE: FunctionFilter[Any] +reveal_type(func.row_number().filter(Foo.a > 0)) +# EXPECTED_TYPE: FunctionFilter[Any] +reveal_type(func.row_number().within_group(Foo.a).filter(Foo.b < 0)) +# EXPECTED_TYPE: WithinGroup[Any] +reveal_type(func.row_number().within_group(Foo.a)) +# EXPECTED_TYPE: WithinGroup[Any] +reveal_type(func.row_number().filter(Foo.a > 0).within_group(Foo.a)) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().filter(Foo.a > 0).over()) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().within_group(Foo.a).over()) + +# test #10801 +# EXPECTED_TYPE: max[int] +reveal_type(func.max(Foo.b)) + + +stmt1 = select(Foo.a, func.min(Foo.b)).group_by(Foo.a) +# EXPECTED_TYPE: Select[Tuple[int, int]] +reveal_type(stmt1) + +# test #10818 +# EXPECTED_TYPE: coalesce[str] +reveal_type(func.coalesce(Foo.c, "a", "b")) +# EXPECTED_TYPE: coalesce[str] +reveal_type(func.coalesce("a", "b")) +# EXPECTED_TYPE: coalesce[int] +reveal_type(func.coalesce(column("x", Integer), 3)) + + +stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a) +# EXPECTED_TYPE: Select[Tuple[int, str]] +reveal_type(stmt2) + + +# EXPECTED_TYPE: TableValuedAlias +reveal_type(func.json_each().table_valued("key", "value")) +# EXPECTED_TYPE: TableValuedAlias +reveal_type(func.json_each().table_valued(Foo.a, Foo.b)) diff --git a/test/typing/plain_files/sql/misc.py b/test/typing/plain_files/sql/misc.py new file mode 100644 index 00000000000..d598af06ef0 --- /dev/null +++ b/test/typing/plain_files/sql/misc.py @@ -0,0 +1,37 @@ +from typing import Any + +from sqlalchemy import column +from sqlalchemy import ColumnElement +from sqlalchemy import Integer +from sqlalchemy import literal +from sqlalchemy import table + + +def test_col_accessors() -> None: + t = table("t", column("a"), column("b"), column("c")) + + t.c.a + t.c["a"] + + t.c[2] + t.c[0, 1] + t.c[0, 1, "b", "c"] + t.c[(0, 1, "b", "c")] + + t.c[:-1] + t.c[0:2] + + +def test_col_get() -> None: + col_id = column("id", Integer) + col_alt = column("alt", Integer) + tbl = table("mytable", col_id) + + # EXPECTED_TYPE: Union[ColumnClause[Any], None] + reveal_type(tbl.c.get("id")) + # EXPECTED_TYPE: Union[ColumnClause[Any], None] + reveal_type(tbl.c.get("id", None)) + # EXPECTED_TYPE: Union[ColumnClause[Any], ColumnClause[int]] + reveal_type(tbl.c.get("alt", col_alt)) + col: ColumnElement[Any] = tbl.c.get("foo", literal("bar")) + print(col) diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py index 2e2f31df9cf..d52461d41f1 100644 --- a/test/typing/plain_files/sql/operators.py +++ b/test/typing/plain_files/sql/operators.py @@ -1,3 +1,4 @@ +import datetime as dt from decimal import Decimal from typing import Any from typing import List @@ -6,6 +7,7 @@ from sqlalchemy import BigInteger from sqlalchemy import column from sqlalchemy import ColumnElement +from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String @@ -100,6 +102,10 @@ class A(Base): add1: "ColumnElement[int]" = A.id + A.id add2: "ColumnElement[int]" = A.id + 1 add3: "ColumnElement[int]" = 1 + A.id +add_date: "ColumnElement[dt.date]" = func.current_date() + dt.timedelta(days=1) +add_datetime: "ColumnElement[dt.datetime]" = ( + func.current_timestamp() + dt.timedelta(seconds=1) +) sub1: "ColumnElement[int]" = A.id - A.id sub2: "ColumnElement[int]" = A.id - 1 @@ -148,3 +154,8 @@ class A(Base): # op functions t1 = operators.eq(A.id, 1) select().where(t1) + +# EXPECTED_TYPE: BinaryExpression[Any] +reveal_type(col.op("->>")("field")) +# EXPECTED_TYPE: Union[BinaryExpression[Any], Grouping[Any]] +reveal_type(col.op("->>")("field").self_group()) diff --git a/test/typing/plain_files/sql/selectables.py b/test/typing/plain_files/sql/selectables.py deleted file mode 100644 index 7d31124587f..00000000000 --- a/test/typing/plain_files/sql/selectables.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import column -from sqlalchemy import table - - -def test_col_accessors() -> None: - t = table("t", column("a"), column("b"), column("c")) - - t.c.a - t.c["a"] - - t.c[2] - t.c[0, 1] - t.c[0, 1, "b", "c"] - t.c[(0, 1, "b", "c")] - - t.c[:-1] - t.c[0:2] diff --git a/test/typing/plain_files/sql/typed_results.py b/test/typing/plain_files/sql/typed_results.py index c7842a7e799..9ed591815af 100644 --- a/test/typing/plain_files/sql/typed_results.py +++ b/test/typing/plain_files/sql/typed_results.py @@ -9,6 +9,7 @@ from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import create_engine +from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import MetaData @@ -118,9 +119,22 @@ def t_result_ctxmanager() -> None: reveal_type(r4) -def t_core_mappings() -> None: +def t_mappings() -> None: r = connection.execute(select(t_user)).mappings().one() - r.get(t_user.c.id) + r["name"] # string + r.get(t_user.c.id) # column + + r2 = connection.execute(select(User)).mappings().one() + r2[User.id] # orm attribute + r2[User.__table__.c.id] # form clause column + + m2 = User.id * 2 + s2 = User.__table__.c.id + 2 + fn = func.abs(User.id) + r3 = connection.execute(select(m2, s2, fn)).mappings().one() + r3[m2] # col element + r3[s2] # also col element + r3[fn] # function def t_entity_varieties() -> None: diff --git a/test/typing/test_overloads.py b/test/typing/test_overloads.py index 968b60d9264..e58b78211b1 100644 --- a/test/typing/test_overloads.py +++ b/test/typing/test_overloads.py @@ -9,6 +9,7 @@ from sqlalchemy.sql.base import Executable from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.util.typing import is_fwd_ref engine_execution_options = { "compiled_cache": "Optional[CompiledCacheType]", @@ -24,6 +25,7 @@ "stream_results": "bool", "max_row_buffer": "int", "yield_per": "int", + "preserve_rowcount": "bool", } orm_dql_execution_options = { @@ -77,6 +79,9 @@ def test_methods(self, class_, expected): @testing.combinations( (CoreExecuteOptionsParameter, core_execution_options), + # note: this failed on python 3.14.0b1 + # due to https://github.com/python/cpython/issues/133701. + # something to keep in mind in case it breaks again (OrmExecuteOptionsParameter, orm_execution_options), ) def test_typed_dicts(self, typ, expected): @@ -89,7 +94,7 @@ def test_typed_dicts(self, typ, expected): expected.pop("opt") assert_annotations = { - key: fwd_ref.__forward_arg__ + key: fwd_ref.__forward_arg__ if is_fwd_ref(fwd_ref) else fwd_ref for key, fwd_ref in typed_dict.__annotations__.items() } eq_(assert_annotations, expected) diff --git a/tools/format_docs_code.py b/tools/format_docs_code.py index 7bae0126b02..a3b6965c862 100644 --- a/tools/format_docs_code.py +++ b/tools/format_docs_code.py @@ -6,12 +6,15 @@ .. versionadded:: 2.0 """ + # mypy: ignore-errors from argparse import ArgumentParser from argparse import RawDescriptionHelpFormatter from collections.abc import Iterator +import dataclasses from functools import partial +from itertools import chain from pathlib import Path import re from typing import NamedTuple @@ -24,7 +27,14 @@ home = Path(__file__).parent.parent -ignore_paths = (re.compile(r"changelog/unreleased_\d{2}"),) +ignore_paths = ( + re.compile(r"changelog/unreleased_\d{2}"), + re.compile(r"README\.unittests\.rst"), + re.compile(r"\.tox"), + re.compile(r"build"), +) + +CUSTOM_TARGET_VERSIONS = {"declarative_tables.rst": "PY312"} class BlockLine(NamedTuple): @@ -44,6 +54,7 @@ def _format_block( errors: list[tuple[int, str, Exception]], is_doctest: bool, file: str, + is_python_file: bool, ) -> list[str]: if not is_doctest: # The first line may have additional padding. Remove then restore later @@ -57,8 +68,15 @@ def _format_block( add_padding = None code = "\n".join(l.code for l in input_block) + mode = PYTHON_BLACK_MODE if is_python_file else RST_BLACK_MODE + custom_target = CUSTOM_TARGET_VERSIONS.get(Path(file).name) + if custom_target: + mode = dataclasses.replace( + mode, target_versions={TargetVersion[custom_target]} + ) + try: - formatted = format_str(code, mode=BLACK_MODE) + formatted = format_str(code, mode=mode) except Exception as e: start_line = input_block[0].line_no first_error = not errors @@ -118,6 +136,7 @@ def _format_block( r"^(((?!\.\.).+::)|(\.\.\s*sourcecode::(.*py.*)?)|(::))$" ) start_space = re.compile(r"^(\s*)[^ ]?") +not_python_line = re.compile(r"^\s+[$:]") def format_file( @@ -130,6 +149,8 @@ def format_file( doctest_block: _Block | None = None plain_block: _Block | None = None + is_python_file = file.suffix == ".py" + plain_code_section = False plain_padding = None plain_padding_len = None @@ -143,6 +164,7 @@ def format_file( errors=errors, is_doctest=True, file=str(file), + is_python_file=is_python_file, ) def doctest_format(): @@ -157,6 +179,7 @@ def doctest_format(): errors=errors, is_doctest=False, file=str(file), + is_python_file=is_python_file, ) def plain_format(): @@ -245,6 +268,14 @@ def plain_format(): ] continue buffer.append(line) + elif ( + is_python_file + and not plain_block + and not_python_line.match(line) + ): + # not a python block. ignore it + plain_code_section = False + buffer.append(line) else: # start of a plain block assert not doctest_block @@ -287,9 +318,12 @@ def plain_format(): def iter_files(directory: str) -> Iterator[Path]: + dir_path = home / directory yield from ( file - for file in (home / directory).glob("./**/*.rst") + for file in chain( + dir_path.glob("./**/*.rst"), dir_path.glob("./**/*.py") + ) if not any(pattern.search(file.as_posix()) for pattern in ignore_paths) ) @@ -316,11 +350,13 @@ def main( print( f"{to_reformat} file(s) would be reformatted;", ( - f"{sum(formatting_error_counts)} formatting errors " - f"reported in {len(formatting_error_counts)} files" - ) - if formatting_error_counts - else "no formatting errors reported", + ( + f"{sum(formatting_error_counts)} formatting errors " + f"reported in {len(formatting_error_counts)} files" + ) + if formatting_error_counts + else "no formatting errors reported" + ), ) exit(1) @@ -349,7 +385,7 @@ def main( "-d", "--directory", help="Find documents in this directory and its sub dirs", - default="doc/build", + default=".", ) parser.add_argument( "-c", @@ -369,7 +405,8 @@ def main( "-l", "--project-line-length", help="Configure the line length to the project value instead " - "of using the black default of 88", + "of using the black default of 88. Python files always use the" + "project line length", action="store_true", ) parser.add_argument( @@ -382,15 +419,24 @@ def main( args = parser.parse_args() config = parse_pyproject_toml(home / "pyproject.toml") - BLACK_MODE = Mode( - target_versions={ - TargetVersion[val.upper()] - for val in config.get("target_version", []) - if val != "py27" - }, - line_length=config.get("line_length", DEFAULT_LINE_LENGTH) - if args.project_line_length - else DEFAULT_LINE_LENGTH, + target_versions = { + TargetVersion[val.upper()] + for val in config.get("target_version", []) + if val != "py27" + } + + RST_BLACK_MODE = Mode( + target_versions=target_versions, + line_length=( + config.get("line_length", DEFAULT_LINE_LENGTH) + if args.project_line_length + else DEFAULT_LINE_LENGTH + ), + ) + PYTHON_BLACK_MODE = Mode( + target_versions=target_versions, + # Remove a few char to account for normal indent + line_length=(config.get("line_length", 4) - 4 or DEFAULT_LINE_LENGTH), ) REPORT_ONLY_DOCTEST = args.report_doctest diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index 9881d26426f..b9f9d572b00 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -40,6 +40,7 @@ .. versionadded:: 2.0 """ + # mypy: ignore-errors from __future__ import annotations @@ -85,9 +86,9 @@ def __repr__(self) -> str: return self.sym -classes: collections.defaultdict[ - str, Dict[str, Tuple[Any, ...]] -] = collections.defaultdict(dict) +classes: collections.defaultdict[str, Dict[str, Tuple[Any, ...]]] = ( + collections.defaultdict(dict) +) _T = TypeVar("_T", bound="Any") @@ -214,18 +215,22 @@ def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: if spec.defaults: new_defaults = tuple( - _repr_sym("util.EMPTY_DICT") - if df is util.EMPTY_DICT - else df + ( + _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + ) for df in spec.defaults ) elem[3] = new_defaults if spec.kwonlydefaults: new_kwonlydefaults = { - name: _repr_sym("util.EMPTY_DICT") - if df is util.EMPTY_DICT - else df + name: ( + _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + ) for name, df in spec.kwonlydefaults.items() } elem[5] = new_kwonlydefaults @@ -365,11 +370,14 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: # use tempfile in same path as the module, or at least in the # current working directory, so that black / zimports use # local pyproject.toml - with NamedTemporaryFile( - mode="w", - delete=False, - suffix=".py", - ) as buf, open(filename) as orig_py: + with ( + NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf, + open(filename) as orig_py, + ): in_block = False current_clsname = None for line in orig_py: @@ -415,9 +423,9 @@ def main(cmd: code_writer_cmd) -> None: from sqlalchemy import util from sqlalchemy.util import langhelpers - util.create_proxy_methods = ( - langhelpers.create_proxy_methods - ) = create_proxy_methods + util.create_proxy_methods = langhelpers.create_proxy_methods = ( + create_proxy_methods + ) for entry in entries: if cmd.args.module in {"all", entry}: diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 848a9272250..624fbb75ed2 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -1,6 +1,5 @@ -"""Generate inline stubs for generic functions on func +"""Generate inline stubs for generic functions on func""" -""" # mypy: ignore-errors from __future__ import annotations @@ -9,8 +8,12 @@ import re from tempfile import NamedTemporaryFile import textwrap +import typing + +import typing_extensions from sqlalchemy.sql.functions import _registry +from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.types import TypeEngine from sqlalchemy.util.tool_support import code_writer_cmd @@ -18,15 +21,21 @@ def _fns_in_deterministic_order(): reg = _registry["_default"] for key in sorted(reg): - yield key, reg[key] + cls = reg[key] + if cls is ReturnTypeFromArgs: + continue + yield key, cls def process_functions(filename: str, cmd: code_writer_cmd) -> str: - with NamedTemporaryFile( - mode="w", - delete=False, - suffix=".py", - ) as buf, open(filename) as orig_py: + with ( + NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf, + open(filename) as orig_py, + ): indent = "" in_block = False @@ -53,23 +62,86 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: for key, fn_class in _fns_in_deterministic_order(): is_reserved_word = key in builtins - guess_its_generic = bool(fn_class.__parameters__) + if issubclass(fn_class, ReturnTypeFromArgs): + buf.write( + textwrap.indent( + f""" + +# set ColumnElement[_T] as a separate overload, to appease +# mypy which seems to not want to accept _T from +# _ColumnExpressionArgument. Seems somewhat related to the covariant +# _HasClauseElement as of mypy 1.15 + +@overload +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + +@overload +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... - buf.write( - textwrap.indent( - f""" +@overload +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + + """, + indent, + ) + ) + else: + guess_its_generic = bool(fn_class.__parameters__) + + # the latest flake8 is quite broken here: + # 1. it insists on linting f-strings, no option + # to turn it off + # 2. the f-string indentation rules are either broken + # or completely impossible to figure out + # 3. there's no way to E501 a too-long f-string, + # so I can't even put the expressions all one line + # to get around the indentation errors + # 4. Therefore here I have to concat part of the + # string outside of the f-string + _type = fn_class.__name__ + _type += "[Any]" if guess_its_generic else "" + _reserved_word = ( + " # noqa: A001" if is_reserved_word else "" + ) + + # now the f-string + buf.write( + textwrap.indent( + f""" @property -def {key}(self) -> Type[{fn_class.__name__}{ - '[Any]' if guess_its_generic else '' -}]:{ - ' # noqa: A001' if is_reserved_word else '' -} +def {key}(self) -> Type[{_type}]:{_reserved_word} ... """, - indent, + indent, + ) ) - ) m = re.match( r"^( *)# START GENERATED FUNCTION TYPING TESTS", @@ -92,15 +164,61 @@ def {key}(self) -> Type[{fn_class.__name__}{ count = 0 for key, fn_class in _fns_in_deterministic_order(): - if hasattr(fn_class, "type") and isinstance( + if issubclass(fn_class, ReturnTypeFromArgs): + count += 1 + + # Would be ReturnTypeFromArgs + (orig_base,) = typing_extensions.get_original_bases( + fn_class + ) + # Type parameter of ReturnTypeFromArgs + (rtype,) = typing.get_args(orig_base) + # The origin type, if rtype is a generic + orig_type = typing.get_origin(rtype) + if orig_type is not None: + coltype = rf".*{orig_type.__name__}\[.*int\]" + else: + coltype = ".*int" + + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}(column('x', Integer))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[{coltype}\]\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + elif fn_class.__name__ == "aggregate_strings": + count += 1 + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}(column('x', String), ',')) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + + elif hasattr(fn_class, "type") and isinstance( fn_class.type, TypeEngine ): python_type = fn_class.type.python_type python_expr = rf"Tuple\[.*{python_type.__name__}\]" argspec = inspect.getfullargspec(fn_class) - args = ", ".join( - 'column("x")' for elem in argspec.args[1:] - ) + if fn_class.__name__ == "next_value": + args = "Sequence('x_seq')" + else: + args = ", ".join( + 'column("x")' for elem in argspec.args[1:] + ) count += 1 buf.write( diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py index 476636b1d0f..8884095b1fa 100644 --- a/tools/generate_tuple_map_overloads.py +++ b/tools/generate_tuple_map_overloads.py @@ -16,6 +16,7 @@ .. versionadded:: 2.0 """ + # mypy: ignore-errors from __future__ import annotations @@ -36,15 +37,21 @@ sys.path.append(str(Path(__file__).parent.parent)) -def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: +def process_module( + modname: str, filename: str, expected_number: int, cmd: code_writer_cmd +) -> str: # use tempfile in same path as the module, or at least in the # current working directory, so that black / zimports use # local pyproject.toml - with NamedTemporaryFile( - mode="w", - delete=False, - suffix=".py", - ) as buf, open(filename) as orig_py: + found = 0 + with ( + NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf, + open(filename) as orig_py, + ): indent = "" in_block = False current_fnname = given_fnname = None @@ -54,6 +61,7 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: line, ) if m: + found += 1 indent = m.group(1) given_fnname = current_fnname = m.group(2) if current_fnname.startswith("self."): @@ -110,16 +118,20 @@ def {current_fnname}( if not in_block: buf.write(line) + if found != expected_number: + raise Exception( + f"{modname} processed {found}. expected {expected_number}" + ) return buf.name -def run_module(modname: str, cmd: code_writer_cmd) -> None: +def run_module(modname: str, count: int, cmd: code_writer_cmd) -> None: cmd.write_status(f"importing module {modname}\n") mod = importlib.import_module(modname) destination_path = mod.__file__ assert destination_path is not None - tempfile = process_module(modname, destination_path, cmd) + tempfile = process_module(modname, destination_path, count, cmd) cmd.run_zimports(tempfile) cmd.run_black(tempfile) @@ -127,17 +139,17 @@ def run_module(modname: str, cmd: code_writer_cmd) -> None: def main(cmd: code_writer_cmd) -> None: - for modname in entries: + for modname, count in entries: if cmd.args.module in {"all", modname}: - run_module(modname, cmd) + run_module(modname, count, cmd) entries = [ - "sqlalchemy.sql._selectable_constructors", - "sqlalchemy.orm.session", - "sqlalchemy.orm.query", - "sqlalchemy.sql.selectable", - "sqlalchemy.sql.dml", + ("sqlalchemy.sql._selectable_constructors", 1), + ("sqlalchemy.orm.session", 1), + ("sqlalchemy.orm.query", 1), + ("sqlalchemy.sql.selectable", 1), + ("sqlalchemy.sql.dml", 3), ] if __name__ == "__main__": @@ -146,7 +158,7 @@ def main(cmd: code_writer_cmd) -> None: with cmd.add_arguments() as parser: parser.add_argument( "--module", - choices=entries + ["all"], + choices=[n for n, _ in entries] + ["all"], default="all", help="Which file to generate. Default is to regenerate all files", ) diff --git a/tools/normalize_file_headers.py b/tools/normalize_file_headers.py new file mode 100644 index 00000000000..ba4cd5734f8 --- /dev/null +++ b/tools/normalize_file_headers.py @@ -0,0 +1,69 @@ +from datetime import date +from pathlib import Path +import re + +from sqlalchemy.util.tool_support import code_writer_cmd + +sa_path = Path(__file__).parent.parent / "lib/sqlalchemy" + + +file_re = re.compile(r"^# [\w+/]+.(?:pyx?|pxd)$", re.MULTILINE) +license_re = re.compile( + r"Copyright .C. (\d+)-\d+ the SQLAlchemy authors and contributors" +) + +this_year = date.today().year +license_ = f""" +# Copyright (C) 2005-{this_year} the SQLAlchemy authors and \ +contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" + + +def run_file(cmd: code_writer_cmd, file: Path, update_year: bool): + content = file.read_text("utf-8") + path = str(file.relative_to(sa_path)).replace("\\", "/") # handle windows + path_comment = f"# {path}" + has_license = bool(license_re.search(content)) + if file_re.match(content.strip()): + if has_license: + to_sub = path_comment + else: + to_sub = path_comment + license_ + content = file_re.sub(to_sub, content, count=1) + else: + content = path_comment + ("\n" if has_license else license_) + content + + if has_license and update_year: + content = license_re.sub( + rf"Copyright (C) \1-{this_year} the SQLAlchemy " + "authors and contributors", + content, + 1, + ) + cmd.write_output_file_from_text(content, file) + + +def run(cmd: code_writer_cmd, update_year: bool): + i = 0 + for ext in ("py", "pyx", "pxd"): + for file in sa_path.glob(f"**/*.{ext}"): + run_file(cmd, file, update_year) + i += 1 + cmd.write_status(f"\nDone. Processed {i} files.") + + +if __name__ == "__main__": + cmd = code_writer_cmd(__file__) + with cmd.add_arguments() as parser: + parser.add_argument( + "--update-year", + action="store_true", + help="Update the year in the license files", + ) + + with cmd.run_program(): + run(cmd, cmd.args.update_year) diff --git a/tools/sync_test_files.py b/tools/sync_test_files.py index f855cd12c2d..4c825c2d7fb 100644 --- a/tools/sync_test_files.py +++ b/tools/sync_test_files.py @@ -6,6 +6,7 @@ from __future__ import annotations from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Any from typing import Iterable @@ -34,7 +35,15 @@ def run_operation( source_data = Path(source).read_text().replace(remove_str, "") dest_data = header.format(source=source, this_file=this_file) + source_data - cmd.write_output_file_from_text(dest_data, dest) + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf: + buf.write(dest_data) + + cmd.run_black(buf.name) + cmd.write_output_file_from_tempfile(buf.name, dest) def main(file: str, cmd: code_writer_cmd) -> None: @@ -51,7 +60,11 @@ def main(file: str, cmd: code_writer_cmd) -> None: "typed_annotation": { "source": "test/orm/declarative/test_typed_mapping.py", "dest": "test/orm/declarative/test_tm_future_annotations_sync.py", - } + }, + "dc_typed_annotation": { + "source": "test/orm/declarative/test_dc_transforms.py", + "dest": "test/orm/declarative/test_dc_transforms_future_anno_sync.py", + }, } if __name__ == "__main__": diff --git a/tools/trace_orm_adapter.py b/tools/trace_orm_adapter.py index de8098bcb8f..72bb08cc484 100644 --- a/tools/trace_orm_adapter.py +++ b/tools/trace_orm_adapter.py @@ -3,26 +3,27 @@ Demos:: - python tools/trace_orm_adapter.py -m pytest \ + $ python tools/trace_orm_adapter.py -m pytest \ test/orm/inheritance/test_polymorphic_rel.py::PolymorphicAliasedJoinsTest::test_primary_eager_aliasing_joinedload - python tools/trace_orm_adapter.py -m pytest \ + $ python tools/trace_orm_adapter.py -m pytest \ test/orm/test_eager_relations.py::LazyLoadOptSpecificityTest::test_pathed_joinedload_aliased_abs_bcs - python tools/trace_orm_adapter.py my_test_script.py + $ python tools/trace_orm_adapter.py my_test_script.py The above two tests should spit out a ton of debug output. If a test or program has no debug output at all, that's a good thing! it means ORMAdapter isn't used for that case. -You can then set a breakpoint at the end of any adapt step: +You can then set a breakpoint at the end of any adapt step:: - python tools/trace_orm_adapter.py -d 10 -m pytest -s \ + $ python tools/trace_orm_adapter.py -d 10 -m pytest -s \ test/orm/test_eager_relations.py::LazyLoadOptSpecificityTest::test_pathed_joinedload_aliased_abs_bcs """ # noqa: E501 + # mypy: ignore-errors diff --git a/tox.ini b/tox.ini index 5b557338883..f776b2a4b63 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,20 @@ [tox] envlist = py +[greenletextras] +extras= + asyncio + sqlite: aiosqlite + sqlite_file: aiosqlite + postgresql: postgresql_asyncpg + mysql: asyncmy + mysql: aiomysql + mssql: aioodbc + + # not greenlet, but tends to not have packaging until the py version + # has been fully released + mssql: mssql_pymssql + [testenv] cov_args=--cov=sqlalchemy --cov-report term --cov-append --cov-report xml --exclude-tag memory-intensive --exclude-tag timing-intensive -k "not aaa_profiling" @@ -14,25 +28,22 @@ usedevelop= cov: True extras= - sqlite: aiosqlite - sqlite_file: aiosqlite - sqlite_file: sqlcipher; python_version < '3.10' + # this can be limited to specific python versions IF there is no + # greenlet available for the most recent python. otherwise + # keep this present in all cases + py{38,39,310,311,312,313,314}: {[greenletextras]extras} + postgresql: postgresql - postgresql: postgresql_asyncpg postgresql: postgresql_pg8000 postgresql: postgresql_psycopg mysql: mysql mysql: pymysql - mysql: asyncmy - mysql: aiomysql mysql: mariadb_connector oracle: oracle oracle: oracle_oracledb mssql: mssql - mssql: aioodbc - py{3,37,38,39,310,311}-mssql: mssql_pymssql install_command= # TODO: I can find no way to get pip / tox / anyone to have this @@ -41,30 +52,29 @@ install_command= python -I -m pip install --only-binary=pymssql {opts} {packages} deps= - pytest>=7.0.0rc1,<8 + typing-extensions>=4.13.0rc1; python_version > '3.7' + + pytest>=7.0.0,<8.4 # tracked by https://github.com/pytest-dev/pytest-xdist/issues/907 pytest-xdist!=3.3.0 - py312: greenlet>=3.0.0a1 + dbapimain-sqlite: git+https://github.com/omnilib/aiosqlite.git\#egg=aiosqlite - dbapimain-sqlite: git+https://github.com/omnilib/aiosqlite.git#egg=aiosqlite - dbapimain-sqlite: git+https://github.com/coleifer/sqlcipher3.git#egg=sqlcipher3 + dbapimain-postgresql: git+https://github.com/psycopg/psycopg2.git\#egg=psycopg2 + dbapimain-postgresql: git+https://github.com/MagicStack/asyncpg.git\#egg=asyncpg + dbapimain-postgresql: git+https://github.com/tlocke/pg8000.git\#egg=pg8000 + dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git\#egg=psycopg&subdirectory=psycopg + # dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git\#egg=psycopg-c&subdirectory=psycopg_c - dbapimain-postgresql: git+https://github.com/psycopg/psycopg2.git#egg=psycopg2 - dbapimain-postgresql: git+https://github.com/MagicStack/asyncpg.git#egg=asyncpg - dbapimain-postgresql: git+https://github.com/tlocke/pg8000.git#egg=pg8000 - dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg&subdirectory=psycopg - # dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg-c&subdirectory=psycopg_c + dbapimain-mysql: git+https://github.com/PyMySQL/mysqlclient-python.git\#egg=mysqlclient + dbapimain-mysql: git+https://github.com/PyMySQL/PyMySQL.git\#egg=pymysql - dbapimain-mysql: git+https://github.com/PyMySQL/mysqlclient-python.git#egg=mysqlclient - dbapimain-mysql: git+https://github.com/PyMySQL/PyMySQL.git#egg=pymysql +# dbapimain-mysql: git+https://github.com/mariadb-corporation/mariadb-connector-python\#egg=mariadb -# dbapimain-mysql: git+https://github.com/mariadb-corporation/mariadb-connector-python#egg=mariadb + dbapimain-oracle: git+https://github.com/oracle/python-cx_Oracle.git\#egg=cx_Oracle - dbapimain-oracle: git+https://github.com/oracle/python-cx_Oracle.git#egg=cx_Oracle - - py312-mssql: git+https://github.com/mkleehammer/pyodbc.git#egg=pyodbc - dbapimain-mssql: git+https://github.com/mkleehammer/pyodbc.git#egg=pyodbc + py313-mssql: git+https://github.com/mkleehammer/pyodbc.git\#egg=pyodbc + dbapimain-mssql: git+https://github.com/mkleehammer/pyodbc.git\#egg=pyodbc cov: pytest-cov @@ -89,19 +99,16 @@ setenv= PYTHONNOUSERSITE=1 PYTEST_EXCLUDES=-m "not memory_intensive and not mypy" + # ensure older pip is installed for EOL python versions + py37: VIRTUALENV_PIP=24.0 + PYTEST_COLOR={tty:--color=yes} MYPY_COLOR={tty:--color-output} - # pytest 'rewrite' is hitting lots of deprecation warnings under py312 and - # i can't find any way to ignore those warnings, so this turns it off - py312: PYTEST_ARGS=--assert plain - - BASECOMMAND=python -m pytest {env:PYTEST_ARGS} {env:PYTEST_COLOR} --rootdir {toxinidir} --log-info=sqlalchemy.testing + BASECOMMAND=python -m pytest {env:PYTEST_COLOR} --rootdir {toxinidir} --log-info=sqlalchemy.testing WORKERS={env:TOX_WORKERS:-n4 --max-worker-restart=5} - - nocext: DISABLE_SQLALCHEMY_CEXT=1 cext: REQUIRE_SQLALCHEMY_CEXT=1 cov: COVERAGE={[testenv]cov_args} @@ -110,21 +117,26 @@ setenv= oracle: WORKERS={env:TOX_WORKERS:-n2 --max-worker-restart=5} oracle: ORACLE={env:TOX_ORACLE:--db oracle} - oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb} + + oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb --dbdriver oracledb_async} sqlite: SQLITE={env:TOX_SQLITE:--db sqlite} sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} - sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric --dbdriver aiosqlite} + py{38,39,310,311,312,313}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric --dbdriver aiosqlite} + py{314}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric} + sqlite-nogreenlet: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric} - py{37,38,39}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher} + # note all of these would need limiting for py314 if we want tests to run until + # greenlet is available. I just dont see any clean way to do this in tox without writing + # all the versions out every time and it's ridiculous - # omit pysqlcipher for Python 3.10 - py{3,310,311,312}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} + sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000 --dbdriver psycopg --dbdriver psycopg_async} + postgresql-nogreenlet: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver pg8000 --dbdriver psycopg} # limit driver list for memusage target memusage: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite} @@ -134,10 +146,14 @@ setenv= mysql: MYSQL={env:TOX_MYSQL:--db mysql} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver asyncmy --dbdriver aiomysql --dbdriver mariadbconnector} + mysql-nogreenlet: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector} mssql: MSSQL={env:TOX_MSSQL:--db mssql} - py{3,37,38,39,310,311}-mssql: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver aioodbc --dbdriver pymssql} - py312-mssql: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver aioodbc} + mssql: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver aioodbc --dbdriver pymssql} + py{314}-mssql: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver aioodbc} + + mssql-nogreenlet: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver pymssql} + py{314}-mssql-nogreenlet: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc} oracle,mssql,sqlite_file: IDENTS=--write-idents db_idents.txt @@ -166,30 +182,36 @@ commands= # this line is only meaningful when usedevelop=True is enabled. we use # that flag for coverage mode. nocext: sh -c "rm -f lib/sqlalchemy/*.so" - + nogreenlet: pip uninstall -y greenlet {env:BASECOMMAND} {env:WORKERS} {env:SQLITE:} {env:EXTRA_SQLITE_DRIVERS:} {env:POSTGRESQL:} {env:EXTRA_PG_DRIVERS:} {env:MYSQL:} {env:EXTRA_MYSQL_DRIVERS:} {env:ORACLE:} {env:EXTRA_ORACLE_DRIVERS:} {env:MSSQL:} {env:EXTRA_MSSQL_DRIVERS:} {env:IDENTS:} {env:PYTEST_EXCLUDES:} {env:COVERAGE:} {posargs} oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt [testenv:pep484] deps= - greenlet != 0.4.17 + greenlet >= 1 importlib_metadata; python_version < '3.8' - mypy >= 1.6.0 + mypy >= 1.14.0 + types-greenlet commands = mypy {env:MYPY_COLOR} ./lib/sqlalchemy # pyright changes too often with not-exactly-correct errors # suddently appearing for it to be stable enough for CI # pyright +extras = + {[greenletextras]extras} + [testenv:mypy] deps= - pytest>=7.0.0rc1,<8 + pytest>=7.0.0rc1,<8.4 pytest-xdist - greenlet != 0.4.17 + greenlet >= 1 importlib_metadata; python_version < '3.8' - mypy >= 1.2.0 + mypy >= 1.2.0,<1.11.0 patch==1.* +extras= + {[greenletextras]extras} commands = pytest {env:PYTEST_COLOR} -m mypy {posargs} @@ -200,6 +222,9 @@ deps= {[testenv:mypy]deps} pytest-cov +extras= + {[greenletextras]extras} + commands = pytest {env:PYTEST_COLOR} -m mypy {env:COVERAGE} {posargs} @@ -209,8 +234,12 @@ setenv= # thanks to https://julien.danjou.info/the-best-flake8-extensions/ [testenv:lint] basepython = python3 + +extras= + {[greenletextras]extras} + deps= - flake8==6.0.0 + flake8==7.2.0 flake8-import-order flake8-builtins flake8-future-annotations>=0.0.5 @@ -222,7 +251,7 @@ deps= # in case it requires a version pin pydocstyle pygments - black==23.3.0 + black==25.1.0 slotscheck>=0.17.0 # required by generate_tuple_map_overloads @@ -244,6 +273,7 @@ commands = python ./tools/generate_proxy_methods.py --check python ./tools/sync_test_files.py --check python ./tools/generate_sql_functions.py --check + python ./tools/normalize_file_headers.py --check python ./tools/walk_packages.py @@ -254,10 +284,15 @@ basepython = {[testenv:lint]basepython} deps = {[testenv:lint]deps} allowlist_externals = {[testenv:lint]allowlist_externals} commands = {[testenv:lint]commands} +extras = {[testenv:lint]extras} + # command run in the github action when cext are active. [testenv:github-cext] +extras= + {[greenletextras]extras} + deps = {[testenv]deps} .[aiosqlite] commands= @@ -266,6 +301,9 @@ commands= # command run in the github action when cext are not active. [testenv:github-nocext] +extras= + {[greenletextras]extras} + deps = {[testenv]deps} .[aiosqlite] commands=