diff --git a/.pylintrc b/.pylintrc index 7360038..4845be9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -10,7 +10,7 @@ fail-under=10.0 # Add files or directories to the blacklist. They should be base names, not # paths. -ignore=CVS +ignore=CVS,stubs # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. @@ -479,7 +479,7 @@ ignore-comments=yes ignore-docstrings=yes # Ignore imports when computing similarities. -ignore-imports=no +ignore-imports=yes # Minimum lines number of a similarity. min-similarity-lines=4 diff --git a/README.md b/README.md index e627766..d802d93 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # DB Wrapper Lib -A simple wrapper on [aio-libs/aiopg](https://github.com/aio-libs/aiopg). +A simple wrapper on [aio-libs/aiopg](https://github.com/aio-libs/aiopg) or [psycopg/psycopg2](https://github.com/psycopg/psycopg2). Encapsulates connection logic & execution logic into one Client class for convenience. ## Installation @@ -9,32 +9,32 @@ Install with `pip` from releases on this repo. For example, you can install version 0.1.0 with the following command: ``` -$ pip install https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/0.1.2/db-wrapper-0.1.2.tar.gz +$ pip install https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.1.0/db-wrapper-2.1.0.tar.gz ``` -If looking for a different release version, just replace the two instances of `0.1.2` in the command with the version number you need. +If looking for a different release version, just replace the two instances of `2.1.0` in the command with the version number you need. ## Usage -This library uses a fairly simple API to manage asynchronously connecting to, executing queries on, & disconnecting from a PostgresQL database. -Additionally, it includes a very simple Model abstraction to help with defining queries for a given model & managing separation of concerns in your application. +This library uses a fairly simple API to manage connecting to, executing queries on, & disconnecting from a PostgresQL database, in both synchronous & asynchronous APIs. +Additionally, it includes a very simple Model abstraction to help with declaring data types, enforcing types at runtime, defining queries for a given model, & managing separation of concerns in your application. -### Example `Client` +### Example: Clients Intializing a database `Client` & executing a query begins with defining a connection & giving it to `Client` on intialization: ```python -from db_wrapper import ConnectionParameters +from db_wrapper import ConnectionParameters, AsyncClient connection_parameters = ConnectionParameters( host='localhost', user='postgres', password='postgres', database='postgres') -client = Client(connection_parameters) +client = AsyncClient(connection_parameters) ``` -From there, you need to tell the client to connect using `Client.connect()` before you can execute any queries. +From there, you need to tell the client to connect using `client.connect()` before you can execute any queries. This method is asynchronous though, so you need to supply an async/await runtime. ```python @@ -45,7 +45,7 @@ import asyncio async def a_query() -> None: # we'll come back to this part - # just know that it usins async/await to call Client.execute_and_return + # just know that it uses async/await to call Client.execute_and_return result = await client.execute_and_return(query) # do something with the result... @@ -78,14 +78,42 @@ async def a_query() -> None: ``` -### Example: `Model` +Alternatively, everything can also be done synchronously, using an API that is almost exactly the same. +Simply drop the async/await keywords & skip the async event loop, then proceed in exactly the same fashion: -Using `Model` isn't necessary at all, you can just interact directly with the `Client` instance using it's `execute` & `execute_and_return` methods to execute SQL queries as needed. -`Model` may be helpful in managing your separation of concerns by giving you a single place to define queries related to a given data model in your database. -Additionally, `Model` will be helpful in defining types, if you're using mypy to check your types in development. +```python +from db_wrapper import ConnectionParameters, SyncClient + +connection_parameters = ConnectionParameters( + host='localhost', + user='postgres', + password='postgres', + database='postgres') +client = SyncClient(connection_parameters) + + +def a_query() -> None: + query = 'SELECT table_name' \ + 'FROM information_schema.tables' \ + 'WHERE table_schema = public' + result = client.execute_and_return(query) + + assert result[0] == 'postgres' + + +client.connect() +a_query() +client.disconnect() +``` + +### Example: Models + +Using `AsyncModel` or `SyncModel` isn't necessary at all, you can just interact directly with the Client instance using it's `execute` & `execute_and_return` methods to execute SQL queries as needed. +A Model may be helpful in managing your separation of concerns by giving you a single place to define queries related to a given data model in your database. +Additionally, `Model` will be helpful in defining types, if you're using mypy to check your types in development, & in enforcing types at runtime using pydantic.. It has no concept of types at runtime, however, & cannot be relied upon to constrain data types & shapes during runtime. -A `Model` instance has 4 properties, corresponding with each of the CRUD operations: `create`, `read`, `update`, & `delete`. +A Model instance has 4 properties, corresponding with each of the CRUD operations: `create`, `read`, `update`, & `delete`. Each CRUD property has one built-in method to handle the simplest of queries for you already (create one record, read one record by id, update one record by id, & delete one record by id). Using a model requires defining it's expected type (using `ModelData`), initializing a new instance, then calling the query methods as needed. @@ -102,10 +130,11 @@ class AModel(ModelData): a_boolean_value: bool ``` -Subclassing `ModelData` is important because `Model` expects all records to be constrained to a dictionary containing at least one field labeled `_id` & constrained to the UUID type. This means the above `AModel` will contain records that look like the following dictionary in python: +Subclassing `ModelData` is important because `Model` expects all records to be constrained to a Subclass of `ModelData`, containing least one property labeled `_id` constrained to the UUID type. +This means the above `AModel` will contain records that look like the following dictionary in python: ```python -a_model_result = { +a_model_result.dict() == { _id: UUID(...), a_string_value: 'some string', a_number_value: 12345, @@ -113,24 +142,24 @@ a_model_result = { } ``` -Then to initialize your `Model` with your new expected type, simply initialize `Model` by passing `AModel` as a type parameter, a `Client` instance, & the name of the table this `Model` will be represented on: +Then to initialize your Model with your new expected type, simply initialize `AsyncModel` or `SyncModel` by passing `AModel` as a type parameter, a matching Client instance, & the name of the table this Model will be represented on: ```python from db_wrapper import ( ConnectionParameters, - Client, - Model, + AsyncClient, + AsyncModel, ModelData, ) connection_parameters = ConnectionParameters(...) -client = Client(...) +client = AsyncClient(...) class AModel(ModelData): # ... -a_model = Model[AModel](client, 'a_table_name') +a_model = AsyncModel[AModel](client, 'a_table_name') ``` From there, you can query your new `Model` by calling CRUD methods on the instance: @@ -142,7 +171,8 @@ from typing import List async get_some_record() -> List[AModel]: - return await a_model.read.one_by_id('some record id') # NOTE: in reality the id would be a UUID + return await a_model.read.one_by_id('some record id') + # NOTE: in reality the id would be a UUID ``` Of course, just having methods for creating, reading, updating, or deleting a single record at a time often won't be enough. @@ -152,8 +182,7 @@ For example, if you want to write an additional query for reading any record tha ```python from db_wrapper import ModelData -from db_wrapper.model import Read -from psycopg2 import sql +from db_wrapper.model import AsyncRead, sql # ... @@ -162,7 +191,7 @@ class AnotherModel(ModelData): a_field: str -class ExtendedReader(Read[AnotherModel]): +class ExtendedReader(AsyncRead[AnotherModel]): """Add custom method to Model.read.""" async def all_with_some_string_value(self) -> List[AnotherModel]: @@ -173,7 +202,9 @@ class ExtendedReader(Read[AnotherModel]): # sql module query = sql.SQL( 'SELECT * ' - 'FROM {table} ' # a Model knows it's own table name, no need to specify it manually here + 'FROM {table} ' + # a Model knows it's own table name, + # no need to specify it manually here 'WHERE a_field = 'some value';' ).format(table=self._table) @@ -183,24 +214,24 @@ class ExtendedReader(Read[AnotherModel]): return result ``` -Then, you would subclass `Model` & redefine it's read property to be an instance of your new `ExtendedReader` class: +Then, you would subclass `AsyncModel` & redefine it's read property to be an instance of your new `ExtendedReader` class: ```python -from db_wrapper import Client, Model, ModelData +from db_wrapper import AsyncClient, AsyncModel, ModelData # ... -class ExtendedModel(Model[AnotherModel]): +class ExtendedModel(AsyncModel[AnotherModel]): """Build an AnotherModel instance.""" read: ExtendedReader - def __init__(self, client: Client) -> None: + def __init__(self, client: AsyncClient) -> None: super().__init__(client, 'another_model_table') # you can supply your table name here self.read = ExtendedReader(self.client, self.table) ``` -Finally, using your `ExtendedModel` is simple, just initialize the class with a `Client` instance & use it just as you would your previous `Model` instance, `a_model`: +Finally, using your `ExtendedModel` is simple, just initialize the class with a `AsyncClient` instance & use it just as you would your previous `AsyncModel` instance, `a_model`: ```python # ... diff --git a/db_wrapper/__init__.py b/db_wrapper/__init__.py index a366ad6..00953f5 100644 --- a/db_wrapper/__init__.py +++ b/db_wrapper/__init__.py @@ -18,5 +18,5 @@ Model instance. """ from .connection import ConnectionParameters -from .client import Client -from .model import Model, ModelData +from .client import AsyncClient, SyncClient +from .model import AsyncModel, SyncModel, ModelData diff --git a/db_wrapper/client/__init__.py b/db_wrapper/client/__init__.py new file mode 100644 index 0000000..a99c6b2 --- /dev/null +++ b/db_wrapper/client/__init__.py @@ -0,0 +1,8 @@ +"""Create a database Client for managing connections & executing queries.""" + +from typing import Union + +from .async_client import AsyncClient +from .sync_client import SyncClient + +Client = Union[AsyncClient, SyncClient] diff --git a/db_wrapper/client.py b/db_wrapper/client/async_client.py similarity index 72% rename from db_wrapper/client.py rename to db_wrapper/client/async_client.py index 7792be3..afed995 100644 --- a/db_wrapper/client.py +++ b/db_wrapper/client/async_client.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import ( + cast, Any, TypeVar, Union, @@ -10,27 +11,20 @@ List, Dict) -import aiopg # type: ignore -from psycopg2.extras import register_uuid -# importing for the sole purpose of re-exporting -# pylint: disable=unused-import +import aiopg +from psycopg2.extras import register_uuid, RealDictCursor, RealDictRow # type: ignore from psycopg2 import sql -from .connection import ConnectionParameters, connect +from db_wrapper.connection import ConnectionParameters, get_pool # add uuid support to psycopg2 & Postgres register_uuid() -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name -T = TypeVar('T') - -# pylint: disable=unsubscriptable-object Query = Union[str, sql.Composed] -class Client: +class AsyncClient: """Class to manage database connection & expose necessary methods to user. Stores connection parameters on init, then exposes methods to @@ -39,18 +33,19 @@ class Client: """ _connection_params: ConnectionParameters - _connection: aiopg.Connection + _pool: aiopg.Pool def __init__(self, connection_params: ConnectionParameters) -> None: self._connection_params = connection_params async def connect(self) -> None: - """Connect to the database.""" - self._connection = await connect(self._connection_params) + """Create a database connection pool.""" + self._pool = await get_pool(self._connection_params) async def disconnect(self) -> None: - """Disconnect from the database.""" - await self._connection.close() + """Close database connection pool.""" + self._pool.close() + await self._pool.wait_closed() # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object @@ -60,6 +55,11 @@ async def _execute_query( query: Query, params: Optional[Dict[Hashable, Any]] = None, ) -> None: + # aiopg type is incorrect & thinks execute only takes str + # when in the query is passed through to psycopg2's + # cursor.execute which does accept sql.Composed objects. + query = cast(str, query) + if params: await cursor.execute(query, params) else: @@ -82,7 +82,7 @@ async def execute( Returns: None """ - async with self._connection.cursor() as cursor: + with (await self._pool.cursor(cursor_factory=RealDictCursor) ) as cursor: await self._execute_query(cursor, query, params) # PENDS python 3.9 support in pylint @@ -91,7 +91,7 @@ async def execute_and_return( self, query: Query, params: Optional[Dict[Hashable, Any]] = None, - ) -> List[T]: + ) -> List[RealDictRow]: """Execute the given SQL query & return the result. Arguments: @@ -102,8 +102,8 @@ async def execute_and_return( Returns: List containing all the rows that matched the query. """ - async with self._connection.cursor() as cursor: + with (await self._pool.cursor(cursor_factory=RealDictCursor) ) as cursor: await self._execute_query(cursor, query, params) - result: List[T] = await cursor.fetchall() + result: List[RealDictRow] = await cursor.fetchall() return result diff --git a/db_wrapper/client/sync_client.py b/db_wrapper/client/sync_client.py new file mode 100644 index 0000000..72d64fd --- /dev/null +++ b/db_wrapper/client/sync_client.py @@ -0,0 +1,102 @@ +"""Wrapper on aiopg to simplify connecting to & interacting with db.""" + +from __future__ import annotations +from typing import ( + Any, + Dict, + Hashable, + List, + Optional, + Union, +) + +from psycopg2.extras import register_uuid, RealDictRow +from psycopg2 import sql +# pylint can't seem to find the items in psycopg2 despite being available +from psycopg2._psycopg import cursor # pylint: disable=no-name-in-module + +from db_wrapper.connection import ( + sync_connect, + ConnectionParameters, +) + +# add uuid support to psycopg2 & Postgres +register_uuid() + + +Query = Union[str, sql.Composed] + + +class SyncClient: + """Class to manage database connection & expose necessary methods to user. + + Stores connection parameters on init, then exposes methods to + asynchronously connect & disconnect the database, as well as execute SQL + queries. + """ + + _connection_params: ConnectionParameters + _connection: Any + + def __init__(self, connection_params: ConnectionParameters) -> None: + self._connection_params = connection_params + + def connect(self) -> None: + """Connect to the database.""" + self._connection = sync_connect(self._connection_params) + + def disconnect(self) -> None: + """Disconnect from the database.""" + self._connection.close() + + @staticmethod + def _execute_query( + db_cursor: cursor, + query: Query, + params: Optional[Dict[Hashable, Any]] = None, + ) -> None: + if params: + db_cursor.execute(query, params) + else: + db_cursor.execute(query) + + db_cursor.connection.commit() + + def execute( + self, + query: Query, + params: Optional[Dict[Hashable, Any]] = None, + ) -> None: + """Execute the given SQL query. + + Arguments: + query (Query) -- the SQL query to execute + params (dict) -- a dictionary of parameters to interpolate when + executing the query + + Returns: + None + """ + with self._connection.cursor() as db_cursor: + self._execute_query(db_cursor, query, params) + + def execute_and_return( + self, + query: Query, + params: Optional[Dict[Hashable, Any]] = None, + ) -> List[RealDictRow]: + """Execute the given SQL query & return the result. + + Arguments: + query (Query) -- the SQL query to execute + params (dict) -- a dictionary of parameters to interpolate when + executing the query + + Returns: + List containing all the rows that matched the query. + """ + with self._connection.cursor() as db_cursor: + self._execute_query(db_cursor, query, params) + + result: List[RealDictRow] = db_cursor.fetchall() + return result diff --git a/db_wrapper/connection.py b/db_wrapper/connection.py index 2c4ae05..0574d55 100644 --- a/db_wrapper/connection.py +++ b/db_wrapper/connection.py @@ -3,18 +3,16 @@ import asyncio from dataclasses import dataclass import logging -from typing import Optional +import time +from typing import Optional, Any +from psycopg2 import ( # type: ignore + connect as psycopg2Connect, + OperationalError as psycopg2OpError, +) from psycopg2.extras import RealDictCursor # type: ignore -from psycopg2 import OperationalError # type: ignore -# no stubs available, starting out by just determining correct return -# types & annotating in my wrappers here -import aiopg # type: ignore - -# -# Postgres Connection helper -# +import aiopg LOGGER = logging.getLogger(__name__) @@ -24,6 +22,7 @@ class ConnectionParameters: """Defines connection parameters for database.""" host: str + port: int user: str password: str database: str @@ -37,42 +36,91 @@ async def _try_connect( user = connection_params.user password = connection_params.password host = connection_params.host + port = connection_params.port + + dsn = f"dbname={database} user={user} password={password} " \ + f"host={host} port={port}" - dsn = f'dbname={database} user={user} password={password} host={host}' + # return await aiopg.create_pool(dsn) # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object - connection: Optional[aiopg.Connection] = None + pool: Optional[aiopg.Connection] = None - LOGGER.info( - f'Attempting to connect to database {database} as {user}@{host}...') + LOGGER.info(f"Attempting to connect to database {database} as " + f"{user}@{host}:{port}...") + + while pool is None: + try: + pool = await aiopg.create_pool(dsn) + except psycopg2OpError as err: + print(type(err)) + if retries > 12: + raise ConnectionError( + "Max number of connection attempts has been reached (12)" + ) from err + + LOGGER.info( + f"Connection failed ({retries} time(s))" + "retrying again in 5 seconds...") + + await asyncio.sleep(5) + return await _try_connect(connection_params, retries + 1) + + return pool + + +def _sync_try_connect( + connection_params: ConnectionParameters, + retries: int = 1 +) -> Any: + database = connection_params.database + user = connection_params.user + password = connection_params.password + host = connection_params.host + port = connection_params.port + + dsn = f"dbname={database} user={user} password={password} " + \ + f"host={host} port={port}" + + connection: Optional[Any] = None + + LOGGER.info(f"Attempting to connect to database {database} " + f"as {user}@{host}:{port}...") while connection is None: try: - connection = await aiopg.connect( + connection = psycopg2Connect( dsn, cursor_factory=RealDictCursor) - except OperationalError as err: + except psycopg2OpError as err: print(type(err)) if retries > 12: raise ConnectionError( - 'Max number of connection attempts has been reached (12)' + "Max number of connection attempts has been reached (12)" ) from err LOGGER.info( - f'Connection failed ({retries} time(s))' - 'retrying again in 5 seconds...') + f"Connection failed ({retries} time(s))" + "retrying again in 5 seconds...") - await asyncio.sleep(5) - return await _try_connect(connection_params, retries + 1) + time.sleep(5) + return _sync_try_connect(connection_params, retries + 1) return connection # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object -async def connect( +async def get_pool( connection_params: ConnectionParameters -) -> aiopg.Connection: - """Establish database connection.""" +) -> aiopg.Pool: + """Establish database connection pool.""" return await _try_connect(connection_params) + + +def sync_connect( + connection_params: ConnectionParameters +) -> Any: + """Establish database connection.""" + return _sync_try_connect(connection_params) diff --git a/db_wrapper/model.py b/db_wrapper/model.py deleted file mode 100644 index dcc0e4c..0000000 --- a/db_wrapper/model.py +++ /dev/null @@ -1,277 +0,0 @@ -"""Convenience class to simplify database interactions for given interface.""" - -# std lib dependencies -from __future__ import annotations -from typing import ( - TypeVar, - Generic, - Any, - Tuple, - List, - Dict, - TypedDict -) -from uuid import UUID - -# third party dependencies -from pydantic import BaseModel - -# internal dependency -from .client import Client, sql - - -class ModelData(BaseModel): - """Base interface for ModelData to be used in Model.""" - - # PENDS python 3.9 support in pylint - # pylint: disable=inherit-non-class - # pylint: disable=too-few-public-methods - - id: UUID - - -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name -T = TypeVar('T', bound=ModelData) - - -class UnexpectedMultipleResults(Exception): - """Raised when query receives multiple results when only one expected.""" - - def __init__(self, results: List[Any]) -> None: - message = 'Multiple results received when only ' \ - f'one was expected: {results}' - super().__init__(self, message) - - -class NoResultFound(Exception): - """Raised when query receives no results when 1+ results expected.""" - - def __init__(self) -> None: - message = 'No result was found' - super().__init__(self, message) - - -# pylint: disable=too-few-public-methods -class Create(Generic[T]): - """Encapsulate Create operations for Model.read.""" - - _client: Client - _table: sql.Composable - - def __init__(self, client: Client, table: sql.Composable) -> None: - self._client = client - self._table = table - - async def one(self, item: T) -> T: - """Create one new record with a given item.""" - columns: List[sql.Identifier] = [] - values: List[sql.Literal] = [] - - for column, value in item.dict().items(): - values.append(sql.Literal(value)) - - columns.append(sql.Identifier(column)) - - query = sql.SQL( - 'INSERT INTO {table} ({columns}) ' - 'VALUES ({values}) ' - 'RETURNING *;' - ).format( - table=self._table, - columns=sql.SQL(',').join(columns), - values=sql.SQL(',').join(values), - ) - - result: List[T] = await self._client.execute_and_return(query) - - return result[0] - - -class Read(Generic[T]): - """Encapsulate Read operations for Model.read.""" - - _client: Client - _table: sql.Composable - - def __init__(self, client: Client, table: sql.Composable) -> None: - self._client = client - self._table = table - - async def one_by_id(self, id_value: str) -> T: - """Read a row by it's id.""" - query = sql.SQL( - 'SELECT * ' - 'FROM {table} ' - 'WHERE id = {id_value};' - ).format( - table=self._table, - id_value=sql.Literal(id_value) - ) - - result: List[T] = await self._client.execute_and_return(query) - - # Should only return one item from DB - if len(result) > 1: - raise UnexpectedMultipleResults(result) - if len(result) == 0: - raise NoResultFound() - - return result[0] - - -class Update(Generic[T]): - """Encapsulate Update operations for Model.read.""" - - _client: Client - _table: sql.Composable - - def __init__(self, client: Client, table: sql.Composable) -> None: - self._client = client - self._table = table - - async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T: - """Apply changes to row with given id. - - Arguments: - id_value (string) - the id of the row to update - changes (dict) - a dictionary of changes to apply, - matches keys to column names & values to values - - Returns: - full value of row updated - """ - def compose_one_change(change: Tuple[str, Any]) -> sql.Composed: - key = change[0] - value = change[1] - - return sql.SQL("{key} = {value}").format( - key=sql.Identifier(key), value=sql.Literal(value)) - - def compose_changes(changes: Dict[str, Any]) -> sql.Composed: - return sql.SQL(',').join( - [compose_one_change(change) for change in changes.items()]) - - query = sql.SQL( - 'UPDATE {table} ' - 'SET {changes} ' - 'WHERE id = {id_value} ' - 'RETURNING *;' - ).format( - table=self._table, - changes=compose_changes(changes), - id_value=sql.Literal(id_value), - ) - - result: List[T] = await self._client.execute_and_return(query) - - return result[0] - - -class Delete(Generic[T]): - """Encapsulate Delete operations for Model.read.""" - - _client: Client - _table: sql.Composable - - def __init__(self, client: Client, table: sql.Composable) -> None: - self._client = client - self._table = table - - async def one_by_id(self, id_value: str) -> T: - """Delete one record with matching ID.""" - query = sql.SQL( - 'DELETE FROM {table} ' - 'WHERE id = {id_value} ' - 'RETURNING *;' - ).format( - table=self._table, - id_value=sql.Literal(id_value) - ) - - result: List[T] = await self._client.execute_and_return(query) - - # Should only return one item from DB - if len(result) > 1: - raise UnexpectedMultipleResults(result) - if len(result) == 0: - raise NoResultFound() - - return result[0] - - -class Model(Generic[T]): - """Class to manage execution of database queries for a model.""" - - # Properties don't need docstrings - # pylint: disable=missing-function-docstring - - client: Client - table: sql.Identifier - - _create: Create[T] - _read: Read[T] - _update: Update[T] - _delete: Delete[T] - - # PENDS python 3.9 support in pylint - # pylint: disable=unsubscriptable-object - def __init__( - self, - client: Client, table: str, - ) -> None: - self.client = client - self.table = sql.Identifier(table) - - self._create = Create[T](self.client, self.table) - self._read = Read[T](self.client, self.table) - self._update = Update[T](self.client, self.table) - self._delete = Delete[T](self.client, self.table) - - @property - def create(self) -> Create[T]: - """Methods for creating new records of the Model.""" - return self._create - - @create.setter - def create(self, creator: Create[T]) -> None: - if isinstance(creator, Create): - self._create = creator - else: - raise TypeError('Model.create must be an instance of Create.') - - @property - def read(self) -> Read[T]: - """Methods for reading records of the Model.""" - return self._read - - @read.setter - def read(self, reader: Read[T]) -> None: - if isinstance(reader, Read): - self._read = reader - else: - raise TypeError('Model.read must be an instance of Read.') - - @property - def update(self) -> Update[T]: - """Methods for updating records of the Model.""" - return self._update - - @update.setter - def update(self, updater: Update[T]) -> None: - if isinstance(updater, Update): - self._update = updater - else: - raise TypeError('Model.update must be an instance of Update.') - - @property - def delete(self) -> Delete[T]: - """Methods for deleting records of the Model.""" - return self._delete - - @delete.setter - def delete(self, deleter: Delete[T]) -> None: - if isinstance(deleter, Delete): - self._delete = deleter - else: - raise TypeError('Model.delete must be an instance of Delete.') diff --git a/db_wrapper/model/__init__.py b/db_wrapper/model/__init__.py new file mode 100644 index 0000000..6bae307 --- /dev/null +++ b/db_wrapper/model/__init__.py @@ -0,0 +1,18 @@ +"""Convenience objects to simplify database interactions w/ given interface.""" + +from psycopg2.extras import RealDictRow +from .async_model import ( + AsyncModel, + AsyncCreate, + AsyncRead, + AsyncUpdate, + AsyncDelete +) +from .sync_model import ( + SyncModel, + SyncCreate, + SyncRead, + SyncUpdate, + SyncDelete +) +from .base import ModelData, sql diff --git a/db_wrapper/model/async_model.py b/db_wrapper/model/async_model.py new file mode 100644 index 0000000..73cb23c --- /dev/null +++ b/db_wrapper/model/async_model.py @@ -0,0 +1,218 @@ +"""Asynchronous Model objects.""" + +from typing import Any, Dict, List, Type +from uuid import UUID + +from psycopg2.extras import RealDictRow + +from db_wrapper.client import AsyncClient +from .base import ( + ensure_exactly_one, + sql, + T, + CreateABC, + ReadABC, + UpdateABC, + DeleteABC, + ModelABC, +) + + +class AsyncCreate(CreateABC[T]): + """Create methods designed to use an AsyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: AsyncClient + + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + async def one(self, item: T) -> T: + """Create one new record with a given item.""" + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(self._query_one(item)) + result: T = self._return_constructor(**query_result[0]) + + return result + + +class AsyncRead(ReadABC[T]): + """Create methods designed to use an AsyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: AsyncClient + + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + async def one_by_id(self, id_value: UUID) -> T: + """Read a row by it's id.""" + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value)) + + # Should only return one item from DB + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) + + return result + + +class AsyncUpdate(UpdateABC[T]): + """Create methods designed to use an AsyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: AsyncClient + + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + async def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T: + """Apply changes to row with given id. + + Arguments: + id_value (string) - the id of the row to update + changes (dict) - a dictionary of changes to apply, + matches keys to column names & values to values + + Returns: + full value of row updated + """ + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value, changes)) + + ensure_exactly_one(query_result) + result: T = self._return_constructor(**query_result[0]) + + return result + + +class AsyncDelete(DeleteABC[T]): + """Create methods designed to use an AsyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: AsyncClient + + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + async def one_by_id(self, id_value: str) -> T: + """Delete one record with matching ID.""" + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value)) + + # Should only return one item from DB + ensure_exactly_one(query_result) + result = self._return_constructor(**query_result[0]) + + return result + + +class AsyncModel(ModelABC[T]): + """Class to manage execution of database queries for a model.""" + + # Properties don't need docstrings + # pylint: disable=missing-function-docstring + + client: AsyncClient + + _create: AsyncCreate[T] + _read: AsyncRead[T] + _update: AsyncUpdate[T] + _delete: AsyncDelete[T] + + def __init__( + self, + client: AsyncClient, + table: str, + return_constructor: Type[T], + ) -> None: + super().__init__(client, table) + + self._create = AsyncCreate[T]( + self.client, self.table, return_constructor) + self._read = AsyncRead[T]( + self.client, self.table, return_constructor) + self._update = AsyncUpdate[T]( + self.client, self.table, return_constructor) + self._delete = AsyncDelete[T]( + self.client, self.table, return_constructor) + + @property + def create(self) -> AsyncCreate[T]: + """Methods for creating new records of the Model.""" + return self._create + + @create.setter + def create(self, creator: AsyncCreate[T]) -> None: + if isinstance(creator, AsyncCreate): + self._create = creator + else: + raise TypeError('Model.create must be an instance of AsyncCreate.') + + @property + def read(self) -> AsyncRead[T]: + """Methods for reading records of the Model.""" + return self._read + + @read.setter + def read(self, reader: AsyncRead[T]) -> None: + if isinstance(reader, AsyncRead): + self._read = reader + else: + raise TypeError('Model.read must be an instance of AsyncRead.') + + @property + def update(self) -> AsyncUpdate[T]: + """Methods for updating records of the Model.""" + return self._update + + @update.setter + def update(self, updater: AsyncUpdate[T]) -> None: + if isinstance(updater, AsyncUpdate): + self._update = updater + else: + raise TypeError('Model.update must be an instance of AsyncUpdate.') + + @property + def delete(self) -> AsyncDelete[T]: + """Methods for deleting records of the Model.""" + return self._delete + + @delete.setter + def delete(self, deleter: AsyncDelete[T]) -> None: + if isinstance(deleter, AsyncDelete): + self._delete = deleter + else: + raise TypeError('Model.delete must be an instance of AsyncDelete.') diff --git a/db_wrapper/model/base.py b/db_wrapper/model/base.py new file mode 100644 index 0000000..fb24805 --- /dev/null +++ b/db_wrapper/model/base.py @@ -0,0 +1,199 @@ +"""Base classes for building Async/SyncModel.""" + +# std lib dependencies +from __future__ import annotations +from typing import ( + Any, + Dict, + Generic, + List, + Tuple, + Type, + TypeVar, +) +from uuid import UUID + +# third party dependencies +from psycopg2 import sql +# pylint is unable to parse c module to check contents +from pydantic import BaseModel # pylint: disable=no-name-in-module + +# internal dependency +from db_wrapper.client import ( + Client, +) + + +class ModelData(BaseModel): + """Base interface for ModelData to be used in Model.""" + + # PENDS python 3.9 support in pylint + # pylint: disable=inherit-non-class + # pylint: disable=too-few-public-methods + + id: UUID + + +# Generic doesn't need a more descriptive name +T = TypeVar('T', bound=ModelData) # pylint: disable=invalid-name + + +class UnexpectedMultipleResults(Exception): + """Raised when query receives multiple results when only one expected.""" + + def __init__(self, results: List[Any]) -> None: + message = 'Multiple results received when only ' \ + f'one was expected: {results}' + super().__init__(self, message) + + +class NoResultFound(Exception): + """Raised when query receives no results when 1+ results expected.""" + + def __init__(self) -> None: + message = 'No result was found' + super().__init__(self, message) + + +def ensure_exactly_one(result: List[Any]) -> None: + """Raise appropriate Exceptions if list longer than 1.""" + if len(result) > 1: + raise UnexpectedMultipleResults(result) + if len(result) == 0: + raise NoResultFound() + + +class CRUDABC(Generic[T]): + """Encapsulate object creation behavior for all CRUD objects.""" + + # pylint: disable=too-few-public-methods + + _table: sql.Composable + _return_constructor: Type[T] + + def __init__( + self, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + self._table = table + self._return_constructor = return_constructor + + +class CreateABC(CRUDABC[T]): + """Encapsulate Create operations for Model.create.""" + + # pylint: disable=too-few-public-methods + + def _query_one(self, item: T) -> sql.Composed: + """Build query to create one new record with a given item.""" + columns: List[sql.Identifier] = [] + values: List[sql.Literal] = [] + + for column, value in item.dict().items(): + values.append(sql.Literal(value)) + + columns.append(sql.Identifier(column)) + + query = sql.SQL( + 'INSERT INTO {table} ({columns}) ' + 'VALUES ({values}) ' + 'RETURNING *;' + ).format( + table=self._table, + columns=sql.SQL(',').join(columns), + values=sql.SQL(',').join(values), + ) + + return query + + +class ReadABC(CRUDABC[T]): + """Encapsulate Read operations for Model.read.""" + + # pylint: disable=too-few-public-methods + + def _query_one_by_id(self, id_value: UUID) -> sql.Composed: + """Build query to read a row by it's id.""" + query = sql.SQL( + 'SELECT * ' + 'FROM {table} ' + 'WHERE id = {id_value};' + ).format( + table=self._table, + id_value=sql.Literal(str(id_value)) + ) + + return query + + +class UpdateABC(CRUDABC[T]): + """Encapsulate Update operations for Model.read.""" + + # pylint: disable=too-few-public-methods + + def _query_one_by_id( + self, + id_value: UUID, + changes: Dict[str, Any] + ) -> sql.Composed: + """Build Query to apply changes to row with given id.""" + def compose_one_change(change: Tuple[str, Any]) -> sql.Composed: + key = change[0] + value = change[1] + + return sql.SQL("{key} = {value}").format( + key=sql.Identifier(key), value=sql.Literal(value)) + + def compose_changes(changes: Dict[str, Any]) -> sql.Composed: + return sql.SQL(',').join( + [compose_one_change(change) for change in changes.items()]) + + query = sql.SQL( + 'UPDATE {table} ' + 'SET {changes} ' + 'WHERE id = {id_value} ' + 'RETURNING *;' + ).format( + table=self._table, + changes=compose_changes(changes), + id_value=sql.Literal(str(id_value)), + ) + + return query + + +class DeleteABC(CRUDABC[T]): + """Encapsulate Delete operations for Model.read.""" + + # pylint: disable=too-few-public-methods + + def _query_one_by_id(self, id_value: str) -> sql.Composed: + """Build query to delete one record with matching ID.""" + query = sql.SQL( + 'DELETE FROM {table} ' + 'WHERE id = {id_value} ' + 'RETURNING *;' + ).format( + table=self._table, + id_value=sql.Literal(id_value) + ) + + return query + + +class ModelABC(Generic[T]): + """Class to manage execution of database queries for a model.""" + + # pylint: disable=too-few-public-methods + + client: Client + table: sql.Identifier + + def __init__( + self, + client: Client, + table: str, + ) -> None: + self.client = client + self.table = sql.Identifier(table) diff --git a/db_wrapper/model/sync_model.py b/db_wrapper/model/sync_model.py new file mode 100644 index 0000000..7d65602 --- /dev/null +++ b/db_wrapper/model/sync_model.py @@ -0,0 +1,216 @@ +"""Synchronous Model objects.""" + +from typing import Any, Dict, List, Type +from uuid import UUID + +from psycopg2.extras import RealDictRow + +from db_wrapper.client import SyncClient +from .base import ( + ensure_exactly_one, + sql, + T, + CreateABC, + DeleteABC, + ReadABC, + UpdateABC, + ModelABC, +) + + +class SyncCreate(CreateABC[T]): + """Create methods designed to use a SyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: SyncClient + + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + def one(self, item: T) -> T: + """Create one new record with a given item.""" + query_result: List[RealDictRow] = self._client.execute_and_return( + self._query_one(item)) + result: T = self._return_constructor(**query_result[0]) + + return result + + +class SyncRead(ReadABC[T]): + """Create methods designed to use an SyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: SyncClient + + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + def one_by_id(self, id_value: UUID) -> T: + """Read a row by it's id.""" + query_result: List[RealDictRow] = self._client.execute_and_return( + self._query_one_by_id(id_value)) + + # Should only return one item from DB + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) + + return result + + +class SyncUpdate(UpdateABC[T]): + """Create methods designed to use an SyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: SyncClient + + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T: + """Apply changes to row with given id. + + Arguments: + id_value (string) - the id of the row to update + changes (dict) - a dictionary of changes to apply, + matches keys to column names & values to values + + Returns: + full value of row updated + """ + query_result: List[RealDictRow] = self._client.execute_and_return( + self._query_one_by_id(id_value, changes)) + + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) + + return result + + +class SyncDelete(DeleteABC[T]): + """Create methods designed to use an SyncClient.""" + + # pylint: disable=too-few-public-methods + + _client: SyncClient + + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) + self._client = client + + def one_by_id(self, id_value: str) -> T: + """Delete one record with matching ID.""" + query_result: List[RealDictRow] = self._client.execute_and_return( + self._query_one_by_id(id_value)) + + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) + + return result + + +class SyncModel(ModelABC[T]): + """Class to manage execution of database queries for a model.""" + + # Properties don't need docstrings + # pylint: disable=missing-function-docstring + + client: SyncClient + + _create: SyncCreate[T] + _read: SyncRead[T] + _update: SyncUpdate[T] + _delete: SyncDelete[T] + + def __init__( + self, + client: SyncClient, + table: str, + return_constructor: Type[T], + ) -> None: + super().__init__(client, table) + + self._create = SyncCreate[T]( + self.client, self.table, return_constructor) + self._read = SyncRead[T]( + self.client, self.table, return_constructor) + self._update = SyncUpdate[T]( + self.client, self.table, return_constructor) + self._delete = SyncDelete[T]( + self.client, self.table, return_constructor) + + @property + def create(self) -> SyncCreate[T]: + """Methods for creating new records of the Model.""" + return self._create + + @create.setter + def create(self, creator: SyncCreate[T]) -> None: + if isinstance(creator, SyncCreate): + self._create = creator + else: + raise TypeError('Model.create must be an instance of SyncCreate.') + + @property + def read(self) -> SyncRead[T]: + """Methods for reading records of the Model.""" + return self._read + + @read.setter + def read(self, reader: SyncRead[T]) -> None: + if isinstance(reader, SyncRead): + self._read = reader + else: + raise TypeError('Model.read must be an instance of SyncRead.') + + @property + def update(self) -> SyncUpdate[T]: + """Methods for updating records of the Model.""" + return self._update + + @update.setter + def update(self, updater: SyncUpdate[T]) -> None: + if isinstance(updater, SyncUpdate): + self._update = updater + else: + raise TypeError('Model.update must be an instance of SyncUpdate.') + + @property + def delete(self) -> SyncDelete[T]: + """Methods for deleting records of the Model.""" + return self._delete + + @delete.setter + def delete(self, deleter: SyncDelete[T]) -> None: + if isinstance(deleter, SyncDelete): + self._delete = deleter + else: + raise TypeError('Model.delete must be an instance of SyncDelete.') diff --git a/example/.python-version b/example/.python-version index 6bd1074..f69abe4 100644 --- a/example/.python-version +++ b/example/.python-version @@ -1 +1 @@ -3.9.1 +3.9.7 diff --git a/example/example/__init__.py b/example/example/__init__.py new file mode 100644 index 0000000..8aef547 --- /dev/null +++ b/example/example/__init__.py @@ -0,0 +1,2 @@ +from . import models +from .main import run diff --git a/example/example/example.py b/example/example/main.py similarity index 77% rename from example/example/example.py rename to example/example/main.py index 372037b..7123ad7 100644 --- a/example/example/example.py +++ b/example/example/main.py @@ -1,14 +1,21 @@ -"""An example of how to use Client & Model together.""" +"""An example of how to use AsyncClient & AsyncModel together.""" import asyncio import json +import logging import os from uuid import uuid4, UUID from typing import Any, List -from db_wrapper import ConnectionParameters, Client, Model +from db_wrapper import ConnectionParameters, AsyncClient, AsyncModel -from models import AModel, ExtendedModel, ExtendedModelData +from example.models import ( + AModel, + ExtendedModel, + ExtendedModelData, +) + +logging.basicConfig(level=logging.INFO) class UUIDJsonEncoder(json.JSONEncoder): @@ -23,23 +30,24 @@ def default(self, obj: Any) -> Any: conn_params = ConnectionParameters( host=os.getenv('DB_HOST', 'localhost'), + port=int(os.getenv('DB_PORT', '5432')), # user=os.getenv('DB_USER', 'postgres'), # password=os.getenv('DB_PASS', 'postgres'), # database=os.getenv('DB_NAME', 'postgres')) user=os.getenv('DB_USER', 'test'), password=os.getenv('DB_PASS', 'pass'), database=os.getenv('DB_NAME', 'dev')) -client = Client(conn_params) +client = AsyncClient(conn_params) -a_model = Model[AModel](client, 'a_model') +a_model = AsyncModel[AModel](client, 'a_model', AModel) extended_model = ExtendedModel(client) async def create_a_model_record() -> UUID: """ - Show how to use a simple Model instance. + Show how to use a simple AsyncModel instance. - Create a new record using the default Model.create.one method. + Create a new record using the default AsyncModel.create.one method. """ new_record = AModel(**{ 'id': uuid4(), @@ -55,13 +63,13 @@ async def create_a_model_record() -> UUID: async def read_a_model(id_value: UUID) -> AModel: """Show how to read a record with a given id value.""" - # read.one_by_id expects a string, so UUID values need - # converted using str() - return await a_model.read.one_by_id(str(id_value)) + return await a_model.read.one_by_id(id_value) async def create_extended_models() -> None: - """Show how using an extended Model can be the same as the defaults.""" + """ + Show how using an extended AsyncModel can be the same as the defaults. + """ dicts = [{ 'id': uuid4(), 'string': 'something', @@ -97,7 +105,7 @@ async def read_extended_models() -> List[ExtendedModelData]: """Show how to use an extended Model's new methods.""" # We defined read.all in ./models/extended_model.py's ExtendedRead class, # then replaced ExtendedModel's read property with ExtendedRead. - # As a result, we can call it just like any other method on Model.read. + # As a result, we can call it just like any other method on AsyncModel.read return await extended_model.read.all() @@ -115,13 +123,15 @@ async def run() -> None: new_id = await create_a_model_record() created_a_model = await read_a_model(new_id) await create_extended_models() - extended_models = await read_extended_models() + created_extended_models = await read_extended_models() finally: await client.disconnect() # Print results to stdout - print(json.dumps(created_a_model, cls=UUIDJsonEncoder)) - print(json.dumps(extended_models, cls=UUIDJsonEncoder)) + print(json.dumps(created_a_model.dict(), cls=UUIDJsonEncoder)) + print(json.dumps([model.dict() + for model in created_extended_models], + cls=UUIDJsonEncoder)) if __name__ == '__main__': # A simple app can be run using asyncio's run method. diff --git a/example/example/models/extended_model.py b/example/example/models/extended_model.py index 85f7324..97ebd6f 100644 --- a/example/example/models/extended_model.py +++ b/example/example/models/extended_model.py @@ -5,9 +5,10 @@ from psycopg2 import sql from psycopg2.extensions import register_adapter -from psycopg2.extras import Json +from psycopg2.extras import Json # type: ignore -from db_wrapper.model import ModelData, Model, Read, Create, Client +from db_wrapper import AsyncClient, AsyncModel, ModelData +from db_wrapper.model import AsyncRead, AsyncCreate, RealDictRow # tell psycopg2 to adapt all dictionaries to json instead of # the default hstore @@ -26,7 +27,7 @@ class ExtendedModelData(ModelData): data: Dict[str, Any] -class ExtendedCreator(Create[ExtendedModelData]): +class ExtendedCreator(AsyncCreate[ExtendedModelData]): """Add custom json loading to Model.create.""" # pylint: disable=too-few-public-methods @@ -54,13 +55,14 @@ async def one(self, item: ExtendedModelData) -> ExtendedModelData: values=sql.SQL(',').join(values), ) - result: List[ExtendedModelData] = \ + query_result: List[RealDictRow] = \ await self._client.execute_and_return(query) + result = self._return_constructor(**query_result[0]) - return result[0] + return result -class ExtendedReader(Read[ExtendedModelData]): +class ExtendedReader(AsyncRead[ExtendedModelData]): """Add custom method to Model.read.""" async def all_by_string(self, string: str) -> List[ExtendedModelData]: @@ -74,8 +76,10 @@ async def all_by_string(self, string: str) -> List[ExtendedModelData]: string=sql.Identifier(string) ) - result: List[ExtendedModelData] = await self \ - ._client.execute_and_return(query) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(query) + result = [self._return_constructor(**row) + for row in query_result] return result @@ -84,19 +88,23 @@ async def all(self) -> List[ExtendedModelData]: query = sql.SQL('SELECT * FROM {table}').format( table=self._table) - result: List[ExtendedModelData] = await self \ - ._client.execute_and_return(query) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(query) + result = [self._return_constructor(**row) + for row in query_result] return result -class ExtendedModel(Model[ExtendedModelData]): +class ExtendedModel(AsyncModel[ExtendedModelData]): """Build an ExampleItem Model instance.""" read: ExtendedReader create: ExtendedCreator - def __init__(self, client: Client) -> None: - super().__init__(client, 'extended_model') - self.read = ExtendedReader(self.client, self.table) - self.create = ExtendedCreator(self.client, self.table) + def __init__(self, client: AsyncClient) -> None: + super().__init__(client, 'extended_model', ExtendedModelData) + self.read = ExtendedReader( + self.client, self.table, ExtendedModelData) + self.create = ExtendedCreator( + self.client, self.table, ExtendedModelData) diff --git a/example/manage.py b/example/manage.py index c95a88d..4c8b7bc 100644 --- a/example/manage.py +++ b/example/manage.py @@ -14,11 +14,11 @@ import time from typing import Any, Optional, Generator, List, Tuple -from migra import Migration # type: ignore +from migra import Migration from psycopg2 import connect, OperationalError # type: ignore from psycopg2 import sql from psycopg2.sql import Composed -from sqlbag import ( # type: ignore +from sqlbag import ( S, load_sql_from_folder) @@ -170,13 +170,13 @@ def _get_schema_diff( def _temp_db(host: str, user: str, password: str) -> Generator[str, Any, Any]: """Create, yield, & remove a temporary database as context.""" connection = _resilient_connect( - f'postgres://{user}:{password}@{host}/{DB_NAME}') + f'postgresql://{user}:{password}@{host}/{DB_NAME}') connection.set_session(autocommit=True) name = _temp_name() with connection.cursor() as cursor: _create_db(cursor, name) - yield f'postgres://{user}:{password}@{host}/{name}' + yield f'postgresql://{user}:{password}@{host}/{name}' _drop_db(cursor, name) connection.close() diff --git a/example/requirements/prod.txt b/example/requirements/prod.txt index 8d4c2e7..c642fd0 100644 --- a/example/requirements/prod.txt +++ b/example/requirements/prod.txt @@ -1,2 +1 @@ -psycopg2-binary>=2.8.6,<3.0.0 -https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.0.0-alpha/db_wrapper-2.0.0a0-py3-none-any.whl +-e ../ diff --git a/example_sync/.mypy.ini b/example_sync/.mypy.ini new file mode 100644 index 0000000..991f7c8 --- /dev/null +++ b/example_sync/.mypy.ini @@ -0,0 +1,26 @@ +[mypy] + +; +; Import discovery +; +; tell mypy where the project root is for module resolution +mypy_path=$PYTHONPATH,$PYTHONPATH/stubs +; tell mypy to look at namespace packages (ones without an __init__.py) +namespace_packages = True + +; +; Strict mode, almost +; +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +warn_unreachable = True +strict_equality = True diff --git a/example_sync/.python-version b/example_sync/.python-version new file mode 100644 index 0000000..6bd1074 --- /dev/null +++ b/example_sync/.python-version @@ -0,0 +1 @@ +3.9.1 diff --git a/example_sync/Dockerfile b/example_sync/Dockerfile new file mode 100644 index 0000000..0f72d47 --- /dev/null +++ b/example_sync/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3.9.1-alpine3.12 + +# load source code +RUN mkdir /app +COPY example /app +COPY requirements/prod.txt /app/requirements.txt +COPY requirements/db_wrapper-0.1.0a0.tar.gz /app/requirements/db_wrapper-0.1.0a0.tar.gz + +VOLUME /app +WORKDIR /app + +# install python dependencies +RUN apk add --no-cache --virtual .build-deps \ + # needed to build psycopg2 & yarl + gcc \ + # needed to build yarl + musl-dev \ + # needed to build psycopg2 + postgresql-dev \ + # runtime dependency for psycopg2 + && apk add --no-cache libpq \ + # install python packages + && ls -la requirements \ + && pip install -r requirements.txt \ + # then remove build dependencies + && apk del .build-deps + +# start server +ENTRYPOINT ["python"] +CMD ["example.py"] diff --git a/example_sync/docker-compose.yml b/example_sync/docker-compose.yml new file mode 100644 index 0000000..14611fa --- /dev/null +++ b/example_sync/docker-compose.yml @@ -0,0 +1,28 @@ +version: "3" # this is the Docker Compose specification + # version, not the app stack version + +services: + + # seed: + # build: . + # environment: + # MODE: development + # DB_HOST: db + # DB_USER: test + # DB_PASS: pass + # DB_NAME: dev + # volumes: + # - ./example:/app + + db: + image: postgres:13-alpine + restart: always + ports: + - 5432:5432 + environment: + POSTGRES_DB: dev + POSTGRES_USER: test + POSTGRES_PASSWORD: pass # DONT DO THIS IN PROD -- acceptable in dev, + # but docker secrets should be used in prod + # in conjunction with docker swarm in place + # of docker compose diff --git a/example_sync/example/__init__.py b/example_sync/example/__init__.py new file mode 100644 index 0000000..8aef547 --- /dev/null +++ b/example_sync/example/__init__.py @@ -0,0 +1,2 @@ +from . import models +from .main import run diff --git a/example_sync/example/main.py b/example_sync/example/main.py new file mode 100644 index 0000000..e5a96dd --- /dev/null +++ b/example_sync/example/main.py @@ -0,0 +1,135 @@ +"""An example of how to use SyncClient & Model together.""" + +import json +import logging +import os +from uuid import uuid4, UUID +from typing import Any, List + +from db_wrapper import SyncClient, ConnectionParameters +from db_wrapper.model import SyncModel as Model + +from example.models import ( + AModel, + ExtendedModel, + ExtendedModelData, +) + +logging.basicConfig(level=logging.INFO) + + +class UUIDJsonEncoder(json.JSONEncoder): + """Extended Json Encoder to allow encoding of objects containing UUID.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, UUID): + return str(obj) + + return obj + + +conn_params = ConnectionParameters( + host=os.getenv('DB_HOST', 'localhost'), + port=int(os.getenv('DB_PORT', '5432')), + # user=os.getenv('DB_USER', 'postgres'), + # password=os.getenv('DB_PASS', 'postgres'), + # database=os.getenv('DB_NAME', 'postgres')) + user=os.getenv('DB_USER', 'test'), + password=os.getenv('DB_PASS', 'pass'), + database=os.getenv('DB_NAME', 'dev')) +client = SyncClient(conn_params) + +a_model = Model[AModel](client, 'a_model', AModel) +extended_model = ExtendedModel(client) + + +def create_a_model_record() -> UUID: + """ + Show how to use a simple Model instance. + + Create a new record using the default Model.create.one method. + """ + new_record = AModel(**{ + 'id': uuid4(), + 'string': 'some string', + 'integer': 1, + 'array': ['an', 'array', 'of', 'strings'], + }) + + a_model.create.one(new_record) + + return new_record.id + + +def read_a_model(id_value: UUID) -> AModel: + """Show how to read a record with a given id value.""" + return a_model.read.one_by_id(id_value) + + +def create_extended_models() -> None: + """Show how using an extended Model can be the same as the defaults.""" + dicts = [{ + 'id': uuid4(), + 'string': 'something', + 'integer': 1, + 'data': {'a': 1, 'b': 2, 'c': True} + }, { + 'id': uuid4(), + 'string': 'something', + 'integer': 1, + 'data': {'a': 1, 'b': 2, 'c': True} + }, { + 'id': uuid4(), + 'string': 'something', + 'integer': 1, + 'data': {'a': 1, 'b': 2, 'c': True} + }, { + 'id': uuid4(), + 'string': 'something', + 'integer': 1, + 'data': {'a': 1, 'b': 2, 'c': True} + }] + + new_records: List[ExtendedModelData] = [ + ExtendedModelData(**record) for record in dicts] + + # by looping over a list of records, you can use the default create.one + # method to create each record as a separate transaction + for record in new_records: + extended_model.create.one(record) + + +def read_extended_models() -> List[ExtendedModelData]: + """Show how to use an extended Model's new methods.""" + # We defined read.all in ./models/extended_model.py's ExtendedRead class, + # then replaced ExtendedModel's read property with ExtendedRead. + # As a result, we can call it just like any other method on Model.read. + return extended_model.read.all() + + +def run() -> None: + """Show how to make a connection, execute queries, & disconnect.""" + # First, have the client make a connection to the database + client.connect() + + # Then, execute queries using the models that were initialized + # with the client above. + # Doing this inside a try/finally block allows client to gracefully + # disconnect even when an exception is thrown. + try: + new_id = create_a_model_record() + created_a_model = read_a_model(new_id) + create_extended_models() + created_extended_models = read_extended_models() + finally: + client.disconnect() + + # Print results to stdout + print(json.dumps(created_a_model.dict(), cls=UUIDJsonEncoder)) + print(json.dumps([model.dict() + for model in created_extended_models], + cls=UUIDJsonEncoder)) + + +if __name__ == '__main__': + run() diff --git a/example_sync/example/models/__init__.py b/example_sync/example/models/__init__.py new file mode 100644 index 0000000..0b0966b --- /dev/null +++ b/example_sync/example/models/__init__.py @@ -0,0 +1,2 @@ +from .a_model import AModel +from .extended_model import ExtendedModelData, ExtendedModel diff --git a/example_sync/example/models/a_model.py b/example_sync/example/models/a_model.py new file mode 100644 index 0000000..071272c --- /dev/null +++ b/example_sync/example/models/a_model.py @@ -0,0 +1,13 @@ +"""Define a simple Model data type.""" + +from typing import List + +from db_wrapper.model import ModelData + + +class AModel(ModelData): + """An example Item.""" + + string: str + integer: int + array: List[str] diff --git a/example_sync/example/models/a_model.sql b/example_sync/example/models/a_model.sql new file mode 100644 index 0000000..5fbd02a --- /dev/null +++ b/example_sync/example/models/a_model.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS "a_model" ( + "id" uuid PRIMARY KEY, + "string" varchar(255), + "integer" smallint, + "array" varchar(255) [] +); diff --git a/example_sync/example/models/extended_model.py b/example_sync/example/models/extended_model.py new file mode 100644 index 0000000..fd44311 --- /dev/null +++ b/example_sync/example/models/extended_model.py @@ -0,0 +1,111 @@ +"""An example implementation of custom object SyncModel.""" + +import json +from typing import Any, List, Dict + +from psycopg2 import sql +from psycopg2.extensions import register_adapter +from psycopg2.extras import Json # type: ignore + +from db_wrapper import SyncClient, SyncModel, ModelData +from db_wrapper.model import RealDictRow, SyncRead, SyncCreate + +# tell psycopg2 to adapt all dictionaries to json instead of +# the default hstore +register_adapter(dict, Json) + + +class ExtendedModelData(ModelData): + """An example Item.""" + + string: str + integer: int + data: Dict[str, Any] + + +class ExtendedCreator(SyncCreate[ExtendedModelData]): + """Add custom json loading to SyncModel.create.""" + + # pylint: disable=too-few-public-methods + + def one(self, item: ExtendedModelData) -> ExtendedModelData: + """Override default SyncModel.create.one method.""" + columns: List[sql.Identifier] = [] + values: List[sql.Literal] = [] + + for column, value in item.dict().items(): + if column == 'data': + values.append(sql.Literal(json.dumps(value))) + else: + values.append(sql.Literal(value)) + + columns.append(sql.Identifier(column)) + + query = sql.SQL( + 'INSERT INTO {table} ({columns}) ' + 'VALUES ({values}) ' + 'RETURNING *;' + ).format( + table=self._table, + columns=sql.SQL(',').join(columns), + values=sql.SQL(',').join(values), + ) + + query_result: List[RealDictRow] = \ + self._client.execute_and_return(query) + + result = self._return_constructor(**query_result[0]) + + return result + + +class ExtendedReader(SyncRead[ExtendedModelData]): + """Add custom method to Model.read.""" + + def all_by_string(self, string: str) -> List[ExtendedModelData]: + """Read all rows with matching `string` value.""" + query = sql.SQL( + 'SELECT * ' + 'FROM {table} ' + 'WHERE string = {string};' + ).format( + table=self._table, + string=sql.Identifier(string) + ) + + result: List[ExtendedModelData] = self \ + ._client.execute_and_return(query) + + return result + + def all(self) -> List[ExtendedModelData]: + """Read all rows.""" + query = sql.SQL('SELECT * FROM {table}').format( + table=self._table) + + query_result: List[RealDictRow] = \ + self._client.execute_and_return(query) + + result = [self._return_constructor(**item) for item in query_result] + + return result + + +class ExtendedModel(SyncModel[ExtendedModelData]): + """Build an ExampleItem SyncModel instance.""" + + read: ExtendedReader + create: ExtendedCreator + + def __init__(self, client: SyncClient) -> None: + super().__init__(client, 'extended_model', ExtendedModelData) + self.read = ExtendedReader( + self.client, + self.table, + ExtendedModelData + ) + self.create = ExtendedCreator( + self.client, + self.table, + ExtendedModelData + ) diff --git a/example_sync/example/models/extended_model.sql b/example_sync/example/models/extended_model.sql new file mode 100644 index 0000000..bc42da9 --- /dev/null +++ b/example_sync/example/models/extended_model.sql @@ -0,0 +1,8 @@ +CREATE EXTENSION hstore; + +CREATE TABLE IF NOT EXISTS "extended_model" ( + "id" uuid PRIMARY KEY, + "string" varchar(255), + "integer" smallint, + "data" jsonb +); diff --git a/example_sync/example/py.typed b/example_sync/example/py.typed new file mode 100644 index 0000000..f5642f7 --- /dev/null +++ b/example_sync/example/py.typed @@ -0,0 +1 @@ +Marker diff --git a/example_sync/manage.py b/example_sync/manage.py new file mode 100644 index 0000000..4c8b7bc --- /dev/null +++ b/example_sync/manage.py @@ -0,0 +1,306 @@ +"""Script for managing database migrations. + +Exposes two methods: + sync diff app to live db & apply changes, use for dev primarily + pending diff schema dump & save to file, used for prod primarily +""" + +from contextlib import contextmanager +import io +import os +import random +import string +import sys +import time +from typing import Any, Optional, Generator, List, Tuple + +from migra import Migration +from psycopg2 import connect, OperationalError # type: ignore +from psycopg2 import sql +from psycopg2.sql import Composed +from sqlbag import ( + S, + load_sql_from_folder) + +# DB_USER = os.getenv('DB_USER', 'postgres') +# DB_PASS = os.getenv('DB_PASS', 'postgres') +# DB_HOST = os.getenv('DB_HOST', 'localhost') +# DB_NAME = os.getenv('DB_NAME', 'postgres') +DB_USER = os.getenv('DB_USER', 'test') +DB_PASS = os.getenv('DB_PASS', 'pass') +DB_HOST = os.getenv('DB_HOST', 'localhost') +DB_NAME = os.getenv('DB_NAME', 'dev') + +DB_URL = f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:5432/{DB_NAME}' + + +def _try_connect(dsn: str, retries: int = 1) -> Any: + # PENDS python 3.9 support in pylint + # pylint: disable=unsubscriptable-object + connection: Optional[Any] = None + + print(f'Attempting to connect to database at {dsn}') + + while connection is None: + try: + connection = connect(dsn) + except OperationalError as err: + print(type(err)) + if retries > 12: + raise ConnectionError( + 'Max number of connection attempts has been reached (12)' + ) from err + + print( + f'Connection failed ({retries} time(s))' + 'retrying again in 5 seconds...') + + time.sleep(5) + return _try_connect(dsn, retries + 1) + + return connection + + +def _resilient_connect(dsn: str) -> Any: + """Handle connecting to db, attempt to reconnect on failure.""" + return _try_connect(dsn) + + +def _prompt(question: str) -> bool: + """Prompt user with simple yes/no question & return True if answer is y.""" + print(f'{question} ', end='') + return input().strip().lower() == 'y' + + +def _temp_name() -> str: + """ + Generate a temporary name. + + Prefixes a string of 10 random characters with 'temp_db' & returns it. + """ + random_letters = [random.choice(string.ascii_lowercase) for _ in range(10)] + rnd = "".join(random_letters) + tempname = 'temp_db' + rnd + + return tempname + + +def _create_db(cursor: Any, name: str) -> None: + """Create a database with a given name.""" + query = sql.SQL('create database {name};').format( + name=sql.Identifier(name)) + + cursor.execute(query) + + +def _kill_query(dbname: str) -> Composed: + """Build & return SQL query that kills connections to a given database.""" + query = """ + SELECT + pg_terminate_backend(pg_stat_activity.pid) + FROM + pg_stat_activity + WHERE + pg_stat_activity.datname = {dbname} + AND pid <> pg_backend_pid(); + """ + + return sql.SQL(query).format(dbname=sql.Literal(dbname)) + + +def _drop_db(cursor: Any, name: str) -> None: + """Drop a database with a given name.""" + revoke: Composed = sql.SQL( + 'REVOKE CONNECT ON DATABASE {name} FROM PUBLIC;' + ).format( + name=sql.Identifier(name)) + + kill_other_connections: Composed = _kill_query(name) + + drop: Composed = sql.SQL('DROP DATABASE {name};').format( + name=sql.Identifier(name)) + + cursor.execute(revoke) + cursor.execute(kill_other_connections) + cursor.execute(drop) + + +def _load_pre_migration(dsn: str) -> None: + """ + Load schema for production server. + + Uses sql schema file saved at migrations/production.dump.sql + """ + connection = _resilient_connect(dsn) + connection.set_session(autocommit=True) + + with connection.cursor() as cursor: + cursor.execute(open('migrations/production.dump.sql', 'r').read()) + + connection.close() + + +def _load_from_app(session: S) -> None: + """ + Load schema from application source. + + Uses all .sql files stored at ./example/models/** + """ + load_sql_from_folder(session, 'example/models') + + +@contextmanager +def _get_schema_diff( + from_db_url: str, + target_db_url: str +) -> Generator[Tuple[str, Migration], Any, Any]: + """Get schema diff between two databases using djrobstep/migra.""" + with S(from_db_url) as from_schema_session, \ + S(target_db_url) as target_schema_session: + migration = Migration( + from_schema_session, + target_schema_session) + migration.set_safety(False) + migration.add_all_changes() + + yield migration.sql, migration + + +@contextmanager +def _temp_db(host: str, user: str, password: str) -> Generator[str, Any, Any]: + """Create, yield, & remove a temporary database as context.""" + connection = _resilient_connect( + f'postgresql://{user}:{password}@{host}/{DB_NAME}') + connection.set_session(autocommit=True) + name = _temp_name() + + with connection.cursor() as cursor: + _create_db(cursor, name) + yield f'postgresql://{user}:{password}@{host}/{name}' + _drop_db(cursor, name) + + connection.close() + + +def sync(args: List[str]) -> None: + """ + Compare live database to application schema & apply changes to database. + + Uses running database specified for application via + `DB_[USER|PASS|HOST|NAME]` environment variables & compares to application + schema defined at `./example/models/**/*.sql`. + """ + # define if prompts are needed or not + no_prompt = False + + if 'no_prompt' in args: + no_prompt = True + + # create temp database for app schema + with _temp_db( + host=DB_HOST, + user=DB_USER, + password=DB_PASS + ) as temp_db_url: + print(f'db url: {DB_URL}') + print(f'temp url: {temp_db_url}') + + # create sessions for current db state & target schema + with S(DB_URL) as from_schema_session, \ + S(temp_db_url) as target_schema_session: + # load target schema to temp db + _load_from_app(target_schema_session) + + # diff target db & current db + migration = Migration( + from_schema_session, + target_schema_session) + migration.set_safety(False) + migration.add_all_changes() + + # handle changes + if migration.statements: + print('\nTHE FOLLOWING CHANGES ARE PENDING:', end='\n\n') + print(migration.sql) + + if no_prompt: + print('Applying...') + migration.apply() + print('Changes applied.') + else: + if _prompt('Apply these changes?'): + print('Applying...') + migration.apply() + print('Changes applied.') + else: + print('Not applying.') + + else: + print('Already synced.') + + +def pending(_: List[str]) -> None: + """ + Compare a production schema to application schema & save difference. + + Uses production schema stored at `./migrations/production.dump.sql` & + application schema defined at `./example/models/**/*.sql`, then saves + difference at `./migrations/pending.sql`. + """ + # create temporary databases for prod & target schemas + with _temp_db( + host=DB_HOST, + user=DB_USER, + password=DB_PASS + ) as prod_schema_db_url, _temp_db( + host=DB_HOST, + user=DB_USER, + password=DB_PASS + ) as target_db_url: + print(f'prod temp url: {prod_schema_db_url}') + print(f'target temp url: {target_db_url}') + + # create sessions for both databases + with S(prod_schema_db_url) as from_schema_session, \ + S(target_db_url) as target_schema_session: + # load both schemas into their databases + _load_pre_migration(prod_schema_db_url) + _load_from_app(target_schema_session) + + # get a diff + migration = Migration( + from_schema_session, + target_schema_session) + migration.set_safety(False) + migration.add_all_changes() + + if migration.statements: + print('\nTHE FOLLOWING CHANGES ARE PENDING:', end='\n\n') + print(migration.sql) + else: + print('No changes needed, setting pending.sql to empty.') + + # write pending changes to file + with io.open('migrations/pending.sql', 'w') as file: + file.write(migration.sql) + + print('Changes written to ./migrations/pending.sql.') + + +if __name__ == '__main__': + tasks = { + 'sync': sync, + 'pending': pending, + } + + try: + ARGS: List[str] = [] + + if len(sys.argv) > 2: + ARGS = sys.argv[2:] + + tasks[sys.argv[1]](ARGS) + except KeyError: + print('No such task') + except IndexError: + print('No task given') diff --git a/example_sync/requirements/dev.txt b/example_sync/requirements/dev.txt new file mode 100644 index 0000000..3059dcb --- /dev/null +++ b/example_sync/requirements/dev.txt @@ -0,0 +1,7 @@ +autopep8>=1.5.4 +mypy>=0.800,<0.810 +pycodestyle>=2.6.0 +pydocstyle>=5.1.1 +pylint>=2.6.0 +migra>=3.0.16,<4.0.0 +sqlbag>=0.1.15,<0.2.0 diff --git a/example_sync/requirements/prod.txt b/example_sync/requirements/prod.txt new file mode 100644 index 0000000..e60595e --- /dev/null +++ b/example_sync/requirements/prod.txt @@ -0,0 +1 @@ +https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.3.0/db_wrapper-2.3.0-py3-none-any.whl diff --git a/example_sync/stubs/psycopg2/extras.pyi b/example_sync/stubs/psycopg2/extras.pyi new file mode 100644 index 0000000..65451da --- /dev/null +++ b/example_sync/stubs/psycopg2/extras.pyi @@ -0,0 +1,195 @@ +from typing import Any, Optional +# from .compat import PY2 as PY2, PY3 as PY3, lru_cache as lru_cache +# from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident +# from collections import OrderedDict +# from psycopg2._ipaddress import register_ipaddress as register_ipaddress +# from psycopg2._json import Json as Json, json as json, register_default_json as register_default_json, register_default_jsonb as register_default_jsonb, register_json as register_json +# from psycopg2._psycopg import REPLICATION_LOGICAL as REPLICATION_LOGICAL, REPLICATION_PHYSICAL as REPLICATION_PHYSICAL, ReplicationConnection as _replicationConnection, ReplicationCursor as _replicationCursor, ReplicationMessage as ReplicationMessage +# from psycopg2._range import DateRange as DateRange, DateTimeRange as DateTimeRange, DateTimeTZRange as DateTimeTZRange, NumericRange as NumericRange, Range as Range, RangeAdapter as RangeAdapter, RangeCaster as RangeCaster, register_range as register_range +# +# +# class DictCursorBase(_cursor): +# row_factory: Any = ... +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# def fetchone(self): ... +# def fetchmany(self, size: Optional[Any] = ...): ... +# def fetchall(self): ... +# def __iter__(self) -> Any: ... +# +# +# class DictConnection(_connection): +# def cursor(self, *args: Any, **kwargs: Any): ... +# +# +# class DictCursor(DictCursorBase): +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# index: Any = ... +# def execute(self, query: Any, vars: Optional[Any] = ...): ... +# def callproc(self, procname: Any, vars: Optional[Any] = ...): ... +# +# +# class DictRow(list): +# def __init__(self, cursor: Any) -> None: ... +# def __getitem__(self, x: Any): ... +# def __setitem__(self, x: Any, v: Any) -> None: ... +# def items(self): ... +# def keys(self): ... +# def values(self): ... +# def get(self, x: Any, default: Optional[Any] = ...): ... +# def copy(self): ... +# def __contains__(self, x: Any): ... +# def __reduce__(self): ... +# +# +# class RealDictConnection(_connection): +# def cursor(self, *args: Any, **kwargs: Any): ... +# +# +# class RealDictCursor(DictCursorBase): +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# column_mapping: Any = ... +# def execute(self, query: Any, vars: Optional[Any] = ...): ... +# def callproc(self, procname: Any, vars: Optional[Any] = ...): ... +# +# +# class RealDictRow(OrderedDict): +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# def __setitem__(self, key: Any, value: Any) -> None: ... +# +# +# class NamedTupleConnection(_connection): +# def cursor(self, *args: Any, **kwargs: Any): ... +# +# +# class NamedTupleCursor(_cursor): +# Record: Any = ... +# MAX_CACHE: int = ... +# def execute(self, query: Any, vars: Optional[Any] = ...): ... +# def executemany(self, query: Any, vars: Any): ... +# def callproc(self, procname: Any, vars: Optional[Any] = ...): ... +# def fetchone(self): ... +# def fetchmany(self, size: Optional[Any] = ...): ... +# def fetchall(self): ... +# def __iter__(self) -> Any: ... +# +# +# class LoggingConnection(_connection): +# log: Any = ... +# def initialize(self, logobj: Any) -> None: ... +# def filter(self, msg: Any, curs: Any): ... +# def cursor(self, *args: Any, **kwargs: Any): ... +# +# +# class LoggingCursor(_cursor): +# def execute(self, query: Any, vars: Optional[Any] = ...): ... +# def callproc(self, procname: Any, vars: Optional[Any] = ...): ... +# +# +# class MinTimeLoggingConnection(LoggingConnection): +# def initialize(self, logobj: Any, mintime: int = ...) -> None: ... +# def filter(self, msg: Any, curs: Any): ... +# def cursor(self, *args: Any, **kwargs: Any): ... +# +# +# class MinTimeLoggingCursor(LoggingCursor): +# timestamp: Any = ... +# def execute(self, query: Any, vars: Optional[Any] = ...): ... +# def callproc(self, procname: Any, vars: Optional[Any] = ...): ... +# +# +# class LogicalReplicationConnection(_replicationConnection): +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# +# +# class PhysicalReplicationConnection(_replicationConnection): +# def __init__(self, *args: Any, **kwargs: Any) -> None: ... +# +# +# class StopReplication(Exception): +# ... +# +# +# class ReplicationCursor(_replicationCursor): +# def create_replication_slot( +# self, slot_name: Any, slot_type: Optional[Any] = ..., output_plugin: Optional[Any] = ...) -> None: ... +# +# def drop_replication_slot(self, slot_name: Any) -> None: ... +# def start_replication(self, slot_name: Optional[Any] = ..., slot_type: Optional[Any] = ..., start_lsn: int = ..., +# timeline: int = ..., options: Optional[Any] = ..., decode: bool = ..., status_interval: int = ...) -> None: ... +# +# def fileno(self): ... +# +# +# class UUID_adapter: +# def __init__(self, uuid: Any) -> None: ... +# def __conform__(self, proto: Any): ... +# def getquoted(self): ... +# +# + +# pylint: disable=unsubscriptable-object +def register_uuid(oids: Optional[Any] = ..., + conn_or_curs: Optional[Any] = ...) -> None: ... +# +# +# class Inet: +# addr: Any = ... +# def __init__(self, addr: Any) -> None: ... +# def prepare(self, conn: Any) -> None: ... +# def getquoted(self): ... +# def __conform__(self, proto: Any): ... +# +# +# def register_inet(oid: Optional[Any] = ..., +# conn_or_curs: Optional[Any] = ...): ... +# +# +# def wait_select(conn: Any) -> None: ... +# +# +# class HstoreAdapter: +# wrapped: Any = ... +# def __init__(self, wrapped: Any) -> None: ... +# conn: Any = ... +# getquoted: Any = ... +# def prepare(self, conn: Any) -> None: ... +# @classmethod +# def parse(self, s: Any, cur: Any, _bsdec: Any = ...): ... +# @classmethod +# def parse_unicode(self, s: Any, cur: Any): ... +# @classmethod +# def get_oids(self, conn_or_curs: Any): ... +# +# +# def register_hstore(conn_or_curs: Any, globally: bool = ..., unicode: bool = ..., +# oid: Optional[Any] = ..., array_oid: Optional[Any] = ...) -> None: ... +# +# +# class CompositeCaster: +# name: Any = ... +# schema: Any = ... +# oid: Any = ... +# array_oid: Any = ... +# attnames: Any = ... +# atttypes: Any = ... +# typecaster: Any = ... +# array_typecaster: Any = ... +# def __init__(self, name: Any, oid: Any, attrs: Any, +# array_oid: Optional[Any] = ..., schema: Optional[Any] = ...) -> None: ... +# +# def parse(self, s: Any, curs: Any): ... +# def make(self, values: Any): ... +# @classmethod +# def tokenize(self, s: Any): ... +# +# +# def register_composite(name: Any, conn_or_curs: Any, +# globally: bool = ..., factory: Optional[Any] = ...): ... +# +# +# def execute_batch(cur: Any, sql: Any, argslist: Any, +# page_size: int = ...) -> None: ... +# +# +# def execute_values(cur: Any, sql: Any, argslist: Any, +# template: Optional[Any] = ..., page_size: int = ..., fetch: bool = ...): ... diff --git a/example_sync/stubs/psycopg2/sql.pyi b/example_sync/stubs/psycopg2/sql.pyi new file mode 100644 index 0000000..6b9f1c5 --- /dev/null +++ b/example_sync/stubs/psycopg2/sql.pyi @@ -0,0 +1,76 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-self-use +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=super-init-not-called + +from __future__ import annotations +from typing import Any, Optional, Union, Hashable, Iterable, List, Tuple + + +class Composable: + # pylint: disable=unsubscriptable-object + def __init__( + self, + wrapped: Union[str, List[Union[str, Composable]]] + ) -> None: ... + # pylint: disable=unsubscriptable-object + def as_string(self, context: Any) -> Optional[str]: ... + def __add__(self, other: Any) -> Composable: ... + # pylint: disable=invalid-name + def __mul__(self, n: Any) -> Composable: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + + +class Composed(Composable): + def __init__(self, seq: List[Composable]) -> None: ... + @property + def seq(self) -> List[str]: ... + def as_string(self, context: Any) -> str: ... + def __iter__(self) -> Iterable[str]: ... + def __add__(self, other: Any) -> Composed: ... + def join(self, joiner: Any) -> Composed: ... + + +class SQL(Composable): + def __init__(self, string: str) -> None: ... + @property + def string(self) -> str: ... + def as_string(self, context: Any) -> str: ... + def format(self, *args: Any, **kwargs: Any) -> Composed: ... + def join(self, seq: Any) -> Composed: ... + + +class Identifier(Composable): + # pylint: disable=unsubscriptable-object + def __init__( + self, + *strings: Union[str, Hashable] + ) -> None: ... + @property + def strings(self) -> Tuple[Union[str, Hashable]]: ... + @property + # pylint: disable=unsubscriptable-object + def string(self) -> Optional[str]: ... + def as_string(self, context: Any) -> str: ... + + +class Literal(Composable): + @property + def wrapped(self) -> List[Composable]: ... + def as_string(self, context: Any) -> str: ... + + +class Placeholder(Composable): + # pylint: disable=unsubscriptable-object + def __init__(self, name: Optional[str] = ...) -> None: ... + @property + def name(self) -> str: ... + def as_string(self, context: Any) -> str: ... + + +NULL: Any +DEFAULT: Any diff --git a/requirements/dev.txt b/requirements/dev.txt index 63dd0b8..77681be 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,6 +1,6 @@ autopep8>=1.5.4 +build>=0.3.1,<0.4.0 mypy>=0.800,<0.810 pycodestyle>=2.6.0 pydocstyle>=5.1.1 pylint>=2.6.0 - diff --git a/scripts/install b/scripts/install new file mode 100755 index 0000000..b5d3924 --- /dev/null +++ b/scripts/install @@ -0,0 +1,66 @@ +#!/usr/bin/env bash + +# +# NAVIGATE TO CORRECT DIRECTORY +# + +# start by going to script dir so all movements +# from here are relative +SCRIPT_DIR=`dirname $(realpath "$0")` +cd $SCRIPT_DIR +# go up to root +cd .. + + +# +# INSTALL FROM LISTS +# + +function dev { + echo "" + echo "Installing dev requirements..." + echo "" + dev_result=0 + + # install from dev list + pip install -r requirements/dev.txt +} + +function prod { + echo "" + echo "Installing prod requirements..." + echo "" + prod_result=0 + + # install from dev list + pip install -r requirements/prod.txt +} + +# Install dev, prod, or all requirements depending on argument given +if [ $# -eq 0 ]; then + dev + prod + + if [[ $dev_result != 0 && $prod_result != 0 ]]; then + echo "Errors found in both dev & prod installation. See output above." + exit $dev_result + elif [[ $dev_result != 0 && $prod_result == 0 ]]; then + echo "Errors found in dev installation. See output above." + exit $dev_result + elif [[ $dev_result == 0 && $prod_result != 0 ]]; then + echo "Errors found in prod installation. See output above." + exit $prod_result + else + exit 0 + fi + +elif [[ $1 == 'dev' || $1 == 'development' ]]; then + dev ${@:2} + exit $dev_result +elif [[ $1 == 'prod' || $1 == 'production' ]]; then + prod ${@:2} + exit $prod_result +else + echo "Bad argument given, either specify \`dev\` or \`prod\` requirements by giving either word as your first argument to this script, or run both by giving no arguments." + exit 1 +fi diff --git a/scripts/pj b/scripts/pj new file mode 100755 index 0000000..160e764 --- /dev/null +++ b/scripts/pj @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + + +# +# NAVIGATE TO CORRECT DIRECTORY +# + +# start by going to script dir so all movements +# from here are relative +SCRIPT_DIR=`dirname $(realpath "$0")` +cd $SCRIPT_DIR + + +# +# PARSE & RUN COMMANDS +# + +# check if expansion of $1 is null +# from: https://stackoverflow.com/a/6482403 +if [ -z $1 ]; then + echo 'A command must be given.' + exit 1 +fi + +function run { + $1 ${@:2} + + exit $? +} + +# first check hard-coded shortcuts +if [[ $1 = "t" ]]; then + run ./test ${@:2} +fi +if [[ $1 = "d" ]]; then + run ./dev ${@:2} +fi + +# otherwise, search for exact match to script name +for file in ./*; do + file_name=${file#"./"} + + if [[ $file_name = $1 ]]; then + run $file ${@:2} + fi +done diff --git a/scripts/typecheck b/scripts/typecheck index 13de26b..f68d5a0 100755 --- a/scripts/typecheck +++ b/scripts/typecheck @@ -38,7 +38,7 @@ function check { echo "" echo "Checking ./tests..." echo "" - mypy tests + mypy test mypy_tests_result=$? if [ $mypy_tests_result != 0 ]; then diff --git a/setup.py b/setup.py index b22bbeb..ead47ee 100644 --- a/setup.py +++ b/setup.py @@ -3,12 +3,15 @@ with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() +short_description = \ + "Simple wrapper on aiopg to handle postgres connections & basic Models." + setuptools.setup( name="db_wrapper", - version="2.0.1", + version="2.4.0", author="Andrew Chang-DeWitt", author_email="andrew@andrew-chang-dewitt.dev", - description="Simple wrapper on aiopg to handle postgres connections & basic Models.", + description=short_description, long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/cheese-drawer/lib-python-db-wrapper/", diff --git a/stubs/psycopg2/__init__.pyi b/stubs/psycopg2/__init__.pyi new file mode 100644 index 0000000..2ec209d --- /dev/null +++ b/stubs/psycopg2/__init__.pyi @@ -0,0 +1,56 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Optional +from psycopg2._psycopg import ( + BINARY, + Binary, + connection, + DATETIME, + DataError, + DatabaseError, + Date, + DateFromTicks, + Error, + IntegrityError, + InterfaceError, + InternalError, + NUMBER, + NotSupportedError, + OperationalError, + ProgrammingError, + ROWID, + STRING, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, + Warning, + __libpq_version__, + apilevel, + paramstyle, + threadsafety +) + +connection = connection +OperationalError = OperationalError + + +def connect( + dsn: Optional[Any] = ..., + connection_factory: Optional[Any] = ..., + cursor_factory: Optional[Any] = ..., + **kwargs: Any +) -> connection: ... diff --git a/stubs/psycopg2/_ipaddress.pyi b/stubs/psycopg2/_ipaddress.pyi new file mode 100644 index 0000000..0af4c8c --- /dev/null +++ b/stubs/psycopg2/_ipaddress.pyi @@ -0,0 +1,33 @@ +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-module-docstring +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called +# pylint: disable=unused-import +# pylint: disable=useless-import-alias +# pylint: disable=line-too-long + +from typing import Any, Optional + +from psycopg2.extensions import ( + QuotedString, + new_array_type, + new_type, + register_adapter, + register_type, +) + +ipaddress: Any + + +def register_ipaddress(conn_or_curs: Optional[Any] = ...) -> None: ... +def cast_interface(s: Any, cur: Optional[Any] = ...) -> Any: ... +def cast_network(s: Any, cur: Optional[Any] = ...) -> Any: ... +def adapt_ipaddress(obj: Any) -> Any: ... diff --git a/stubs/psycopg2/_json.pyi b/stubs/psycopg2/_json.pyi new file mode 100644 index 0000000..276088e --- /dev/null +++ b/stubs/psycopg2/_json.pyi @@ -0,0 +1,61 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-arguments +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Optional + +from psycopg2._psycopg import ( + ISQLQuote, + QuotedString, + new_array_type, + new_type, + register_type +) + +JSON_OID: int +JSONARRAY_OID: int +JSONB_OID: int +JSONBARRAY_OID: int + + +class Json: + adapted: Any = ... + def __init__(self, adapted: Any, dumps: Optional[Any] = ...) -> None: ... + def __conform__(self, proto: Any) -> Any: ... + def dumps(self, obj: Any) -> Any: ... + def prepare(self, conn: Any) -> None: ... + def getquoted(self) -> Any: ... + + +def register_json( + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ..., + oid: Optional[Any] = ..., + array_oid: Optional[Any] = ..., + name: str = ... +) -> Any: ... + + +def register_default_json( + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ...) -> Any: ... + + +def register_default_jsonb( + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ...) -> Any: ... diff --git a/stubs/psycopg2/_psycopg.pyi b/stubs/psycopg2/_psycopg.pyi new file mode 100644 index 0000000..b2b617b --- /dev/null +++ b/stubs/psycopg2/_psycopg.pyi @@ -0,0 +1,576 @@ +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-module-docstring +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Iterator + +__libpq_version__: Any +BINARY: Any +BINARYARRAY: Any +BOOLEAN: Any +BOOLEANARRAY: Any +BYTES: Any +BYTESARRAY: Any +CIDRARRAY: Any +DATE: Any +DATEARRAY: Any +DATETIME: Any +DATETIMEARRAY: Any +DATETIMETZ: Any +DATETIMETZARRAY: Any +DECIMAL: Any +DECIMALARRAY: Any +FLOAT: Any +FLOATARRAY: Any +INETARRAY: Any +INTEGER: Any +INTEGERARRAY: Any +INTERVAL: Any +INTERVALARRAY: Any +LONGINTEGER: Any +LONGINTEGERARRAY: Any +MACADDRARRAY: Any +NUMBER: Any +PYDATE: Any +PYDATEARRAY: Any +PYDATETIME: Any +PYDATETIMEARRAY: Any +PYDATETIMETZ: Any +PYDATETIMETZARRAY: Any +PYINTERVAL: Any +PYINTERVALARRAY: Any +PYTIME: Any +PYTIMEARRAY: Any +REPLICATION_LOGICAL: int +REPLICATION_PHYSICAL: int +ROWID: Any +ROWIDARRAY: Any +STRING: Any +STRINGARRAY: Any +TIME: Any +TIMEARRAY: Any +UNICODE: Any +UNICODEARRAY: Any +UNKNOWN: Any +adapters: Any +apilevel: str +binary_types: Any +encodings: Any +paramstyle: str +sqlstate_errors: Any +string_types: Any +threadsafety: int + +newdate: Any +newtime: Any +newtimestamp: Any +newtypeobject: Any + + +def Date(year: Any, month: Any, day: Any) -> newdate: ... +def DateFromPy(*args: Any, **kwargs: Any) -> Any: ... +def DateFromTicks(ticks: Any) -> newdate: ... +def IntervalFromPy(*args: Any, **kwargs: Any) -> Any: ... + + +def Time( + hour: Any, + minutes: Any, + seconds: Any, + tzinfo: Any = ... +) -> newtime: ... + + +def TimeFromPy(*args: Any, **kwargs: Any) -> Any: ... +def TimeFromTicks(ticks: Any) -> newtime: ... + + +def Timestamp( + year: Any, + month: Any, + day: Any, + hour: Any, + minutes: Any, + seconds: Any, + tzinfo: Any = ... +) -> newtimestamp: ... + + +def TimestampFromPy(*args: Any, **kwargs: Any) -> Any: ... +def TimestampFromTicks(ticks: Any) -> newtimestamp: ... +def _connect(*args: Any, **kwargs: Any) -> Any: ... +def adapt(*args: Any, **kwargs: Any) -> Any: ... +def encrypt_password(*args: Any, **kwargs: Any) -> Any: ... +def get_wait_callback(*args: Any, **kwargs: Any) -> Any: ... +def libpq_version(*args: Any, **kwargs: Any) -> Any: ... +def new_array_type(oids: Any, name: Any, baseobj: Any) -> newtypeobject: ... +def new_type(oids: Any, name: Any, castobj: Any) -> newtypeobject: ... +def parse_dsn(*args: Any, **kwargs: Any) -> Any: ... +def quote_ident(*args: Any, **kwargs: Any) -> Any: ... +def register_type(*args: Any, **kwargs: Any) -> Any: ... +def set_wait_callback(_none: Any) -> Any: ... + + +class AsIs: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +prepareforbinaryencodingusingconn: Any + + +class Binary: + adapted: Any = ... + buffer: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def prepare(conn: Any) -> prepareforbinaryencodingusingconn: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class Boolean: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class Column: + display_size: Any = ... + internal_size: Any = ... + name: Any = ... + null_ok: Any = ... + precision: Any = ... + scale: Any = ... + table_column: Any = ... + table_oid: Any = ... + type_code: Any = ... + __hash__: Any = ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __eq__(self, other: Any) -> Any: ... + def __ge__(self, other: Any) -> Any: ... + def __getitem__(self, index: Any) -> Any: ... + def __getstate__(self) -> Any: ... + def __gt__(self, other: Any) -> Any: ... + def __le__(self, other: Any) -> Any: ... + def __len__(self) -> Any: ... + def __lt__(self, other: Any) -> Any: ... + def __ne__(self, other: Any) -> Any: ... + def __setstate__(self, state: Any) -> Any: ... + + +class ConnectionInfo: + backend_pid: Any = ... + dbname: Any = ... + dsn_parameters: Any = ... + error_message: Any = ... + host: Any = ... + needs_password: Any = ... + options: Any = ... + password: Any = ... + port: Any = ... + protocol_version: Any = ... + server_version: Any = ... + socket: Any = ... + ssl_attribute_names: Any = ... + ssl_in_use: Any = ... + status: Any = ... + transaction_status: Any = ... + used_password: Any = ... + user: Any = ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def parameter_status(self, *args: Any, **kwargs: Any) -> Any: ... + def ssl_attribute(self, *args: Any, **kwargs: Any) -> Any: ... + + +class Error(Exception): + cursor: Any = ... + diag: Any = ... + pgcode: Any = ... + pgerror: Any = ... + def __init__(self, *args: Any, ** + kwargs: Any) -> None: ... + + def __reduce__(self) -> Any: ... + def __setstate__(self, state: Any) -> Any: ... + + +class DatabaseError(Error): + ... + + +class DataError(DatabaseError): + ... + + +class Decimal: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class Diagnostics: + column_name: Any = ... + constraint_name: Any = ... + context: Any = ... + datatype_name: Any = ... + internal_position: Any = ... + internal_query: Any = ... + message_detail: Any = ... + message_hint: Any = ... + message_primary: Any = ... + schema_name: Any = ... + severity: Any = ... + severity_nonlocalized: Any = ... + source_file: Any = ... + source_function: Any = ... + source_line: Any = ... + sqlstate: Any = ... + statement_position: Any = ... + table_name: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + +class Float: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class ISQLQuote: + _wrapped: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getbinary(self, *args: Any, **kwargs: Any) -> Any: ... + def getbuffer(self, *args: Any, **kwargs: Any) -> Any: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + + +class Int: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class IntegrityError(DatabaseError): + ... + + +class InterfaceError(Error): + ... + + +class InternalError(DatabaseError): + ... + + +class List: + adapted: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def prepare(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class NotSupportedError(DatabaseError): + ... + + +class Notify: + channel: Any = ... + payload: Any = ... + pid: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __eq__(self, other: Any) -> Any: ... + def __ge__(self, other: Any) -> Any: ... + def __getitem__(self, index: Any) -> Any: ... + def __gt__(self, other: Any) -> Any: ... + def __hash__(self) -> Any: ... # pylint: disable=invalid-hash-returned + def __le__(self, other: Any) -> Any: ... + def __len__(self) -> Any: ... + def __lt__(self, other: Any) -> Any: ... + def __ne__(self, other: Any) -> Any: ... + + +class OperationalError(DatabaseError): + ... + + +class ProgrammingError(DatabaseError): + ... + + +class QueryCanceledError(OperationalError): + ... + + +class QuotedString: + adapted: Any = ... + buffer: Any = ... + encoding: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + def prepare(self, *args: Any, **kwargs: Any) -> Any: ... + def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class connection: + DataError: Any = ... + DatabaseError: Any = ... + Error: Any = ... + IntegrityError: Any = ... + InterfaceError: Any = ... + InternalError: Any = ... + NotSupportedError: Any = ... + OperationalError: Any = ... + ProgrammingError: Any = ... + Warning: Any = ... + async_: Any = ... + autocommit: Any = ... + binary_types: Any = ... + closed: Any = ... + cursor_factory: Any = ... + deferrable: Any = ... + dsn: Any = ... + encoding: Any = ... + info: Any = ... + isolation_level: Any = ... + notices: Any = ... + notifies: Any = ... + pgconn_ptr: Any = ... + protocol_version: Any = ... + readonly: Any = ... + server_version: Any = ... + status: Any = ... + string_types: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def cancel(self, *args: Any, **kwargs: Any) -> Any: ... + def close(self, *args: Any, **kwargs: Any) -> Any: ... + def commit(self, *args: Any, **kwargs: Any) -> Any: ... + + @staticmethod + def cursor( + name: Any = ..., + cursor_factory: Any = ..., + withhold: Any = ... + ) -> Any: ... + + def fileno(self, *args: Any, **kwargs: Any) -> Any: ... + def get_backend_pid(self, *args: Any, **kwargs: Any) -> Any: ... + def get_dsn_parameters(self, *args: Any, **kwargs: Any) -> Any: ... + def get_native_connection(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def get_parameter_status(parameter: Any) -> Any: ... + def get_transaction_status(self, *args: Any, **kwargs: Any) -> Any: ... + def isexecuting(self, *args: Any, **kwargs: Any) -> Any: ... + + @staticmethod + def lobject( + oid: Any = ..., + mode: Any = ..., + new_oid: Any = ..., + new_file: Any = ..., + lobject_factory: Any = ... + ) -> Any: ... + + def poll(self, *args: Any, **kwargs: Any) -> Any: ... + def reset(self, *args: Any, **kwargs: Any) -> Any: ... + def rollback(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def set_client_encoding(encoding: Any) -> Any: ... + @staticmethod + def set_isolation_level(level: Any) -> Any: ... + def set_session(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def tpc_begin(xid: Any) -> Any: ... + def tpc_commit(self, *args: Any, **kwargs: Any) -> Any: ... + def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any: ... + def tpc_recover(self, *args: Any, **kwargs: Any) -> Any: ... + def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def xid(format_id: Any, gtrid: Any, bqual: Any) -> Any: ... + def __enter__(self) -> Any: ... + def __exit__(self, type: Any, value: Any, traceback: Any) -> Any: ... + + +class cursor: + arraysize: Any = ... + binary_types: Any = ... + closed: Any = ... + connection: Any = ... + description: Any = ... + itersize: Any = ... + lastrowid: Any = ... + name: Any = ... + pgresult_ptr: Any = ... + query: Any = ... + row_factory: Any = ... + rowcount: Any = ... + rownumber: Any = ... + scrollable: Any = ... + statusmessage: Any = ... + string_types: Any = ... + typecaster: Any = ... + tzinfo_factory: Any = ... + withhold: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @staticmethod + def callproc(procname: Any, parameters: Any = ...) -> Any: ... + @staticmethod + def cast(oid: Any, s: Any) -> Any: ... + def close(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def copy_expert(sql: Any, file: Any, size: Any = ...) -> Any: ... + + @staticmethod + def copy_from( + file: Any, + table: Any, + sep: Any = ..., + null: Any = ..., + size: Any = ..., + columns: Any = ..., + ) -> Any: ... + + @staticmethod + def copy_to( + file: Any, + table: Any, + sep: Any = ..., + null: Any = ..., + columns: Any = ... + ) -> Any: ... + + @staticmethod + def execute(query: Any, params: Any = ...) -> Any: ... + @staticmethod + def executemany(query: Any, vars_list: Any) -> Any: ... + def fetchall(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def fetchmany(size: Any = ...) -> listoftuple: ... + def fetchone(self, *args: Any, **kwargs: Any) -> Any: ... + def mogrify(self, *args: Any, **kwargs: Any) -> Any: ... + def nextset(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def scroll(value: Any, mode: Any = ...) -> Any: ... + @staticmethod + def setinputsizes(sizes: Any) -> Any: ... + @staticmethod + def setoutputsize(size: Any, column: Any = ...) -> Any: ... + def __enter__(self) -> Any: ... + def __exit__(self, type: Any, value: Any, traceback: Any) -> Any: ... + def __iter__( # pylint: disable=non-iterator-returned + self) -> Iterator[Any]: ... + + def __next__(self) -> Any: ... + + +class ReplicationConnection(connection): + autocommit: Any = ... + isolation_level: Any = ... + replication_type: Any = ... + reset: Any = ... + set_isolation_level: Any = ... + set_session: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + +class ReplicationCursor(cursor): + feedback_timestamp: Any = ... + io_timestamp: Any = ... + wal_end: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + @staticmethod + def consume_stream( + consumer: Any, + keepalive_interval: Any = ... + ) -> Any: ... + + def read_message(self, *args: Any, **kwargs: Any) -> Any: ... + + @staticmethod + def send_feedback( + write_lsn: Any = ..., + flush_lsn: Any = ..., + apply_lsn: Any = ..., + reply: Any = ..., + force: Any = ... + ) -> Any: ... + + @staticmethod + def start_replication_expert( + command: Any, + decode: Any = ..., + status_interval: Any = ..., + ) -> Any: ... + + +class ReplicationMessage: + cursor: Any = ... + data_size: Any = ... + data_start: Any = ... + payload: Any = ... + send_time: Any = ... + wal_end: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + +class TransactionRollbackError(OperationalError): + ... + + +class Warning(Exception): + ... + + +class Xid: + bqual: Any = ... + database: Any = ... + format_id: Any = ... + gtrid: Any = ... + owner: Any = ... + prepared: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def from_string(self, *args: Any, **kwargs: Any) -> Any: ... + def __getitem__(self, index: Any) -> Any: ... + def __len__(self) -> Any: ... + + +listoftuple: Any + + +class lobject: + closed: Any = ... + mode: Any = ... + oid: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def close(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def export(filename: Any) -> Any: ... + @staticmethod + def read(size: Any = ...) -> Any: ... + @staticmethod + def seek(offset: Any, whence: Any = ...) -> Any: ... + def tell(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def truncate(len: Any = ...) -> Any: ... + def unlink(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod + def write(str: str) -> Any: ... diff --git a/stubs/psycopg2/_range.pyi b/stubs/psycopg2/_range.pyi new file mode 100644 index 0000000..caefcbc --- /dev/null +++ b/stubs/psycopg2/_range.pyi @@ -0,0 +1,87 @@ +from typing import Any, Optional + +from psycopg2._psycopg import ( + InterfaceError, + ProgrammingError +) + + +class Range: + def __init__(self, lower: Optional[Any] = ..., upper: Optional[Any] + = ..., bounds: str = ..., empty: bool = ...) -> None: ... + + @property + def lower(self) -> Any: ... + @property + def upper(self) -> Any: ... + @property + def isempty(self) -> Any: ... + @property + def lower_inf(self) -> Any: ... + @property + def upper_inf(self) -> Any: ... + @property + def lower_inc(self) -> Any: ... + @property + def upper_inc(self) -> Any: ... + def __contains__(self, x: Any) -> Any: ... + def __bool__(self) -> Any: ... + def __nonzero__(self) -> Any: ... + def __eq__(self, other: Any) -> Any: ... + def __ne__(self, other: Any) -> Any: ... + def __hash__(self) -> Any: ... + def __lt__(self, other: Any) -> Any: ... + def __le__(self, other: Any) -> Any: ... + def __gt__(self, other: Any) -> Any: ... + def __ge__(self, other: Any) -> Any: ... + + +def register_range(pgrange: Any, pyrange: Any, + conn_or_curs: Any, globally: bool = ...) -> Any: ... + + +class RangeAdapter: + name: Any = ... + adapted: Any = ... + def __init__(self, adapted: Any) -> None: ... + def __conform__(self, proto: Any) -> Any: ... + def prepare(self, conn: Any) -> None: ... + def getquoted(self) -> Any: ... + + +class RangeCaster: + subtype_oid: Any = ... + typecaster: Any = ... + array_typecaster: Any = ... + def __init__(self, pgrange: Any, pyrange: Any, oid: Any, + subtype_oid: Any, array_oid: Optional[Any] = ...) -> None: ... + + def parse(self, s: Any, cur: Optional[Any] = ...) -> Any: ... + + +class NumericRange(Range): + ... + + +class DateRange(Range): + ... + + +class DateTimeRange(Range): + ... + + +class DateTimeTZRange(Range): + ... + + +class NumberRangeAdapter(RangeAdapter): + def getquoted(self) -> Any: ... + + +int4range_caster: Any +int8range_caster: Any +numrange_caster: Any +daterange_caster: Any +tsrange_caster: Any +tstzrange_caster: Any diff --git a/stubs/psycopg2/errorcodes.pyi b/stubs/psycopg2/errorcodes.pyi new file mode 100644 index 0000000..3806074 --- /dev/null +++ b/stubs/psycopg2/errorcodes.pyi @@ -0,0 +1,306 @@ +from typing import Any + +def lookup(code: Any, _cache: Any = ...): ... + +CLASS_SUCCESSFUL_COMPLETION: str +CLASS_WARNING: str +CLASS_NO_DATA: str +CLASS_SQL_STATEMENT_NOT_YET_COMPLETE: str +CLASS_CONNECTION_EXCEPTION: str +CLASS_TRIGGERED_ACTION_EXCEPTION: str +CLASS_FEATURE_NOT_SUPPORTED: str +CLASS_INVALID_TRANSACTION_INITIATION: str +CLASS_LOCATOR_EXCEPTION: str +CLASS_INVALID_GRANTOR: str +CLASS_INVALID_ROLE_SPECIFICATION: str +CLASS_DIAGNOSTICS_EXCEPTION: str +CLASS_CASE_NOT_FOUND: str +CLASS_CARDINALITY_VIOLATION: str +CLASS_DATA_EXCEPTION: str +CLASS_INTEGRITY_CONSTRAINT_VIOLATION: str +CLASS_INVALID_CURSOR_STATE: str +CLASS_INVALID_TRANSACTION_STATE: str +CLASS_INVALID_SQL_STATEMENT_NAME: str +CLASS_TRIGGERED_DATA_CHANGE_VIOLATION: str +CLASS_INVALID_AUTHORIZATION_SPECIFICATION: str +CLASS_DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str +CLASS_INVALID_TRANSACTION_TERMINATION: str +CLASS_SQL_ROUTINE_EXCEPTION: str +CLASS_INVALID_CURSOR_NAME: str +CLASS_EXTERNAL_ROUTINE_EXCEPTION: str +CLASS_EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str +CLASS_SAVEPOINT_EXCEPTION: str +CLASS_INVALID_CATALOG_NAME: str +CLASS_INVALID_SCHEMA_NAME: str +CLASS_TRANSACTION_ROLLBACK: str +CLASS_SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str +CLASS_WITH_CHECK_OPTION_VIOLATION: str +CLASS_INSUFFICIENT_RESOURCES: str +CLASS_PROGRAM_LIMIT_EXCEEDED: str +CLASS_OBJECT_NOT_IN_PREREQUISITE_STATE: str +CLASS_OPERATOR_INTERVENTION: str +CLASS_SYSTEM_ERROR: str +CLASS_SNAPSHOT_FAILURE: str +CLASS_CONFIGURATION_FILE_ERROR: str +CLASS_FOREIGN_DATA_WRAPPER_ERROR: str +CLASS_PL_PGSQL_ERROR: str +CLASS_INTERNAL_ERROR: str +SUCCESSFUL_COMPLETION: str +WARNING: str +NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: str +STRING_DATA_RIGHT_TRUNCATION_: str +PRIVILEGE_NOT_REVOKED: str +PRIVILEGE_NOT_GRANTED: str +IMPLICIT_ZERO_BIT_PADDING: str +DYNAMIC_RESULT_SETS_RETURNED: str +DEPRECATED_FEATURE: str +NO_DATA: str +NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: str +SQL_STATEMENT_NOT_YET_COMPLETE: str +CONNECTION_EXCEPTION: str +SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: str +CONNECTION_DOES_NOT_EXIST: str +SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: str +CONNECTION_FAILURE: str +TRANSACTION_RESOLUTION_UNKNOWN: str +PROTOCOL_VIOLATION: str +TRIGGERED_ACTION_EXCEPTION: str +FEATURE_NOT_SUPPORTED: str +INVALID_TRANSACTION_INITIATION: str +LOCATOR_EXCEPTION: str +INVALID_LOCATOR_SPECIFICATION: str +INVALID_GRANTOR: str +INVALID_GRANT_OPERATION: str +INVALID_ROLE_SPECIFICATION: str +DIAGNOSTICS_EXCEPTION: str +STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: str +CASE_NOT_FOUND: str +CARDINALITY_VIOLATION: str +DATA_EXCEPTION: str +STRING_DATA_RIGHT_TRUNCATION: str +NULL_VALUE_NO_INDICATOR_PARAMETER: str +NUMERIC_VALUE_OUT_OF_RANGE: str +NULL_VALUE_NOT_ALLOWED_: str +ERROR_IN_ASSIGNMENT: str +INVALID_DATETIME_FORMAT: str +DATETIME_FIELD_OVERFLOW: str +INVALID_TIME_ZONE_DISPLACEMENT_VALUE: str +ESCAPE_CHARACTER_CONFLICT: str +INVALID_USE_OF_ESCAPE_CHARACTER: str +INVALID_ESCAPE_OCTET: str +ZERO_LENGTH_CHARACTER_STRING: str +MOST_SPECIFIC_TYPE_MISMATCH: str +SEQUENCE_GENERATOR_LIMIT_EXCEEDED: str +NOT_AN_XML_DOCUMENT: str +INVALID_XML_DOCUMENT: str +INVALID_XML_CONTENT: str +INVALID_XML_COMMENT: str +INVALID_XML_PROCESSING_INSTRUCTION: str +INVALID_INDICATOR_PARAMETER_VALUE: str +SUBSTRING_ERROR: str +DIVISION_BY_ZERO: str +INVALID_PRECEDING_OR_FOLLOWING_SIZE: str +INVALID_ARGUMENT_FOR_NTILE_FUNCTION: str +INTERVAL_FIELD_OVERFLOW: str +INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION: str +INVALID_CHARACTER_VALUE_FOR_CAST: str +INVALID_ESCAPE_CHARACTER: str +INVALID_REGULAR_EXPRESSION: str +INVALID_ARGUMENT_FOR_LOGARITHM: str +INVALID_ARGUMENT_FOR_POWER_FUNCTION: str +INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: str +INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: str +INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: str +INVALID_LIMIT_VALUE: str +CHARACTER_NOT_IN_REPERTOIRE: str +INDICATOR_OVERFLOW: str +INVALID_PARAMETER_VALUE: str +UNTERMINATED_C_STRING: str +INVALID_ESCAPE_SEQUENCE: str +STRING_DATA_LENGTH_MISMATCH: str +TRIM_ERROR: str +ARRAY_SUBSCRIPT_ERROR: str +INVALID_TABLESAMPLE_REPEAT: str +INVALID_TABLESAMPLE_ARGUMENT: str +DUPLICATE_JSON_OBJECT_KEY_VALUE: str +INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: str +INVALID_JSON_TEXT: str +INVALID_SQL_JSON_SUBSCRIPT: str +MORE_THAN_ONE_SQL_JSON_ITEM: str +NO_SQL_JSON_ITEM: str +NON_NUMERIC_SQL_JSON_ITEM: str +NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: str +SINGLETON_SQL_JSON_ITEM_REQUIRED: str +SQL_JSON_ARRAY_NOT_FOUND: str +SQL_JSON_MEMBER_NOT_FOUND: str +SQL_JSON_NUMBER_NOT_FOUND: str +SQL_JSON_OBJECT_NOT_FOUND: str +TOO_MANY_JSON_ARRAY_ELEMENTS: str +TOO_MANY_JSON_OBJECT_MEMBERS: str +SQL_JSON_SCALAR_REQUIRED: str +FLOATING_POINT_EXCEPTION: str +INVALID_TEXT_REPRESENTATION: str +INVALID_BINARY_REPRESENTATION: str +BAD_COPY_FILE_FORMAT: str +UNTRANSLATABLE_CHARACTER: str +NONSTANDARD_USE_OF_ESCAPE_CHARACTER: str +INTEGRITY_CONSTRAINT_VIOLATION: str +RESTRICT_VIOLATION: str +NOT_NULL_VIOLATION: str +FOREIGN_KEY_VIOLATION: str +UNIQUE_VIOLATION: str +CHECK_VIOLATION: str +EXCLUSION_VIOLATION: str +INVALID_CURSOR_STATE: str +INVALID_TRANSACTION_STATE: str +ACTIVE_SQL_TRANSACTION: str +BRANCH_TRANSACTION_ALREADY_ACTIVE: str +INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: str +INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: str +NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: str +READ_ONLY_SQL_TRANSACTION: str +SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: str +HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: str +NO_ACTIVE_SQL_TRANSACTION: str +IN_FAILED_SQL_TRANSACTION: str +IDLE_IN_TRANSACTION_SESSION_TIMEOUT: str +INVALID_SQL_STATEMENT_NAME: str +TRIGGERED_DATA_CHANGE_VIOLATION: str +INVALID_AUTHORIZATION_SPECIFICATION: str +INVALID_PASSWORD: str +DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str +DEPENDENT_OBJECTS_STILL_EXIST: str +INVALID_TRANSACTION_TERMINATION: str +SQL_ROUTINE_EXCEPTION: str +MODIFYING_SQL_DATA_NOT_PERMITTED_: str +PROHIBITED_SQL_STATEMENT_ATTEMPTED_: str +READING_SQL_DATA_NOT_PERMITTED_: str +FUNCTION_EXECUTED_NO_RETURN_STATEMENT: str +INVALID_CURSOR_NAME: str +EXTERNAL_ROUTINE_EXCEPTION: str +CONTAINING_SQL_NOT_PERMITTED: str +MODIFYING_SQL_DATA_NOT_PERMITTED: str +PROHIBITED_SQL_STATEMENT_ATTEMPTED: str +READING_SQL_DATA_NOT_PERMITTED: str +EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str +INVALID_SQLSTATE_RETURNED: str +NULL_VALUE_NOT_ALLOWED: str +TRIGGER_PROTOCOL_VIOLATED: str +SRF_PROTOCOL_VIOLATED: str +EVENT_TRIGGER_PROTOCOL_VIOLATED: str +SAVEPOINT_EXCEPTION: str +INVALID_SAVEPOINT_SPECIFICATION: str +INVALID_CATALOG_NAME: str +INVALID_SCHEMA_NAME: str +TRANSACTION_ROLLBACK: str +SERIALIZATION_FAILURE: str +TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION: str +STATEMENT_COMPLETION_UNKNOWN: str +DEADLOCK_DETECTED: str +SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str +INSUFFICIENT_PRIVILEGE: str +SYNTAX_ERROR: str +INVALID_NAME: str +INVALID_COLUMN_DEFINITION: str +NAME_TOO_LONG: str +DUPLICATE_COLUMN: str +AMBIGUOUS_COLUMN: str +UNDEFINED_COLUMN: str +UNDEFINED_OBJECT: str +DUPLICATE_OBJECT: str +DUPLICATE_ALIAS: str +DUPLICATE_FUNCTION: str +AMBIGUOUS_FUNCTION: str +GROUPING_ERROR: str +DATATYPE_MISMATCH: str +WRONG_OBJECT_TYPE: str +INVALID_FOREIGN_KEY: str +CANNOT_COERCE: str +UNDEFINED_FUNCTION: str +GENERATED_ALWAYS: str +RESERVED_NAME: str +UNDEFINED_TABLE: str +UNDEFINED_PARAMETER: str +DUPLICATE_CURSOR: str +DUPLICATE_DATABASE: str +DUPLICATE_PREPARED_STATEMENT: str +DUPLICATE_SCHEMA: str +DUPLICATE_TABLE: str +AMBIGUOUS_PARAMETER: str +AMBIGUOUS_ALIAS: str +INVALID_COLUMN_REFERENCE: str +INVALID_CURSOR_DEFINITION: str +INVALID_DATABASE_DEFINITION: str +INVALID_FUNCTION_DEFINITION: str +INVALID_PREPARED_STATEMENT_DEFINITION: str +INVALID_SCHEMA_DEFINITION: str +INVALID_TABLE_DEFINITION: str +INVALID_OBJECT_DEFINITION: str +INDETERMINATE_DATATYPE: str +INVALID_RECURSION: str +WINDOWING_ERROR: str +COLLATION_MISMATCH: str +INDETERMINATE_COLLATION: str +WITH_CHECK_OPTION_VIOLATION: str +INSUFFICIENT_RESOURCES: str +DISK_FULL: str +OUT_OF_MEMORY: str +TOO_MANY_CONNECTIONS: str +CONFIGURATION_LIMIT_EXCEEDED: str +PROGRAM_LIMIT_EXCEEDED: str +STATEMENT_TOO_COMPLEX: str +TOO_MANY_COLUMNS: str +TOO_MANY_ARGUMENTS: str +OBJECT_NOT_IN_PREREQUISITE_STATE: str +OBJECT_IN_USE: str +CANT_CHANGE_RUNTIME_PARAM: str +LOCK_NOT_AVAILABLE: str +UNSAFE_NEW_ENUM_VALUE_USAGE: str +OPERATOR_INTERVENTION: str +QUERY_CANCELED: str +ADMIN_SHUTDOWN: str +CRASH_SHUTDOWN: str +CANNOT_CONNECT_NOW: str +DATABASE_DROPPED: str +SYSTEM_ERROR: str +IO_ERROR: str +UNDEFINED_FILE: str +DUPLICATE_FILE: str +SNAPSHOT_TOO_OLD: str +CONFIG_FILE_ERROR: str +LOCK_FILE_EXISTS: str +FDW_ERROR: str +FDW_OUT_OF_MEMORY: str +FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: str +FDW_INVALID_DATA_TYPE: str +FDW_COLUMN_NAME_NOT_FOUND: str +FDW_INVALID_DATA_TYPE_DESCRIPTORS: str +FDW_INVALID_COLUMN_NAME: str +FDW_INVALID_COLUMN_NUMBER: str +FDW_INVALID_USE_OF_NULL_POINTER: str +FDW_INVALID_STRING_FORMAT: str +FDW_INVALID_HANDLE: str +FDW_INVALID_OPTION_INDEX: str +FDW_INVALID_OPTION_NAME: str +FDW_OPTION_NAME_NOT_FOUND: str +FDW_REPLY_HANDLE: str +FDW_UNABLE_TO_CREATE_EXECUTION: str +FDW_UNABLE_TO_CREATE_REPLY: str +FDW_UNABLE_TO_ESTABLISH_CONNECTION: str +FDW_NO_SCHEMAS: str +FDW_SCHEMA_NOT_FOUND: str +FDW_TABLE_NOT_FOUND: str +FDW_FUNCTION_SEQUENCE_ERROR: str +FDW_TOO_MANY_HANDLES: str +FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: str +FDW_INVALID_ATTRIBUTE_VALUE: str +FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: str +FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: str +PLPGSQL_ERROR: str +RAISE_EXCEPTION: str +NO_DATA_FOUND: str +TOO_MANY_ROWS: str +ASSERT_FAILURE: str +INTERNAL_ERROR: str +DATA_CORRUPTED: str +INDEX_CORRUPTED: str diff --git a/stubs/psycopg2/errors.pyi b/stubs/psycopg2/errors.pyi new file mode 100644 index 0000000..b866509 --- /dev/null +++ b/stubs/psycopg2/errors.pyi @@ -0,0 +1,3 @@ +from typing import Any + +def lookup(code: Any): ... diff --git a/stubs/psycopg2/extensions.pyi b/stubs/psycopg2/extensions.pyi new file mode 100644 index 0000000..c597c28 --- /dev/null +++ b/stubs/psycopg2/extensions.pyi @@ -0,0 +1,142 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Optional + +from psycopg2._json import ( + register_default_json, + register_default_jsonb +) + +from psycopg2._psycopg import ( + AsIs, + BINARYARRAY, + BOOLEAN, + BOOLEANARRAY, + BYTES, + BYTESARRAY, + Binary, + Boolean, + Column, + ConnectionInfo, + DATE, + DATEARRAY, + DATETIMEARRAY, + DECIMAL, + DECIMALARRAY, + DateFromPy, + Diagnostics, + FLOAT, + FLOATARRAY, + Float, + INTEGER, + INTEGERARRAY, + INTERVAL, + INTERVALARRAY, + ISQLQuote, + Int, + IntervalFromPy, + LONGINTEGER, + LONGINTEGERARRAY, + Notify, + PYDATE, + PYDATEARRAY, + PYDATETIME, + PYDATETIMEARRAY, + PYDATETIMETZ, + PYDATETIMETZARRAY, + PYINTERVAL, + PYINTERVALARRAY, + PYTIME, + PYTIMEARRAY, + QueryCanceledError, + ROWIDARRAY, + STRINGARRAY, + TIME, + TIMEARRAY, + TimeFromPy, + TimestampFromPy, + TransactionRollbackError, + UNICODE, + UNICODEARRAY, + Xid, + adapt, + adapters, + binary_types, + connection, + cursor, + encodings, + encrypt_password, + get_wait_callback, + libpq_version, + lobject, +) +from psycopg2._range import Range + +ISOLATION_LEVEL_AUTOCOMMIT: int +ISOLATION_LEVEL_READ_UNCOMMITTED: int +ISOLATION_LEVEL_READ_COMMITTED: int +ISOLATION_LEVEL_REPEATABLE_READ: int +ISOLATION_LEVEL_SERIALIZABLE: int +ISOLATION_LEVEL_DEFAULT: Any +STATUS_SETUP: int +STATUS_READY: int +STATUS_BEGIN: int +STATUS_SYNC: int +STATUS_ASYNC: int +STATUS_PREPARED: int +STATUS_IN_TRANSACTION = STATUS_BEGIN +POLL_OK: int +POLL_READ: int +POLL_WRITE: int +POLL_ERROR: int +TRANSACTION_STATUS_IDLE: int +TRANSACTION_STATUS_ACTIVE: int +TRANSACTION_STATUS_INTRANS: int +TRANSACTION_STATUS_INERROR: int +TRANSACTION_STATUS_UNKNOWN: int + +QuotedString: Any +new_array_type: Any +new_type: Any +parse_dsn: Any +quote_ident: Any +register_type: Any +set_wait_callback: Any +string_types: Any + + +def register_adapter(typ: Any, callable: Any) -> None: ... + + +class SQL_IN: + def __init__(self, seq: Any) -> None: ... + def prepare(self, conn: Any) -> None: ... + def getquoted(self) -> Any: ... + + +class NoneAdapter: + def __init__(self, obj: Any) -> None: ... + def getquoted(self, _null: bytes = ...) -> Any: ... + + +def make_dsn(dsn: Optional[Any] = ..., **kwargs: Any) -> Any: ... + + +JSON: Any +JSONARRAY: Any +JSONB: Any +JSONBARRAY: Any +k: Any diff --git a/stubs/psycopg2/extras.pyi b/stubs/psycopg2/extras.pyi index 65451da..a95f203 100644 --- a/stubs/psycopg2/extras.pyi +++ b/stubs/psycopg2/extras.pyi @@ -1,23 +1,62 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + + +from collections import OrderedDict from typing import Any, Optional -# from .compat import PY2 as PY2, PY3 as PY3, lru_cache as lru_cache -# from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident -# from collections import OrderedDict -# from psycopg2._ipaddress import register_ipaddress as register_ipaddress -# from psycopg2._json import Json as Json, json as json, register_default_json as register_default_json, register_default_jsonb as register_default_jsonb, register_json as register_json -# from psycopg2._psycopg import REPLICATION_LOGICAL as REPLICATION_LOGICAL, REPLICATION_PHYSICAL as REPLICATION_PHYSICAL, ReplicationConnection as _replicationConnection, ReplicationCursor as _replicationCursor, ReplicationMessage as ReplicationMessage -# from psycopg2._range import DateRange as DateRange, DateTimeRange as DateTimeRange, DateTimeTZRange as DateTimeTZRange, NumericRange as NumericRange, Range as Range, RangeAdapter as RangeAdapter, RangeCaster as RangeCaster, register_range as register_range + +from psycopg2._range import ( + DateRange, + DateTimeRange, + DateTimeTZRange, + NumericRange, + Range, + RangeAdapter, + RangeCaster, + register_range +) +from psycopg2._psycopg import ( + REPLICATION_LOGICAL, + REPLICATION_PHYSICAL, + ReplicationConnection, + ReplicationCursor, + ReplicationMessage, + connection, + cursor, + quote_ident +) +from psycopg2._ipaddress import ( + register_ipaddress +) + + +class RealDictRow(OrderedDict[Any, Any]): + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __setitem__(self, key: Any, value: Any) -> None: ... # # -# class DictCursorBase(_cursor): +# class DictCursorBase(cursor): # row_factory: Any = ... # def __init__(self, *args: Any, **kwargs: Any) -> None: ... # def fetchone(self): ... # def fetchmany(self, size: Optional[Any] = ...): ... # def fetchall(self): ... -# def __iter__(self) -> Any: ... +# def __iter__(self) -> Any: ... # pylint: disable=non-iterator-returned # # -# class DictConnection(_connection): +# class DictConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # @@ -41,7 +80,7 @@ from typing import Any, Optional # def __reduce__(self): ... # # -# class RealDictConnection(_connection): +# class RealDictConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # @@ -57,11 +96,11 @@ from typing import Any, Optional # def __setitem__(self, key: Any, value: Any) -> None: ... # # -# class NamedTupleConnection(_connection): +# class NamedTupleConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # -# class NamedTupleCursor(_cursor): +# class NamedTupleCursor(cursor): # Record: Any = ... # MAX_CACHE: int = ... # def execute(self, query: Any, vars: Optional[Any] = ...): ... @@ -73,14 +112,14 @@ from typing import Any, Optional # def __iter__(self) -> Any: ... # # -# class LoggingConnection(_connection): +# class LoggingConnection(connection): # log: Any = ... # def initialize(self, logobj: Any) -> None: ... # def filter(self, msg: Any, curs: Any): ... # def cursor(self, *args: Any, **kwargs: Any): ... # # -# class LoggingCursor(_cursor): +# class LoggingCursor(cursor): # def execute(self, query: Any, vars: Optional[Any] = ...): ... # def callproc(self, procname: Any, vars: Optional[Any] = ...): ... # @@ -111,11 +150,21 @@ from typing import Any, Optional # # class ReplicationCursor(_replicationCursor): # def create_replication_slot( -# self, slot_name: Any, slot_type: Optional[Any] = ..., output_plugin: Optional[Any] = ...) -> None: ... +# self, +# slot_name: Any, +# slot_type: Optional[Any] = ..., +# output_plugin: Optional[Any] = ...) -> None: ... # # def drop_replication_slot(self, slot_name: Any) -> None: ... -# def start_replication(self, slot_name: Optional[Any] = ..., slot_type: Optional[Any] = ..., start_lsn: int = ..., -# timeline: int = ..., options: Optional[Any] = ..., decode: bool = ..., status_interval: int = ...) -> None: ... +# +# def start_replication(self, +# slot_name: Optional[Any] = ..., +# slot_type: Optional[Any] = ..., +# start_lsn: int = ..., +# timeline: int = ..., +# options: Optional[Any] = ..., +# decode: bool = ..., +# status_interval: int = ...) -> None: ... # # def fileno(self): ... # @@ -124,10 +173,8 @@ from typing import Any, Optional # def __init__(self, uuid: Any) -> None: ... # def __conform__(self, proto: Any): ... # def getquoted(self): ... -# -# -# pylint: disable=unsubscriptable-object + def register_uuid(oids: Optional[Any] = ..., conn_or_curs: Optional[Any] = ...) -> None: ... # @@ -161,8 +208,13 @@ def register_uuid(oids: Optional[Any] = ..., # def get_oids(self, conn_or_curs: Any): ... # # -# def register_hstore(conn_or_curs: Any, globally: bool = ..., unicode: bool = ..., -# oid: Optional[Any] = ..., array_oid: Optional[Any] = ...) -> None: ... +# def register_hstore( +# conn_or_curs: Any, +# globally: bool = ..., +# unicode: bool = ..., +# oid: Optional[Any] = ..., +# array_oid: Optional[Any] = ... +# ) -> None: ... # # # class CompositeCaster: diff --git a/stubs/psycopg2/pool.pyi b/stubs/psycopg2/pool.pyi new file mode 100644 index 0000000..24023b3 --- /dev/null +++ b/stubs/psycopg2/pool.pyi @@ -0,0 +1,21 @@ +import psycopg2 +from typing import Any, Optional + +class PoolError(psycopg2.Error): ... + +class AbstractConnectionPool: + minconn: Any = ... + maxconn: Any = ... + closed: bool = ... + def __init__(self, minconn: Any, maxconn: Any, *args: Any, **kwargs: Any) -> None: ... + +class SimpleConnectionPool(AbstractConnectionPool): + getconn: Any = ... + putconn: Any = ... + closeall: Any = ... + +class ThreadedConnectionPool(AbstractConnectionPool): + def __init__(self, minconn: Any, maxconn: Any, *args: Any, **kwargs: Any) -> None: ... + def getconn(self, key: Optional[Any] = ...): ... + def putconn(self, conn: Optional[Any] = ..., key: Optional[Any] = ..., close: bool = ...) -> None: ... + def closeall(self) -> None: ... diff --git a/stubs/psycopg2/tz.pyi b/stubs/psycopg2/tz.pyi new file mode 100644 index 0000000..ab6376c --- /dev/null +++ b/stubs/psycopg2/tz.pyi @@ -0,0 +1,26 @@ +import datetime +from typing import Any, Optional + +ZERO: Any + +class FixedOffsetTimezone(datetime.tzinfo): + def __init__(self, offset: Optional[Any] = ..., name: Optional[Any] = ...) -> None: ... + def __new__(cls, offset: Optional[Any] = ..., name: Optional[Any] = ...): ... + def __eq__(self, other: Any) -> Any: ... + def __ne__(self, other: Any) -> Any: ... + def __getinitargs__(self): ... + def utcoffset(self, dt: Any): ... + def tzname(self, dt: Any): ... + def dst(self, dt: Any): ... + +STDOFFSET: Any +DSTOFFSET: Any +DSTOFFSET = STDOFFSET +DSTDIFF: Any + +class LocalTimezone(datetime.tzinfo): + def utcoffset(self, dt: Any): ... + def dst(self, dt: Any): ... + def tzname(self, dt: Any): ... + +LOCAL: Any diff --git a/test/helpers.py b/test/helpers.py index 4499c9b..fcfa5fa 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -9,9 +9,6 @@ ) from unittest.mock import MagicMock -from db_wrapper.connection import ConnectionParameters -from db_wrapper.client import Client - def composed_to_string(seq: Iterable[Any]) -> str: """Test helper to convert a sql query to a string for comparison. @@ -25,6 +22,7 @@ def composed_to_string(seq: Iterable[Any]) -> str: class AsyncMock(MagicMock): """Extend unittest.mock.MagicMock to allow mocking of async functions.""" + # pylint: disable=invalid-overridden-method # pylint: disable=useless-super-delegation @@ -40,9 +38,3 @@ def wrapped(instance: Any) -> None: asyncio.run(test(instance)) return wrapped - - -def get_client() -> Client: - """Create a client with placeholder connection data.""" - conn_params = ConnectionParameters('a', 'a', 'a', 'a') - return Client(conn_params) diff --git a/test/test_model.py b/test/test_model.py index 7abffa7..d9de694 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -10,9 +10,10 @@ from typing import ( cast, Any, - TypeVar, List, Tuple, + Type, + TypeVar, ) from uuid import uuid4 import unittest @@ -20,21 +21,27 @@ import helpers -from db_wrapper.client import Client +from db_wrapper import ConnectionParameters, AsyncClient, SyncClient from db_wrapper.model import ( - Model, ModelData, - Read, + AsyncModel, + AsyncRead, + SyncModel, + SyncRead +) +from db_wrapper.model.base import ( UnexpectedMultipleResults, - NoResultFound) + NoResultFound, +) -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name T = TypeVar('T', bound=ModelData) -def setup(query_result: List[T]) -> Tuple[Model[T], Client]: +def setupAsync( + query_result: List[T], + model_data: Type[T] +) -> Tuple[AsyncModel[T], AsyncClient]: """Setup helper that returns instances of both a Model & a Client. Mocks the execute_and_return method on the Client instance to skip @@ -46,14 +53,45 @@ def setup(query_result: List[T]) -> Tuple[Model[T], Client]: here, but still specify a return value for the mocked method on the returned Client instance. """ - client = helpers.get_client() + # create client with placeholder connection data + conn_params = ConnectionParameters('a', 1, 'a', 'a', 'a') + client = AsyncClient(conn_params) # mock client's sql execution method client.execute_and_return = helpers.AsyncMock( # type:ignore - return_value=query_result) + return_value=[i.dict() for i in query_result]) + + # init a real model with mocked client + model = AsyncModel[Any](client, 'test', model_data) + + return model, client + + +def setupSync( + query_result: List[T], + model_data: Type[T] +) -> Tuple[SyncModel[T], SyncClient]: + """Setup helper that returns instances of both a Model & a Client. + + Mocks the execute_and_return method on the Client instance to skip + normal execution & just return the given query_result. + + Using this setup helper that requires manually calling in each test + instance is better than unittest's setUpModule or setUpClass methods + because it allows the caller to skip all the boilerplate contained + here, but still specify a return value for the mocked method on the + returned Client instance. + """ + # create client with placeholder connection data + conn_params = ConnectionParameters('a', 1, 'a', 'a', 'a') + client = SyncClient(conn_params) + + # mock client's sql execution method + client.execute_and_return = helpers.MagicMock( # type:ignore + return_value=[i.dict() for i in query_result]) - # init model with mocked client - model = Model[Any](client, 'test') + # init a real model with mocked client + model = SyncModel[T](client, 'test', model_data) return model, client @@ -64,39 +102,70 @@ class TestReadOneById(TestCase): @helpers.async_test async def test_it_correctly_builds_query_with_given_id(self) -> None: item = ModelData(id=uuid4()) - model, client = setup([item]) - await model.read.one_by_id(str(item.id)) - query_composed = cast( - helpers.AsyncMock, client.execute_and_return).call_args[0][0] - query = helpers.composed_to_string(query_composed) + async_model, async_client = setupAsync([item], ModelData) + sync_model, sync_client = setupSync([item], ModelData) + + await async_model.read.one_by_id(item.id) + sync_model.read.one_by_id(item.id) + + async_query_composed = cast( + helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] + sync_query_composed = cast( + helpers.AsyncMock, sync_client.execute_and_return).call_args[0][0] + + async_query = helpers.composed_to_string(async_query_composed) + sync_query = helpers.composed_to_string(sync_query_composed) - self.assertEqual(query, "SELECT * " - "FROM test " - f"WHERE id = {item.id};") + queries = [async_query, sync_query] + + for query in queries: + with self.subTest(): + self.assertEqual(query, "SELECT * " + "FROM test " + f"WHERE id = {item.id};") @helpers.async_test async def test_it_returns_a_single_result(self) -> None: item = ModelData(id=uuid4()) - model, _ = setup([item]) - result = await model.read.one_by_id(str(item.id)) + async_model, _ = setupAsync([item], ModelData) + sync_model, _ = setupSync([item], ModelData) + results = [await async_model.read.one_by_id(item.id), + sync_model.read.one_by_id(item.id)] - self.assertEqual(result, item) + for result in results: + with self.subTest(): + self.assertEqual(result, item) - @ helpers.async_test + @helpers.async_test async def test_it_raises_exception_if_more_than_one_result(self) -> None: item = ModelData(id=uuid4()) - model, _ = setup([item, item]) + async_model, _ = setupAsync([item, item], ModelData) + sync_model, _ = setupSync([item, item], ModelData) - with self.assertRaises(UnexpectedMultipleResults): - await model.read.one_by_id(str(item.id)) + with self.subTest(): + with self.assertRaises(UnexpectedMultipleResults): + await async_model.read.one_by_id(item.id) + + with self.subTest(): + with self.assertRaises(UnexpectedMultipleResults): + sync_model.read.one_by_id(item.id) @ helpers.async_test async def test_it_raises_exception_if_no_result_to_return(self) -> None: - model: Model[ModelData] - model, _ = setup([]) + empty_async: List[ModelData] = [] + empty_sync: List[ModelData] = [] + async_model: AsyncModel[ModelData] + sync_model: SyncModel[ModelData] + async_model, _ = setupAsync(empty_async, ModelData) + sync_model, _ = setupSync(empty_sync, ModelData) - with self.assertRaises(NoResultFound): - await model.read.one_by_id('id') + with self.subTest(): + with self.assertRaises(NoResultFound): + await async_model.read.one_by_id(uuid4()) + + with self.subTest(): + with self.assertRaises(NoResultFound): + sync_model.read.one_by_id(uuid4()) class TestCreateOne(TestCase): @@ -114,16 +183,26 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - model, client = setup([item]) + async_model, async_client = setupAsync([item], TestCreateOne.Item) + sync_model, sync_client = setupSync([item], TestCreateOne.Item) + + await async_model.create.one(item) + sync_model.create.one(item) - await model.create.one(item) - query_composed = cast( - helpers.AsyncMock, client.execute_and_return).call_args[0][0] - query = helpers.composed_to_string(query_composed) + async_query_composed = cast( + helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] + sync_query_composed = cast( + helpers.MagicMock, sync_client.execute_and_return).call_args[0][0] - self.assertEqual(query, 'INSERT INTO test (id,a,b) ' - f"VALUES ({item.id},a,b) " - 'RETURNING *;') + queries = [async_query_composed, sync_query_composed] + + for query in queries: + with self.subTest(): + self.assertEqual( + helpers.composed_to_string(query), + 'INSERT INTO test (id,a,b) ' + f"VALUES ({item.id},a,b) " + 'RETURNING *;') @ helpers.async_test async def test_it_returns_the_new_record(self) -> None: @@ -132,11 +211,15 @@ async def test_it_returns_the_new_record(self) -> None: 'a': 'a', 'b': 'b', }) - model, _ = setup([item]) + async_model, _ = setupAsync([item], TestCreateOne.Item) + sync_model, _ = setupSync([item], TestCreateOne.Item) - result = await model.create.one(item) + results = [await async_model.create.one(item), + sync_model.create.one(item)] - self.assertEqual(result, item) + for result in results: + with self.subTest(): + self.assertEqual(result, item) class TestUpdateOne(TestCase): @@ -154,21 +237,27 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - # cast required to avoid mypy error due to unpacking - # TypedDict, see more on GitHub issue - # https://github.com/python/mypy/issues/4122 - updated = TestUpdateOne.Item(**{**item.dict(), 'b': 'c'}) - model, client = setup([updated]) + async_model, async_client = setupAsync([item], TestUpdateOne.Item) + sync_model, sync_client = setupSync([item], TestUpdateOne.Item) + + await async_model.update.one_by_id(item.id, {'b': 'c'}) + sync_model.update.one_by_id(item.id, {'b': 'c'}) + + async_query_composed = cast( + helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] + sync_query_composed = cast( + helpers.AsyncMock, sync_client.execute_and_return).call_args[0][0] - await model.update.one_by_id(str(item.id), {'b': 'c'}) - query_composed = cast( - helpers.AsyncMock, client.execute_and_return).call_args[0][0] - query = helpers.composed_to_string(query_composed) + queries = [async_query_composed, sync_query_composed] - self.assertEqual(query, 'UPDATE test ' - 'SET b = c ' - f"WHERE id = {item.id} " - 'RETURNING *;') + for query in queries: + with self.subTest(): + self.assertEqual( + helpers.composed_to_string(query), + 'UPDATE test ' + 'SET b = c ' + f"WHERE id = {item.id} " + 'RETURNING *;') @ helpers.async_test async def test_it_returns_the_new_record(self) -> None: @@ -177,15 +266,19 @@ async def test_it_returns_the_new_record(self) -> None: 'a': 'a', 'b': 'b', }) - # cast required to avoid mypy error due to unpacking - # TypedDict, see more on GitHub issue - # https://github.com/python/mypy/issues/4122 + # mock result updated = TestUpdateOne.Item(**{**item.dict(), 'b': 'c'}) - model, _ = setup([updated]) + async_model, _ = setupAsync([updated], TestUpdateOne.Item) + sync_model, _ = setupSync([updated], TestUpdateOne.Item) - result = await model.update.one_by_id(str(item.id), {'b': 'c'}) + results = [ + await async_model.update.one_by_id(item.id, {'b': 'c'}), + sync_model.update.one_by_id(item.id, {'b': 'c'}) + ] - self.assertEqual(result, updated) + for result in results: + with self.subTest(): + self.assertEqual(result, updated) class TestDeleteOneById(TestCase): @@ -203,17 +296,26 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - model, client = setup([item]) + async_model, async_client = setupAsync([item], TestDeleteOneById.Item) + sync_model, sync_client = setupSync([item], TestDeleteOneById.Item) - await model.delete.one_by_id(str(item.id)) + await async_model.delete.one_by_id(str(item.id)) + sync_model.delete.one_by_id(str(item.id)) - query_composed = cast( - helpers.AsyncMock, client.execute_and_return).call_args[0][0] - query = helpers.composed_to_string(query_composed) + async_query_composed = cast( + helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] + sync_query_composed = cast( + helpers.AsyncMock, sync_client.execute_and_return).call_args[0][0] - self.assertEqual(query, 'DELETE FROM test ' - f"WHERE id = {item.id} " - 'RETURNING *;') + queries = [async_query_composed, sync_query_composed] + + for query in queries: + with self.subTest(): + self.assertEqual( + helpers.composed_to_string(query), + 'DELETE FROM test ' + f"WHERE id = {item.id} " + 'RETURNING *;') @ helpers.async_test async def test_it_returns_the_deleted_record(self) -> None: @@ -222,44 +324,78 @@ async def test_it_returns_the_deleted_record(self) -> None: 'a': 'a', 'b': 'b', }) - model, _ = setup([item]) + async_model, _ = setupAsync([item], TestDeleteOneById.Item) + sync_model, _ = setupSync([item], TestDeleteOneById.Item) - result = await model.delete.one_by_id(str(item.id)) + results = [await async_model.delete.one_by_id(str(item.id)), + sync_model.delete.one_by_id(str(item.id))] - self.assertEqual(result, item) + for result in results: + with self.subTest(): + self.assertEqual(result, item) class TestExtendingModel(TestCase): """Testing Model's extensibility.""" - model: Model[ModelData] + + models: List[Any] def setUp(self) -> None: - class ReadExtended(Read[ModelData]): + class Item(ModelData): + """An example model data object.""" + + class AsyncReadExtended(AsyncRead[Item]): """Extending Read with additional query.""" def new_query(self) -> None: pass - model = Model[ModelData](helpers.get_client(), 'test') - model.read = ReadExtended(model.client, model.table) + class AsyncExtendedModel(AsyncModel[Item]): + """A model with extended read queries.""" + read: AsyncReadExtended - self.model = model + def __init__(self, client: AsyncClient) -> None: + super().__init__(client, 'extended_model', Item) + self.read = AsyncReadExtended(self.client, self.table, Item) + + class SyncReadExtended(SyncRead[Item]): + """Extending Read with additional query.""" + + def new_query(self) -> None: + pass + + class SyncExtendedModel(SyncModel[Item]): + """A model with extended read queries.""" + read: SyncReadExtended + + def __init__(self, client: SyncClient) -> None: + super().__init__(client, 'extended_model', Item) + self.read = SyncReadExtended(self.client, self.table, Item) + + _, async_client = setupAsync([Item(**{"id": uuid4()})], Item) + _, sync_client = setupSync([Item(**{"id": uuid4()})], Item) + self.models = [AsyncExtendedModel(async_client), + SyncExtendedModel(sync_client)] def test_it_can_add_new_queries_by_replacing_a_crud_property(self) -> None: - new_method = getattr(self.model.read, "new_query", None) + new_methods = [getattr(model.read, "new_query", None) + for model in self.models] - with self.subTest(): - self.assertIsNotNone(new_method) - with self.subTest(): - self.assertTrue(callable(new_method)) + for method in new_methods: + with self.subTest(): + self.assertIsNotNone(method) + with self.subTest(): + self.assertTrue(callable(method)) def test_it_still_has_original_queries_after_extending(self) -> None: - old_method = getattr(self.model.read, "one_by_id", None) - - with self.subTest(): - self.assertIsNotNone(old_method) - with self.subTest(): - self.assertTrue(callable(old_method)) + old_methods = [getattr(model.read, "one_by_id", None) + for model in self.models] + + for method in old_methods: + with self.subTest(): + self.assertIsNotNone(method) + with self.subTest(): + self.assertTrue(callable(method)) if __name__ == '__main__':