From 716a1ea463983cf8052e28bf8ff1ea679191171e Mon Sep 17 00:00:00 2001 From: Serget Date: Mon, 13 Mar 2023 20:53:04 +0300 Subject: [PATCH 1/3] add support async postgres driver --- core/testcontainers/core/generic.py | 11 +++++++---- postgres/setup.py | 2 ++ postgres/testcontainers/postgres/__init__.py | 8 ++++++-- postgres/tests/test_postgres.py | 12 ++++++++++++ 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index f295bc01a..b0b690dcf 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -28,13 +28,16 @@ class DbContainer(DockerContainer): """ Generic database container. """ + + DEFAULT_DRIVER = None + @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) - def _connect(self) -> None: + def _connect(self, driver: Optional[str] = None) -> None: import sqlalchemy - engine = sqlalchemy.create_engine(self.get_connection_url()) + engine = sqlalchemy.create_engine(self.get_connection_url(driver=driver)) engine.connect() - def get_connection_url(self) -> str: + def get_connection_url(self, host=None, driver=None) -> str: raise NotImplementedError def _create_connection_url(self, dialect: str, username: str, password: str, @@ -54,7 +57,7 @@ def _create_connection_url(self, dialect: str, username: str, password: str, def start(self) -> 'DbContainer': self._configure() super().start() - self._connect() + self._connect(self.DEFAULT_DRIVER) return self def _configure(self) -> None: diff --git a/postgres/setup.py b/postgres/setup.py index 1d9abd351..05f7893c1 100644 --- a/postgres/setup.py +++ b/postgres/setup.py @@ -14,6 +14,8 @@ "testcontainers-core", "sqlalchemy", "psycopg2-binary", + "asyncpg", + "pytest-asyncio", ], python_requires=">=3.7", ) diff --git a/postgres/testcontainers/postgres/__init__.py b/postgres/testcontainers/postgres/__init__.py index fb8206952..246508821 100644 --- a/postgres/testcontainers/postgres/__init__.py +++ b/postgres/testcontainers/postgres/__init__.py @@ -41,6 +41,7 @@ class PostgresContainer(DbContainer): POSTGRES_USER = os.environ.get("POSTGRES_USER", "test") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "test") POSTGRES_DB = os.environ.get("POSTGRES_DB", "test") + DEFAULT_DRIVER = "psycopg2" def __init__(self, image: str = "postgres:latest", port: int = 5432, user: Optional[str] = None, password: Optional[str] = None, dbname: Optional[str] = None, @@ -59,9 +60,12 @@ def _configure(self) -> None: self.with_env("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD) self.with_env("POSTGRES_DB", self.POSTGRES_DB) - def get_connection_url(self, host=None) -> str: + def get_connection_url(self, host=None, driver: str = None) -> str: + if not driver: + driver = self.driver + return super()._create_connection_url( - dialect="postgresql+{}".format(self.driver), username=self.POSTGRES_USER, + dialect="postgresql+{}".format(driver), username=self.POSTGRES_USER, password=self.POSTGRES_PASSWORD, db_name=self.POSTGRES_DB, host=host, port=self.port_to_expose, ) diff --git a/postgres/tests/test_postgres.py b/postgres/tests/test_postgres.py index c00c1b3fe..9042f0b8e 100644 --- a/postgres/tests/test_postgres.py +++ b/postgres/tests/test_postgres.py @@ -1,4 +1,6 @@ +import pytest import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine from testcontainers.postgres import PostgresContainer @@ -18,3 +20,13 @@ def test_docker_run_postgres_with_driver_pg8000(): engine = sqlalchemy.create_engine(postgres.get_connection_url()) with engine.begin() as connection: connection.execute(sqlalchemy.text("select 1=1")) + + +@pytest.mark.asyncio +async def test_docker_run_async_postgres(): + with PostgresContainer("postgres:9.5", driver="asyncpg") as postgres: + engine = create_async_engine(postgres.get_connection_url()) + async with engine.begin() as connection: + result = await connection.execute(sqlalchemy.text("select version()")) + for row in result: + assert row[0].lower().startswith("postgresql 9.5") From 4faa31cb5293d0aa3d55e8bd2bdc02410c513f38 Mon Sep 17 00:00:00 2001 From: Serget Date: Thu, 16 Mar 2023 14:12:08 +0300 Subject: [PATCH 2/3] refactoring postgresql container --- core/testcontainers/core/generic.py | 13 +++++------ postgres/testcontainers/postgres/__init__.py | 24 +++++++++++++++++--- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index b0b690dcf..a76c98beb 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -29,15 +29,14 @@ class DbContainer(DockerContainer): Generic database container. """ - DEFAULT_DRIVER = None - @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) - def _connect(self, driver: Optional[str] = None) -> None: + def _connect(self) -> None: import sqlalchemy - engine = sqlalchemy.create_engine(self.get_connection_url(driver=driver)) - engine.connect() + engine = sqlalchemy.create_engine(self.get_connection_url()) + conn = engine.connect() + conn.close() - def get_connection_url(self, host=None, driver=None) -> str: + def get_connection_url(self) -> str: raise NotImplementedError def _create_connection_url(self, dialect: str, username: str, password: str, @@ -57,7 +56,7 @@ def _create_connection_url(self, dialect: str, username: str, password: str, def start(self) -> 'DbContainer': self._configure() super().start() - self._connect(self.DEFAULT_DRIVER) + self._connect() return self def _configure(self) -> None: diff --git a/postgres/testcontainers/postgres/__init__.py b/postgres/testcontainers/postgres/__init__.py index 246508821..67f8e7f62 100644 --- a/postgres/testcontainers/postgres/__init__.py +++ b/postgres/testcontainers/postgres/__init__.py @@ -13,6 +13,14 @@ import os from typing import Optional from testcontainers.core.generic import DbContainer +from testcontainers.core.waiting_utils import wait_container_is_ready + +ADDITIONAL_TRANSIENT_ERRORS = [] +try: + from sqlalchemy.exc import DBAPIError + ADDITIONAL_TRANSIENT_ERRORS.append(DBAPIError) +except ImportError: + pass class PostgresContainer(DbContainer): @@ -45,7 +53,10 @@ class PostgresContainer(DbContainer): def __init__(self, image: str = "postgres:latest", port: int = 5432, user: Optional[str] = None, password: Optional[str] = None, dbname: Optional[str] = None, - driver: str = "psycopg2", **kwargs) -> None: + driver: Optional[str] = None, **kwargs) -> None: + if driver is None: + driver = self.DEFAULT_DRIVER + super(PostgresContainer, self).__init__(image=image, **kwargs) self.POSTGRES_USER = user or self.POSTGRES_USER self.POSTGRES_PASSWORD = password or self.POSTGRES_PASSWORD @@ -55,13 +66,20 @@ def __init__(self, image: str = "postgres:latest", port: int = 5432, user: Optio self.with_exposed_ports(self.port_to_expose) + @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) + def _connect(self) -> None: + import sqlalchemy + engine = sqlalchemy.create_engine(self.get_connection_url(driver=self.DEFAULT_DRIVER)) + conn = engine.connect() + conn.close() + def _configure(self) -> None: self.with_env("POSTGRES_USER", self.POSTGRES_USER) self.with_env("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD) self.with_env("POSTGRES_DB", self.POSTGRES_DB) - def get_connection_url(self, host=None, driver: str = None) -> str: - if not driver: + def get_connection_url(self, host: Optional[str] = None, driver: Optional[str] = None) -> str: + if driver is None: driver = self.driver return super()._create_connection_url( From 0b4df386efa63fdfaa2bd0647576be8e1282cefc Mon Sep 17 00:00:00 2001 From: Serget Date: Wed, 12 Apr 2023 12:18:41 +0300 Subject: [PATCH 3/3] move package pytest-asyncio in requirements.in --- postgres/setup.py | 1 - requirements.in | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres/setup.py b/postgres/setup.py index 05f7893c1..4257fc33e 100644 --- a/postgres/setup.py +++ b/postgres/setup.py @@ -15,7 +15,6 @@ "sqlalchemy", "psycopg2-binary", "asyncpg", - "pytest-asyncio", ], python_requires=">=3.7", ) diff --git a/requirements.in b/requirements.in index 0204fdb8e..127eb578d 100644 --- a/requirements.in +++ b/requirements.in @@ -27,6 +27,7 @@ flake8<3.8.0 # 3.8.0 adds a dependency on importlib-metadata which conflicts wi pg8000 pytest pytest-cov +pytest-asyncio sphinx twine wheel