diff --git a/db_wrapper/client/async_client.py b/db_wrapper/client/async_client.py index 96896fb..afed995 100644 --- a/db_wrapper/client/async_client.py +++ b/db_wrapper/client/async_client.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import ( + cast, Any, TypeVar, Union, @@ -10,20 +11,16 @@ List, Dict) -import aiopg # type: ignore -from psycopg2.extras import register_uuid +import aiopg +from psycopg2.extras import register_uuid, RealDictCursor, RealDictRow # type: ignore from psycopg2 import sql -from db_wrapper.connection import ConnectionParameters, connect +from db_wrapper.connection import ConnectionParameters, get_pool # add uuid support to psycopg2 & Postgres register_uuid() -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name -T = TypeVar('T') - Query = Union[str, sql.Composed] @@ -36,18 +33,19 @@ class AsyncClient: """ _connection_params: ConnectionParameters - _connection: aiopg.Connection + _pool: aiopg.Pool def __init__(self, connection_params: ConnectionParameters) -> None: self._connection_params = connection_params async def connect(self) -> None: - """Connect to the database.""" - self._connection = await connect(self._connection_params) + """Create a database connection pool.""" + self._pool = await get_pool(self._connection_params) async def disconnect(self) -> None: - """Disconnect from the database.""" - await self._connection.close() + """Close database connection pool.""" + self._pool.close() + await self._pool.wait_closed() # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object @@ -57,6 +55,11 @@ async def _execute_query( query: Query, params: Optional[Dict[Hashable, Any]] = None, ) -> None: + # aiopg type is incorrect & thinks execute only takes str + # when in the query is passed through to psycopg2's + # cursor.execute which does accept sql.Composed objects. + query = cast(str, query) + if params: await cursor.execute(query, params) else: @@ -79,7 +82,7 @@ async def execute( Returns: None """ - async with self._connection.cursor() as cursor: + with (await self._pool.cursor(cursor_factory=RealDictCursor) ) as cursor: await self._execute_query(cursor, query, params) # PENDS python 3.9 support in pylint @@ -88,7 +91,7 @@ async def execute_and_return( self, query: Query, params: Optional[Dict[Hashable, Any]] = None, - ) -> List[T]: + ) -> List[RealDictRow]: """Execute the given SQL query & return the result. Arguments: @@ -99,8 +102,8 @@ async def execute_and_return( Returns: List containing all the rows that matched the query. """ - async with self._connection.cursor() as cursor: + with (await self._pool.cursor(cursor_factory=RealDictCursor) ) as cursor: await self._execute_query(cursor, query, params) - result: List[T] = await cursor.fetchall() + result: List[RealDictRow] = await cursor.fetchall() return result diff --git a/db_wrapper/client/sync_client.py b/db_wrapper/client/sync_client.py index aa3e644..72d64fd 100644 --- a/db_wrapper/client/sync_client.py +++ b/db_wrapper/client/sync_client.py @@ -3,14 +3,14 @@ from __future__ import annotations from typing import ( Any, - TypeVar, - Union, - Optional, + Dict, Hashable, List, - Dict) + Optional, + Union, +) -from psycopg2.extras import register_uuid +from psycopg2.extras import register_uuid, RealDictRow from psycopg2 import sql # pylint can't seem to find the items in psycopg2 despite being available from psycopg2._psycopg import cursor # pylint: disable=no-name-in-module @@ -24,10 +24,6 @@ register_uuid() -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name -T = TypeVar('T') - Query = Union[str, sql.Composed] @@ -60,7 +56,7 @@ def _execute_query( params: Optional[Dict[Hashable, Any]] = None, ) -> None: if params: - db_cursor.execute(query, params) # type: ignore + db_cursor.execute(query, params) else: db_cursor.execute(query) @@ -88,7 +84,7 @@ def execute_and_return( self, query: Query, params: Optional[Dict[Hashable, Any]] = None, - ) -> List[T]: + ) -> List[RealDictRow]: """Execute the given SQL query & return the result. Arguments: @@ -102,5 +98,5 @@ def execute_and_return( with self._connection.cursor() as db_cursor: self._execute_query(db_cursor, query, params) - result: List[T] = db_cursor.fetchall() + result: List[RealDictRow] = db_cursor.fetchall() return result diff --git a/db_wrapper/connection.py b/db_wrapper/connection.py index 18777e0..0574d55 100644 --- a/db_wrapper/connection.py +++ b/db_wrapper/connection.py @@ -41,18 +41,18 @@ async def _try_connect( dsn = f"dbname={database} user={user} password={password} " \ f"host={host} port={port}" + # return await aiopg.create_pool(dsn) + # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object - connection: Optional[aiopg.Connection] = None + pool: Optional[aiopg.Connection] = None LOGGER.info(f"Attempting to connect to database {database} as " f"{user}@{host}:{port}...") - while connection is None: + while pool is None: try: - connection = await aiopg.connect( - dsn, - cursor_factory=RealDictCursor) + pool = await aiopg.create_pool(dsn) except psycopg2OpError as err: print(type(err)) if retries > 12: @@ -67,7 +67,7 @@ async def _try_connect( await asyncio.sleep(5) return await _try_connect(connection_params, retries + 1) - return connection + return pool def _sync_try_connect( @@ -112,10 +112,10 @@ def _sync_try_connect( # PENDS python 3.9 support in pylint # pylint: disable=unsubscriptable-object -async def connect( +async def get_pool( connection_params: ConnectionParameters -) -> aiopg.Connection: - """Establish database connection.""" +) -> aiopg.Pool: + """Establish database connection pool.""" return await _try_connect(connection_params) diff --git a/db_wrapper/model/__init__.py b/db_wrapper/model/__init__.py index 1fe77a6..6bae307 100644 --- a/db_wrapper/model/__init__.py +++ b/db_wrapper/model/__init__.py @@ -1,5 +1,6 @@ """Convenience objects to simplify database interactions w/ given interface.""" +from psycopg2.extras import RealDictRow from .async_model import ( AsyncModel, AsyncCreate, diff --git a/db_wrapper/model/async_model.py b/db_wrapper/model/async_model.py index 703b8fa..73cb23c 100644 --- a/db_wrapper/model/async_model.py +++ b/db_wrapper/model/async_model.py @@ -1,18 +1,20 @@ """Asynchronous Model objects.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from uuid import UUID +from psycopg2.extras import RealDictRow + from db_wrapper.client import AsyncClient from .base import ( ensure_exactly_one, + sql, T, CreateABC, ReadABC, UpdateABC, DeleteABC, ModelABC, - sql, ) @@ -23,16 +25,22 @@ class AsyncCreate(CreateABC[T]): _client: AsyncClient - def __init__(self, client: AsyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client async def one(self, item: T) -> T: """Create one new record with a given item.""" - result: List[T] = await self._client.execute_and_return( - self._query_one(item)) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(self._query_one(item)) + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class AsyncRead(ReadABC[T]): @@ -42,19 +50,27 @@ class AsyncRead(ReadABC[T]): _client: AsyncClient - def __init__(self, client: AsyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client async def one_by_id(self, id_value: UUID) -> T: """Read a row by it's id.""" - result: List[T] = await self._client.execute_and_return( - self._query_one_by_id(id_value)) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value)) # Should only return one item from DB - ensure_exactly_one(result) + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class AsyncUpdate(UpdateABC[T]): @@ -64,11 +80,16 @@ class AsyncUpdate(UpdateABC[T]): _client: AsyncClient - def __init__(self, client: AsyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client - async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T: + async def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T: """Apply changes to row with given id. Arguments: @@ -79,12 +100,14 @@ async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T: Returns: full value of row updated """ - result: List[T] = await self._client.execute_and_return( - self._query_one_by_id(id_value, changes)) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value, changes)) - ensure_exactly_one(result) + ensure_exactly_one(query_result) + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class AsyncDelete(DeleteABC[T]): @@ -94,19 +117,26 @@ class AsyncDelete(DeleteABC[T]): _client: AsyncClient - def __init__(self, client: AsyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: AsyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client async def one_by_id(self, id_value: str) -> T: """Delete one record with matching ID.""" - result: List[T] = await self._client.execute_and_return( - self._query_one_by_id(id_value)) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return( + self._query_one_by_id(id_value)) # Should only return one item from DB - ensure_exactly_one(result) + ensure_exactly_one(query_result) + result = self._return_constructor(**query_result[0]) - return result[0] + return result class AsyncModel(ModelABC[T]): @@ -122,19 +152,22 @@ class AsyncModel(ModelABC[T]): _update: AsyncUpdate[T] _delete: AsyncDelete[T] - # PENDS python 3.9 support in pylint - # pylint: disable=unsubscriptable-object def __init__( self, client: AsyncClient, table: str, + return_constructor: Type[T], ) -> None: super().__init__(client, table) - self._create = AsyncCreate[T](self.client, self.table) - self._read = AsyncRead[T](self.client, self.table) - self._update = AsyncUpdate[T](self.client, self.table) - self._delete = AsyncDelete[T](self.client, self.table) + self._create = AsyncCreate[T]( + self.client, self.table, return_constructor) + self._read = AsyncRead[T]( + self.client, self.table, return_constructor) + self._update = AsyncUpdate[T]( + self.client, self.table, return_constructor) + self._delete = AsyncDelete[T]( + self.client, self.table, return_constructor) @property def create(self) -> AsyncCreate[T]: diff --git a/db_wrapper/model/base.py b/db_wrapper/model/base.py index 46f4a24..fb24805 100644 --- a/db_wrapper/model/base.py +++ b/db_wrapper/model/base.py @@ -3,12 +3,13 @@ # std lib dependencies from __future__ import annotations from typing import ( - TypeVar, - Generic, Any, - Tuple, - List, Dict, + Generic, + List, + Tuple, + Type, + TypeVar, ) from uuid import UUID @@ -54,7 +55,7 @@ def __init__(self) -> None: super().__init__(self, message) -def ensure_exactly_one(result: List[T]) -> None: +def ensure_exactly_one(result: List[Any]) -> None: """Raise appropriate Exceptions if list longer than 1.""" if len(result) > 1: raise UnexpectedMultipleResults(result) @@ -62,15 +63,27 @@ def ensure_exactly_one(result: List[T]) -> None: raise NoResultFound() -class CreateABC(Generic[T]): - """Encapsulate Create operations for Model.read.""" +class CRUDABC(Generic[T]): + """Encapsulate object creation behavior for all CRUD objects.""" # pylint: disable=too-few-public-methods _table: sql.Composable + _return_constructor: Type[T] - def __init__(self, table: sql.Composable) -> None: + def __init__( + self, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: self._table = table + self._return_constructor = return_constructor + + +class CreateABC(CRUDABC[T]): + """Encapsulate Create operations for Model.create.""" + + # pylint: disable=too-few-public-methods def _query_one(self, item: T) -> sql.Composed: """Build query to create one new record with a given item.""" @@ -95,16 +108,11 @@ def _query_one(self, item: T) -> sql.Composed: return query -class ReadABC(Generic[T]): +class ReadABC(CRUDABC[T]): """Encapsulate Read operations for Model.read.""" # pylint: disable=too-few-public-methods - _table: sql.Composable - - def __init__(self, table: sql.Composable) -> None: - self._table = table - def _query_one_by_id(self, id_value: UUID) -> sql.Composed: """Build query to read a row by it's id.""" query = sql.SQL( @@ -119,19 +127,14 @@ def _query_one_by_id(self, id_value: UUID) -> sql.Composed: return query -class UpdateABC(Generic[T]): +class UpdateABC(CRUDABC[T]): """Encapsulate Update operations for Model.read.""" # pylint: disable=too-few-public-methods - _table: sql.Composable - - def __init__(self, table: sql.Composable) -> None: - self._table = table - def _query_one_by_id( self, - id_value: str, + id_value: UUID, changes: Dict[str, Any] ) -> sql.Composed: """Build Query to apply changes to row with given id.""" @@ -154,22 +157,17 @@ def compose_changes(changes: Dict[str, Any]) -> sql.Composed: ).format( table=self._table, changes=compose_changes(changes), - id_value=sql.Literal(id_value), + id_value=sql.Literal(str(id_value)), ) return query -class DeleteABC(Generic[T]): +class DeleteABC(CRUDABC[T]): """Encapsulate Delete operations for Model.read.""" # pylint: disable=too-few-public-methods - _table: sql.Composable - - def __init__(self, table: sql.Composable) -> None: - self._table = table - def _query_one_by_id(self, id_value: str) -> sql.Composed: """Build query to delete one record with matching ID.""" query = sql.SQL( diff --git a/db_wrapper/model/sync_model.py b/db_wrapper/model/sync_model.py index 81d0563..7d65602 100644 --- a/db_wrapper/model/sync_model.py +++ b/db_wrapper/model/sync_model.py @@ -1,18 +1,20 @@ """Synchronous Model objects.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from uuid import UUID +from psycopg2.extras import RealDictRow + from db_wrapper.client import SyncClient from .base import ( ensure_exactly_one, + sql, T, CreateABC, + DeleteABC, ReadABC, UpdateABC, - DeleteABC, ModelABC, - sql, ) @@ -23,16 +25,22 @@ class SyncCreate(CreateABC[T]): _client: SyncClient - def __init__(self, client: SyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client def one(self, item: T) -> T: """Create one new record with a given item.""" - result: List[T] = self._client.execute_and_return( + query_result: List[RealDictRow] = self._client.execute_and_return( self._query_one(item)) + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class SyncRead(ReadABC[T]): @@ -42,19 +50,26 @@ class SyncRead(ReadABC[T]): _client: SyncClient - def __init__(self, client: SyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client def one_by_id(self, id_value: UUID) -> T: """Read a row by it's id.""" - result: List[T] = self._client.execute_and_return( + query_result: List[RealDictRow] = self._client.execute_and_return( self._query_one_by_id(id_value)) # Should only return one item from DB - ensure_exactly_one(result) + ensure_exactly_one(query_result) - return result[0] + result: T = self._return_constructor(**query_result[0]) + + return result class SyncUpdate(UpdateABC[T]): @@ -64,11 +79,16 @@ class SyncUpdate(UpdateABC[T]): _client: SyncClient - def __init__(self, client: SyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client - def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T: + def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T: """Apply changes to row with given id. Arguments: @@ -79,12 +99,14 @@ def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T: Returns: full value of row updated """ - result: List[T] = self._client.execute_and_return( + query_result: List[RealDictRow] = self._client.execute_and_return( self._query_one_by_id(id_value, changes)) - ensure_exactly_one(result) + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class SyncDelete(DeleteABC[T]): @@ -94,19 +116,25 @@ class SyncDelete(DeleteABC[T]): _client: SyncClient - def __init__(self, client: SyncClient, table: sql.Composable) -> None: - super().__init__(table) + def __init__( + self, + client: SyncClient, + table: sql.Composable, + return_constructor: Type[T] + ) -> None: + super().__init__(table, return_constructor) self._client = client def one_by_id(self, id_value: str) -> T: """Delete one record with matching ID.""" - result: List[T] = self._client.execute_and_return( + query_result: List[RealDictRow] = self._client.execute_and_return( self._query_one_by_id(id_value)) - # Should only return one item from DB - ensure_exactly_one(result) + ensure_exactly_one(query_result) + + result: T = self._return_constructor(**query_result[0]) - return result[0] + return result class SyncModel(ModelABC[T]): @@ -122,19 +150,22 @@ class SyncModel(ModelABC[T]): _update: SyncUpdate[T] _delete: SyncDelete[T] - # PENDS python 3.9 support in pylint - # pylint: disable=unsubscriptable-object def __init__( self, client: SyncClient, table: str, + return_constructor: Type[T], ) -> None: super().__init__(client, table) - self._create = SyncCreate[T](self.client, self.table) - self._read = SyncRead[T](self.client, self.table) - self._update = SyncUpdate[T](self.client, self.table) - self._delete = SyncDelete[T](self.client, self.table) + self._create = SyncCreate[T]( + self.client, self.table, return_constructor) + self._read = SyncRead[T]( + self.client, self.table, return_constructor) + self._update = SyncUpdate[T]( + self.client, self.table, return_constructor) + self._delete = SyncDelete[T]( + self.client, self.table, return_constructor) @property def create(self) -> SyncCreate[T]: diff --git a/example/.python-version b/example/.python-version index 6bd1074..f69abe4 100644 --- a/example/.python-version +++ b/example/.python-version @@ -1 +1 @@ -3.9.1 +3.9.7 diff --git a/example/example/__init__.py b/example/example/__init__.py new file mode 100644 index 0000000..8aef547 --- /dev/null +++ b/example/example/__init__.py @@ -0,0 +1,2 @@ +from . import models +from .main import run diff --git a/example/example/example.py b/example/example/main.py similarity index 79% rename from example/example/example.py rename to example/example/main.py index ab29399..7123ad7 100644 --- a/example/example/example.py +++ b/example/example/main.py @@ -1,4 +1,4 @@ -"""An example of how to use Client & Model together.""" +"""An example of how to use AsyncClient & AsyncModel together.""" import asyncio import json @@ -7,9 +7,13 @@ from uuid import uuid4, UUID from typing import Any, List -from db_wrapper import ConnectionParameters, Client, Model +from db_wrapper import ConnectionParameters, AsyncClient, AsyncModel -from models import AModel, ExtendedModel, ExtendedModelData +from example.models import ( + AModel, + ExtendedModel, + ExtendedModelData, +) logging.basicConfig(level=logging.INFO) @@ -33,17 +37,17 @@ def default(self, obj: Any) -> Any: user=os.getenv('DB_USER', 'test'), password=os.getenv('DB_PASS', 'pass'), database=os.getenv('DB_NAME', 'dev')) -client = Client(conn_params) +client = AsyncClient(conn_params) -a_model = Model[AModel](client, 'a_model') +a_model = AsyncModel[AModel](client, 'a_model', AModel) extended_model = ExtendedModel(client) async def create_a_model_record() -> UUID: """ - Show how to use a simple Model instance. + Show how to use a simple AsyncModel instance. - Create a new record using the default Model.create.one method. + Create a new record using the default AsyncModel.create.one method. """ new_record = AModel(**{ 'id': uuid4(), @@ -59,13 +63,13 @@ async def create_a_model_record() -> UUID: async def read_a_model(id_value: UUID) -> AModel: """Show how to read a record with a given id value.""" - # read.one_by_id expects a string, so UUID values need - # converted using str() - return await a_model.read.one_by_id(str(id_value)) + return await a_model.read.one_by_id(id_value) async def create_extended_models() -> None: - """Show how using an extended Model can be the same as the defaults.""" + """ + Show how using an extended AsyncModel can be the same as the defaults. + """ dicts = [{ 'id': uuid4(), 'string': 'something', @@ -101,7 +105,7 @@ async def read_extended_models() -> List[ExtendedModelData]: """Show how to use an extended Model's new methods.""" # We defined read.all in ./models/extended_model.py's ExtendedRead class, # then replaced ExtendedModel's read property with ExtendedRead. - # As a result, we can call it just like any other method on Model.read. + # As a result, we can call it just like any other method on AsyncModel.read return await extended_model.read.all() @@ -119,13 +123,15 @@ async def run() -> None: new_id = await create_a_model_record() created_a_model = await read_a_model(new_id) await create_extended_models() - extended_models = await read_extended_models() + created_extended_models = await read_extended_models() finally: await client.disconnect() # Print results to stdout - print(json.dumps(created_a_model, cls=UUIDJsonEncoder)) - print(json.dumps(extended_models, cls=UUIDJsonEncoder)) + print(json.dumps(created_a_model.dict(), cls=UUIDJsonEncoder)) + print(json.dumps([model.dict() + for model in created_extended_models], + cls=UUIDJsonEncoder)) if __name__ == '__main__': # A simple app can be run using asyncio's run method. diff --git a/example/example/models/extended_model.py b/example/example/models/extended_model.py index 67e7cba..97ebd6f 100644 --- a/example/example/models/extended_model.py +++ b/example/example/models/extended_model.py @@ -5,10 +5,10 @@ from psycopg2 import sql from psycopg2.extensions import register_adapter -from psycopg2.extras import Json +from psycopg2.extras import Json # type: ignore from db_wrapper import AsyncClient, AsyncModel, ModelData -from db_wrapper.model import AsyncRead, AsyncCreate +from db_wrapper.model import AsyncRead, AsyncCreate, RealDictRow # tell psycopg2 to adapt all dictionaries to json instead of # the default hstore @@ -55,10 +55,11 @@ async def one(self, item: ExtendedModelData) -> ExtendedModelData: values=sql.SQL(',').join(values), ) - result: List[ExtendedModelData] = \ + query_result: List[RealDictRow] = \ await self._client.execute_and_return(query) + result = self._return_constructor(**query_result[0]) - return result[0] + return result class ExtendedReader(AsyncRead[ExtendedModelData]): @@ -75,8 +76,10 @@ async def all_by_string(self, string: str) -> List[ExtendedModelData]: string=sql.Identifier(string) ) - result: List[ExtendedModelData] = await self \ - ._client.execute_and_return(query) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(query) + result = [self._return_constructor(**row) + for row in query_result] return result @@ -85,8 +88,10 @@ async def all(self) -> List[ExtendedModelData]: query = sql.SQL('SELECT * FROM {table}').format( table=self._table) - result: List[ExtendedModelData] = await self \ - ._client.execute_and_return(query) + query_result: List[RealDictRow] = \ + await self._client.execute_and_return(query) + result = [self._return_constructor(**row) + for row in query_result] return result @@ -98,6 +103,8 @@ class ExtendedModel(AsyncModel[ExtendedModelData]): create: ExtendedCreator def __init__(self, client: AsyncClient) -> None: - super().__init__(client, 'extended_model') - self.read = ExtendedReader(self.client, self.table) - self.create = ExtendedCreator(self.client, self.table) + super().__init__(client, 'extended_model', ExtendedModelData) + self.read = ExtendedReader( + self.client, self.table, ExtendedModelData) + self.create = ExtendedCreator( + self.client, self.table, ExtendedModelData) diff --git a/example/requirements/prod.txt b/example/requirements/prod.txt index d875b7a..c642fd0 100644 --- a/example/requirements/prod.txt +++ b/example/requirements/prod.txt @@ -1 +1 @@ -https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.1.0/db_wrapper-2.1.0-py3-none-any.whl +-e ../ diff --git a/example_sync/example/__init__.py b/example_sync/example/__init__.py new file mode 100644 index 0000000..8aef547 --- /dev/null +++ b/example_sync/example/__init__.py @@ -0,0 +1,2 @@ +from . import models +from .main import run diff --git a/example_sync/example/example.py b/example_sync/example/main.py similarity index 90% rename from example_sync/example/example.py rename to example_sync/example/main.py index 89f8ef0..e5a96dd 100644 --- a/example_sync/example/example.py +++ b/example_sync/example/main.py @@ -9,7 +9,7 @@ from db_wrapper import SyncClient, ConnectionParameters from db_wrapper.model import SyncModel as Model -from models import ( +from example.models import ( AModel, ExtendedModel, ExtendedModelData, @@ -39,7 +39,7 @@ def default(self, obj: Any) -> Any: database=os.getenv('DB_NAME', 'dev')) client = SyncClient(conn_params) -a_model = Model[AModel](client, 'a_model') +a_model = Model[AModel](client, 'a_model', AModel) extended_model = ExtendedModel(client) @@ -63,9 +63,7 @@ def create_a_model_record() -> UUID: def read_a_model(id_value: UUID) -> AModel: """Show how to read a record with a given id value.""" - # read.one_by_id expects a string, so UUID values need - # converted using str() - return a_model.read.one_by_id(str(id_value)) + return a_model.read.one_by_id(id_value) def create_extended_models() -> None: @@ -122,13 +120,15 @@ def run() -> None: new_id = create_a_model_record() created_a_model = read_a_model(new_id) create_extended_models() - extended_models = read_extended_models() + created_extended_models = read_extended_models() finally: client.disconnect() # Print results to stdout - print(json.dumps(created_a_model, cls=UUIDJsonEncoder)) - print(json.dumps(extended_models, cls=UUIDJsonEncoder)) + print(json.dumps(created_a_model.dict(), cls=UUIDJsonEncoder)) + print(json.dumps([model.dict() + for model in created_extended_models], + cls=UUIDJsonEncoder)) if __name__ == '__main__': diff --git a/example_sync/example/models/extended_model.py b/example_sync/example/models/extended_model.py index 10ac2b3..fd44311 100644 --- a/example_sync/example/models/extended_model.py +++ b/example_sync/example/models/extended_model.py @@ -5,10 +5,10 @@ from psycopg2 import sql from psycopg2.extensions import register_adapter -from psycopg2.extras import Json +from psycopg2.extras import Json # type: ignore -from db_wrapper import SyncClient -from db_wrapper.model import ModelData, SyncModel, SyncRead, SyncCreate +from db_wrapper import SyncClient, SyncModel, ModelData +from db_wrapper.model import RealDictRow, SyncRead, SyncCreate # tell psycopg2 to adapt all dictionaries to json instead of # the default hstore @@ -51,10 +51,12 @@ def one(self, item: ExtendedModelData) -> ExtendedModelData: values=sql.SQL(',').join(values), ) - result: List[ExtendedModelData] = \ + query_result: List[RealDictRow] = \ self._client.execute_and_return(query) - return result[0] + result = self._return_constructor(**query_result[0]) + + return result class ExtendedReader(SyncRead[ExtendedModelData]): @@ -81,8 +83,10 @@ def all(self) -> List[ExtendedModelData]: query = sql.SQL('SELECT * FROM {table}').format( table=self._table) - result: List[ExtendedModelData] = self \ - ._client.execute_and_return(query) + query_result: List[RealDictRow] = \ + self._client.execute_and_return(query) + + result = [self._return_constructor(**item) for item in query_result] return result @@ -94,6 +98,14 @@ class ExtendedModel(SyncModel[ExtendedModelData]): create: ExtendedCreator def __init__(self, client: SyncClient) -> None: - super().__init__(client, 'extended_model') - self.read = ExtendedReader(self.client, self.table) - self.create = ExtendedCreator(self.client, self.table) + super().__init__(client, 'extended_model', ExtendedModelData) + self.read = ExtendedReader( + self.client, + self.table, + ExtendedModelData + ) + self.create = ExtendedCreator( + self.client, + self.table, + ExtendedModelData + ) diff --git a/example_sync/requirements/prod.txt b/example_sync/requirements/prod.txt index d875b7a..e60595e 100644 --- a/example_sync/requirements/prod.txt +++ b/example_sync/requirements/prod.txt @@ -1 +1 @@ -https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.1.0/db_wrapper-2.1.0-py3-none-any.whl +https://github.com/cheese-drawer/lib-python-db-wrapper/releases/download/2.3.0/db_wrapper-2.3.0-py3-none-any.whl diff --git a/scripts/install b/scripts/install new file mode 100755 index 0000000..b5d3924 --- /dev/null +++ b/scripts/install @@ -0,0 +1,66 @@ +#!/usr/bin/env bash + +# +# NAVIGATE TO CORRECT DIRECTORY +# + +# start by going to script dir so all movements +# from here are relative +SCRIPT_DIR=`dirname $(realpath "$0")` +cd $SCRIPT_DIR +# go up to root +cd .. + + +# +# INSTALL FROM LISTS +# + +function dev { + echo "" + echo "Installing dev requirements..." + echo "" + dev_result=0 + + # install from dev list + pip install -r requirements/dev.txt +} + +function prod { + echo "" + echo "Installing prod requirements..." + echo "" + prod_result=0 + + # install from dev list + pip install -r requirements/prod.txt +} + +# Install dev, prod, or all requirements depending on argument given +if [ $# -eq 0 ]; then + dev + prod + + if [[ $dev_result != 0 && $prod_result != 0 ]]; then + echo "Errors found in both dev & prod installation. See output above." + exit $dev_result + elif [[ $dev_result != 0 && $prod_result == 0 ]]; then + echo "Errors found in dev installation. See output above." + exit $dev_result + elif [[ $dev_result == 0 && $prod_result != 0 ]]; then + echo "Errors found in prod installation. See output above." + exit $prod_result + else + exit 0 + fi + +elif [[ $1 == 'dev' || $1 == 'development' ]]; then + dev ${@:2} + exit $dev_result +elif [[ $1 == 'prod' || $1 == 'production' ]]; then + prod ${@:2} + exit $prod_result +else + echo "Bad argument given, either specify \`dev\` or \`prod\` requirements by giving either word as your first argument to this script, or run both by giving no arguments." + exit 1 +fi diff --git a/setup.py b/setup.py index d82eeff..ead47ee 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="db_wrapper", - version="2.2.0", + version="2.4.0", author="Andrew Chang-DeWitt", author_email="andrew@andrew-chang-dewitt.dev", description=short_description, diff --git a/stubs/psycopg2/__init__.pyi b/stubs/psycopg2/__init__.pyi index 078e53d..2ec209d 100644 --- a/stubs/psycopg2/__init__.pyi +++ b/stubs/psycopg2/__init__.pyi @@ -1,3 +1,19 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Optional from psycopg2._psycopg import ( BINARY, Binary, @@ -27,7 +43,6 @@ from psycopg2._psycopg import ( paramstyle, threadsafety ) -from typing import Any, Optional connection = connection OperationalError = OperationalError diff --git a/stubs/psycopg2/_ipaddress.pyi b/stubs/psycopg2/_ipaddress.pyi index 59af263..0af4c8c 100644 --- a/stubs/psycopg2/_ipaddress.pyi +++ b/stubs/psycopg2/_ipaddress.pyi @@ -1,9 +1,33 @@ -from psycopg2.extensions import QuotedString as QuotedString, new_array_type as new_array_type, new_type as new_type, register_adapter as register_adapter, register_type as register_type +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-module-docstring +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called +# pylint: disable=unused-import +# pylint: disable=useless-import-alias +# pylint: disable=line-too-long + from typing import Any, Optional +from psycopg2.extensions import ( + QuotedString, + new_array_type, + new_type, + register_adapter, + register_type, +) + ipaddress: Any + def register_ipaddress(conn_or_curs: Optional[Any] = ...) -> None: ... -def cast_interface(s: Any, cur: Optional[Any] = ...): ... -def cast_network(s: Any, cur: Optional[Any] = ...): ... -def adapt_ipaddress(obj: Any): ... +def cast_interface(s: Any, cur: Optional[Any] = ...) -> Any: ... +def cast_network(s: Any, cur: Optional[Any] = ...) -> Any: ... +def adapt_ipaddress(obj: Any) -> Any: ... diff --git a/stubs/psycopg2/_json.pyi b/stubs/psycopg2/_json.pyi index f9ae3f0..276088e 100644 --- a/stubs/psycopg2/_json.pyi +++ b/stubs/psycopg2/_json.pyi @@ -1,3 +1,19 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-arguments +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + from typing import Any, Optional from psycopg2._psycopg import ( @@ -23,13 +39,23 @@ class Json: def getquoted(self) -> Any: ... -def register_json(conn_or_curs: Optional[Any] = ..., globally: bool = ..., loads: Optional[Any] - = ..., oid: Optional[Any] = ..., array_oid: Optional[Any] = ..., name: str = ...) -> Any: ... +def register_json( + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ..., + oid: Optional[Any] = ..., + array_oid: Optional[Any] = ..., + name: str = ... +) -> Any: ... def register_default_json( - conn_or_curs: Optional[Any] = ..., globally: bool = ..., loads: Optional[Any] = ...) -> Any: ... + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ...) -> Any: ... def register_default_jsonb( - conn_or_curs: Optional[Any] = ..., globally: bool = ..., loads: Optional[Any] = ...) -> Any: ... + conn_or_curs: Optional[Any] = ..., + globally: bool = ..., + loads: Optional[Any] = ...) -> Any: ... diff --git a/stubs/psycopg2/_psycopg.pyi b/stubs/psycopg2/_psycopg.pyi index 1249e59..b2b617b 100644 --- a/stubs/psycopg2/_psycopg.pyi +++ b/stubs/psycopg2/_psycopg.pyi @@ -1,6 +1,17 @@ -from typing import Any - -import psycopg2.extensions +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-module-docstring +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + +from typing import Any, Iterator __libpq_version__: Any BINARY: Any @@ -124,6 +135,7 @@ class Binary: buffer: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... def getquoted(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def prepare(conn: Any) -> prepareforbinaryencodingusingconn: ... def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... @@ -185,14 +197,26 @@ class ConnectionInfo: def ssl_attribute(self, *args: Any, **kwargs: Any) -> Any: ... -class DataError(DatabaseError): - ... +class Error(Exception): + cursor: Any = ... + diag: Any = ... + pgcode: Any = ... + pgerror: Any = ... + def __init__(self, *args: Any, ** + kwargs: Any) -> None: ... + + def __reduce__(self) -> Any: ... + def __setstate__(self, state: Any) -> Any: ... class DatabaseError(Error): ... +class DataError(DatabaseError): + ... + + class Decimal: adapted: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... @@ -222,16 +246,6 @@ class Diagnostics: def __init__(self, *args: Any, **kwargs: Any) -> None: ... -class Error(Exception): - cursor: Any = ... - diag: Any = ... - pgcode: Any = ... - pgerror: Any = ... - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - def __reduce__(self) -> Any: ... - def __setstate__(self, state: Any) -> Any: ... - - class Float: adapted: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... @@ -287,7 +301,7 @@ class Notify: def __ge__(self, other: Any) -> Any: ... def __getitem__(self, index: Any) -> Any: ... def __gt__(self, other: Any) -> Any: ... - def __hash__(self) -> Any: ... + def __hash__(self) -> Any: ... # pylint: disable=invalid-hash-returned def __le__(self, other: Any) -> Any: ... def __len__(self) -> Any: ... def __lt__(self, other: Any) -> Any: ... @@ -316,75 +330,6 @@ class QuotedString: def __conform__(self, *args: Any, **kwargs: Any) -> Any: ... -class ReplicationConnection(connection): - autocommit: Any = ... - isolation_level: Any = ... - replication_type: Any = ... - reset: Any = ... - set_isolation_level: Any = ... - set_session: Any = ... - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - - -class ReplicationCursor(cursor): - feedback_timestamp: Any = ... - io_timestamp: Any = ... - wal_end: Any = ... - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - - def consume_stream( - consumer: Any, - keepalive_interval: Any = ... - ) -> Any: ... - - def read_message(self, *args: Any, **kwargs: Any) -> Any: ... - - def send_feedback( - write_lsn: Any = ..., - flush_lsn: Any = ..., - apply_lsn: Any = ..., - reply: Any = ..., - force: Any = ... - ) -> Any: ... - - def start_replication_expert( - command: Any, - decode: Any = ..., - status_interval: Any = ..., - ) -> Any: ... - - -class ReplicationMessage: - cursor: Any = ... - data_size: Any = ... - data_start: Any = ... - payload: Any = ... - send_time: Any = ... - wal_end: Any = ... - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - - -class TransactionRollbackError(OperationalError): - ... - - -class Warning(Exception): - ... - - -class Xid: - bqual: Any = ... - database: Any = ... - format_id: Any = ... - gtrid: Any = ... - owner: Any = ... - prepared: Any = ... - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - def from_string(self, *args: Any, **kwargs: Any) -> Any: ... - def __getitem__(self, index: Any) -> Any: ... - def __len__(self) -> Any: ... - - class connection: DataError: Any = ... DatabaseError: Any = ... @@ -419,6 +364,7 @@ class connection: def close(self, *args: Any, **kwargs: Any) -> Any: ... def commit(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def cursor( name: Any = ..., cursor_factory: Any = ..., @@ -429,10 +375,12 @@ class connection: def get_backend_pid(self, *args: Any, **kwargs: Any) -> Any: ... def get_dsn_parameters(self, *args: Any, **kwargs: Any) -> Any: ... def get_native_connection(self, *args: Any, **kwargs: Any) -> Any: ... - def get_parameter_status(parameter) -> Any: ... + @staticmethod + def get_parameter_status(parameter: Any) -> Any: ... def get_transaction_status(self, *args: Any, **kwargs: Any) -> Any: ... def isexecuting(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def lobject( oid: Any = ..., mode: Any = ..., @@ -444,24 +392,23 @@ class connection: def poll(self, *args: Any, **kwargs: Any) -> Any: ... def reset(self, *args: Any, **kwargs: Any) -> Any: ... def rollback(self, *args: Any, **kwargs: Any) -> Any: ... - def set_client_encoding(encoding) -> Any: ... - def set_isolation_level(level) -> Any: ... + @staticmethod + def set_client_encoding(encoding: Any) -> Any: ... + @staticmethod + def set_isolation_level(level: Any) -> Any: ... def set_session(self, *args: Any, **kwargs: Any) -> Any: ... - def tpc_begin(xid) -> Any: ... + @staticmethod + def tpc_begin(xid: Any) -> Any: ... def tpc_commit(self, *args: Any, **kwargs: Any) -> Any: ... def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any: ... def tpc_recover(self, *args: Any, **kwargs: Any) -> Any: ... def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def xid(format_id: Any, gtrid: Any, bqual: Any) -> Any: ... def __enter__(self) -> Any: ... def __exit__(self, type: Any, value: Any, traceback: Any) -> Any: ... -value: Any - -listoftuple: Any - - class cursor: arraysize: Any = ... binary_types: Any = ... @@ -483,11 +430,15 @@ class cursor: tzinfo_factory: Any = ... withhold: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @staticmethod def callproc(procname: Any, parameters: Any = ...) -> Any: ... - def cast(oid: Any, s: Any) -> value: ... + @staticmethod + def cast(oid: Any, s: Any) -> Any: ... def close(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def copy_expert(sql: Any, file: Any, size: Any = ...) -> Any: ... + @staticmethod def copy_from( file: Any, table: Any, @@ -497,6 +448,7 @@ class cursor: columns: Any = ..., ) -> Any: ... + @staticmethod def copy_to( file: Any, table: Any, @@ -505,32 +457,120 @@ class cursor: columns: Any = ... ) -> Any: ... + @staticmethod def execute(query: Any, params: Any = ...) -> Any: ... + @staticmethod def executemany(query: Any, vars_list: Any) -> Any: ... def fetchall(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def fetchmany(size: Any = ...) -> listoftuple: ... def fetchone(self, *args: Any, **kwargs: Any) -> Any: ... def mogrify(self, *args: Any, **kwargs: Any) -> Any: ... def nextset(self, *args: Any, **kwargs: Any) -> Any: ... + @staticmethod def scroll(value: Any, mode: Any = ...) -> Any: ... + @staticmethod def setinputsizes(sizes: Any) -> Any: ... + @staticmethod def setoutputsize(size: Any, column: Any = ...) -> Any: ... def __enter__(self) -> Any: ... def __exit__(self, type: Any, value: Any, traceback: Any) -> Any: ... - def __iter__(self) -> Any: ... + def __iter__( # pylint: disable=non-iterator-returned + self) -> Iterator[Any]: ... + def __next__(self) -> Any: ... +class ReplicationConnection(connection): + autocommit: Any = ... + isolation_level: Any = ... + replication_type: Any = ... + reset: Any = ... + set_isolation_level: Any = ... + set_session: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + +class ReplicationCursor(cursor): + feedback_timestamp: Any = ... + io_timestamp: Any = ... + wal_end: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + @staticmethod + def consume_stream( + consumer: Any, + keepalive_interval: Any = ... + ) -> Any: ... + + def read_message(self, *args: Any, **kwargs: Any) -> Any: ... + + @staticmethod + def send_feedback( + write_lsn: Any = ..., + flush_lsn: Any = ..., + apply_lsn: Any = ..., + reply: Any = ..., + force: Any = ... + ) -> Any: ... + + @staticmethod + def start_replication_expert( + command: Any, + decode: Any = ..., + status_interval: Any = ..., + ) -> Any: ... + + +class ReplicationMessage: + cursor: Any = ... + data_size: Any = ... + data_start: Any = ... + payload: Any = ... + send_time: Any = ... + wal_end: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + +class TransactionRollbackError(OperationalError): + ... + + +class Warning(Exception): + ... + + +class Xid: + bqual: Any = ... + database: Any = ... + format_id: Any = ... + gtrid: Any = ... + owner: Any = ... + prepared: Any = ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def from_string(self, *args: Any, **kwargs: Any) -> Any: ... + def __getitem__(self, index: Any) -> Any: ... + def __len__(self) -> Any: ... + + +listoftuple: Any + + class lobject: closed: Any = ... mode: Any = ... oid: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... def close(self, *args: Any, **kwargs: Any) -> Any: ... - def export(filename) -> Any: ... - def read(size=...) -> Any: ... + @staticmethod + def export(filename: Any) -> Any: ... + @staticmethod + def read(size: Any = ...) -> Any: ... + @staticmethod def seek(offset: Any, whence: Any = ...) -> Any: ... def tell(self, *args: Any, **kwargs: Any) -> Any: ... - def truncate(len=...) -> Any: ... + @staticmethod + def truncate(len: Any = ...) -> Any: ... def unlink(self, *args: Any, **kwargs: Any) -> Any: ... - def write(str) -> Any: ... + @staticmethod + def write(str: str) -> Any: ... diff --git a/stubs/psycopg2/extensions.pyi b/stubs/psycopg2/extensions.pyi index 02fef37..c597c28 100644 --- a/stubs/psycopg2/extensions.pyi +++ b/stubs/psycopg2/extensions.pyi @@ -1,9 +1,25 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + from typing import Any, Optional from psycopg2._json import ( register_default_json, register_default_jsonb ) + from psycopg2._psycopg import ( AsIs, BINARYARRAY, @@ -46,7 +62,6 @@ from psycopg2._psycopg import ( PYTIME, PYTIMEARRAY, QueryCanceledError, - QuotedString, ROWIDARRAY, STRINGARRAY, TIME, @@ -67,13 +82,6 @@ from psycopg2._psycopg import ( get_wait_callback, libpq_version, lobject, - new_array_type, - new_type, - parse_dsn, - quote_ident, - register_type, - set_wait_callback, - string_types ) from psycopg2._range import Range @@ -100,6 +108,15 @@ TRANSACTION_STATUS_INTRANS: int TRANSACTION_STATUS_INERROR: int TRANSACTION_STATUS_UNKNOWN: int +QuotedString: Any +new_array_type: Any +new_type: Any +parse_dsn: Any +quote_ident: Any +register_type: Any +set_wait_callback: Any +string_types: Any + def register_adapter(typ: Any, callable: Any) -> None: ... diff --git a/stubs/psycopg2/extras.pyi b/stubs/psycopg2/extras.pyi index 65451da..a95f203 100644 --- a/stubs/psycopg2/extras.pyi +++ b/stubs/psycopg2/extras.pyi @@ -1,23 +1,62 @@ +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=no-name-in-module +# pylint: disable=unused-import +# pylint: disable=unused-argument +# pylint: disable=multiple-statements +# pylint: disable=invalid-name +# pylint: disable=invalid-length-returned +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-public-methods +# pylint: disable=no-self-use +# pylint: disable=redefined-builtin +# pylint: disable=super-init-not-called + + +from collections import OrderedDict from typing import Any, Optional -# from .compat import PY2 as PY2, PY3 as PY3, lru_cache as lru_cache -# from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident -# from collections import OrderedDict -# from psycopg2._ipaddress import register_ipaddress as register_ipaddress -# from psycopg2._json import Json as Json, json as json, register_default_json as register_default_json, register_default_jsonb as register_default_jsonb, register_json as register_json -# from psycopg2._psycopg import REPLICATION_LOGICAL as REPLICATION_LOGICAL, REPLICATION_PHYSICAL as REPLICATION_PHYSICAL, ReplicationConnection as _replicationConnection, ReplicationCursor as _replicationCursor, ReplicationMessage as ReplicationMessage -# from psycopg2._range import DateRange as DateRange, DateTimeRange as DateTimeRange, DateTimeTZRange as DateTimeTZRange, NumericRange as NumericRange, Range as Range, RangeAdapter as RangeAdapter, RangeCaster as RangeCaster, register_range as register_range + +from psycopg2._range import ( + DateRange, + DateTimeRange, + DateTimeTZRange, + NumericRange, + Range, + RangeAdapter, + RangeCaster, + register_range +) +from psycopg2._psycopg import ( + REPLICATION_LOGICAL, + REPLICATION_PHYSICAL, + ReplicationConnection, + ReplicationCursor, + ReplicationMessage, + connection, + cursor, + quote_ident +) +from psycopg2._ipaddress import ( + register_ipaddress +) + + +class RealDictRow(OrderedDict[Any, Any]): + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def __setitem__(self, key: Any, value: Any) -> None: ... # # -# class DictCursorBase(_cursor): +# class DictCursorBase(cursor): # row_factory: Any = ... # def __init__(self, *args: Any, **kwargs: Any) -> None: ... # def fetchone(self): ... # def fetchmany(self, size: Optional[Any] = ...): ... # def fetchall(self): ... -# def __iter__(self) -> Any: ... +# def __iter__(self) -> Any: ... # pylint: disable=non-iterator-returned # # -# class DictConnection(_connection): +# class DictConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # @@ -41,7 +80,7 @@ from typing import Any, Optional # def __reduce__(self): ... # # -# class RealDictConnection(_connection): +# class RealDictConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # @@ -57,11 +96,11 @@ from typing import Any, Optional # def __setitem__(self, key: Any, value: Any) -> None: ... # # -# class NamedTupleConnection(_connection): +# class NamedTupleConnection(connection): # def cursor(self, *args: Any, **kwargs: Any): ... # # -# class NamedTupleCursor(_cursor): +# class NamedTupleCursor(cursor): # Record: Any = ... # MAX_CACHE: int = ... # def execute(self, query: Any, vars: Optional[Any] = ...): ... @@ -73,14 +112,14 @@ from typing import Any, Optional # def __iter__(self) -> Any: ... # # -# class LoggingConnection(_connection): +# class LoggingConnection(connection): # log: Any = ... # def initialize(self, logobj: Any) -> None: ... # def filter(self, msg: Any, curs: Any): ... # def cursor(self, *args: Any, **kwargs: Any): ... # # -# class LoggingCursor(_cursor): +# class LoggingCursor(cursor): # def execute(self, query: Any, vars: Optional[Any] = ...): ... # def callproc(self, procname: Any, vars: Optional[Any] = ...): ... # @@ -111,11 +150,21 @@ from typing import Any, Optional # # class ReplicationCursor(_replicationCursor): # def create_replication_slot( -# self, slot_name: Any, slot_type: Optional[Any] = ..., output_plugin: Optional[Any] = ...) -> None: ... +# self, +# slot_name: Any, +# slot_type: Optional[Any] = ..., +# output_plugin: Optional[Any] = ...) -> None: ... # # def drop_replication_slot(self, slot_name: Any) -> None: ... -# def start_replication(self, slot_name: Optional[Any] = ..., slot_type: Optional[Any] = ..., start_lsn: int = ..., -# timeline: int = ..., options: Optional[Any] = ..., decode: bool = ..., status_interval: int = ...) -> None: ... +# +# def start_replication(self, +# slot_name: Optional[Any] = ..., +# slot_type: Optional[Any] = ..., +# start_lsn: int = ..., +# timeline: int = ..., +# options: Optional[Any] = ..., +# decode: bool = ..., +# status_interval: int = ...) -> None: ... # # def fileno(self): ... # @@ -124,10 +173,8 @@ from typing import Any, Optional # def __init__(self, uuid: Any) -> None: ... # def __conform__(self, proto: Any): ... # def getquoted(self): ... -# -# -# pylint: disable=unsubscriptable-object + def register_uuid(oids: Optional[Any] = ..., conn_or_curs: Optional[Any] = ...) -> None: ... # @@ -161,8 +208,13 @@ def register_uuid(oids: Optional[Any] = ..., # def get_oids(self, conn_or_curs: Any): ... # # -# def register_hstore(conn_or_curs: Any, globally: bool = ..., unicode: bool = ..., -# oid: Optional[Any] = ..., array_oid: Optional[Any] = ...) -> None: ... +# def register_hstore( +# conn_or_curs: Any, +# globally: bool = ..., +# unicode: bool = ..., +# oid: Optional[Any] = ..., +# array_oid: Optional[Any] = ... +# ) -> None: ... # # # class CompositeCaster: diff --git a/test/test_model.py b/test/test_model.py index 0f9e686..d9de694 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -12,6 +12,7 @@ Any, List, Tuple, + Type, TypeVar, ) from uuid import uuid4 @@ -34,12 +35,13 @@ ) -# Generic doesn't need a more descriptive name -# pylint: disable=invalid-name T = TypeVar('T', bound=ModelData) -def setupAsync(query_result: List[T]) -> Tuple[AsyncModel[T], AsyncClient]: +def setupAsync( + query_result: List[T], + model_data: Type[T] +) -> Tuple[AsyncModel[T], AsyncClient]: """Setup helper that returns instances of both a Model & a Client. Mocks the execute_and_return method on the Client instance to skip @@ -57,15 +59,18 @@ def setupAsync(query_result: List[T]) -> Tuple[AsyncModel[T], AsyncClient]: # mock client's sql execution method client.execute_and_return = helpers.AsyncMock( # type:ignore - return_value=query_result) + return_value=[i.dict() for i in query_result]) # init a real model with mocked client - model = AsyncModel[Any](client, 'test') + model = AsyncModel[Any](client, 'test', model_data) return model, client -def setupSync(query_result: List[T]) -> Tuple[SyncModel[T], SyncClient]: +def setupSync( + query_result: List[T], + model_data: Type[T] +) -> Tuple[SyncModel[T], SyncClient]: """Setup helper that returns instances of both a Model & a Client. Mocks the execute_and_return method on the Client instance to skip @@ -83,10 +88,10 @@ def setupSync(query_result: List[T]) -> Tuple[SyncModel[T], SyncClient]: # mock client's sql execution method client.execute_and_return = helpers.MagicMock( # type:ignore - return_value=query_result) + return_value=[i.dict() for i in query_result]) # init a real model with mocked client - model = SyncModel[Any](client, 'test') + model = SyncModel[T](client, 'test', model_data) return model, client @@ -97,11 +102,11 @@ class TestReadOneById(TestCase): @helpers.async_test async def test_it_correctly_builds_query_with_given_id(self) -> None: item = ModelData(id=uuid4()) - async_model, async_client = setupAsync([item]) - sync_model, sync_client = setupSync([item]) + async_model, async_client = setupAsync([item], ModelData) + sync_model, sync_client = setupSync([item], ModelData) - await async_model.read.one_by_id(str(item.id)) - sync_model.read.one_by_id(str(item.id)) + await async_model.read.one_by_id(item.id) + sync_model.read.one_by_id(item.id) async_query_composed = cast( helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] @@ -122,10 +127,10 @@ async def test_it_correctly_builds_query_with_given_id(self) -> None: @helpers.async_test async def test_it_returns_a_single_result(self) -> None: item = ModelData(id=uuid4()) - async_model, _ = setupAsync([item]) - sync_model, _ = setupSync([item]) - results = [await async_model.read.one_by_id(str(item.id)), - sync_model.read.one_by_id(str(item.id))] + async_model, _ = setupAsync([item], ModelData) + sync_model, _ = setupSync([item], ModelData) + results = [await async_model.read.one_by_id(item.id), + sync_model.read.one_by_id(item.id)] for result in results: with self.subTest(): @@ -134,31 +139,33 @@ async def test_it_returns_a_single_result(self) -> None: @helpers.async_test async def test_it_raises_exception_if_more_than_one_result(self) -> None: item = ModelData(id=uuid4()) - async_model, _ = setupAsync([item, item]) - sync_model, _ = setupSync([item, item]) + async_model, _ = setupAsync([item, item], ModelData) + sync_model, _ = setupSync([item, item], ModelData) with self.subTest(): with self.assertRaises(UnexpectedMultipleResults): - await async_model.read.one_by_id(str(item.id)) + await async_model.read.one_by_id(item.id) with self.subTest(): with self.assertRaises(UnexpectedMultipleResults): - sync_model.read.one_by_id(str(item.id)) + sync_model.read.one_by_id(item.id) @ helpers.async_test async def test_it_raises_exception_if_no_result_to_return(self) -> None: + empty_async: List[ModelData] = [] + empty_sync: List[ModelData] = [] async_model: AsyncModel[ModelData] sync_model: SyncModel[ModelData] - async_model, _ = setupAsync([]) - sync_model, _ = setupSync([]) + async_model, _ = setupAsync(empty_async, ModelData) + sync_model, _ = setupSync(empty_sync, ModelData) with self.subTest(): with self.assertRaises(NoResultFound): - await async_model.read.one_by_id('id') + await async_model.read.one_by_id(uuid4()) with self.subTest(): with self.assertRaises(NoResultFound): - sync_model.read.one_by_id('id') + sync_model.read.one_by_id(uuid4()) class TestCreateOne(TestCase): @@ -176,8 +183,8 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - async_model, async_client = setupAsync([item]) - sync_model, sync_client = setupSync([item]) + async_model, async_client = setupAsync([item], TestCreateOne.Item) + sync_model, sync_client = setupSync([item], TestCreateOne.Item) await async_model.create.one(item) sync_model.create.one(item) @@ -204,8 +211,8 @@ async def test_it_returns_the_new_record(self) -> None: 'a': 'a', 'b': 'b', }) - async_model, _ = setupAsync([item]) - sync_model, _ = setupSync([item]) + async_model, _ = setupAsync([item], TestCreateOne.Item) + sync_model, _ = setupSync([item], TestCreateOne.Item) results = [await async_model.create.one(item), sync_model.create.one(item)] @@ -230,11 +237,11 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - async_model, async_client = setupAsync([item]) - sync_model, sync_client = setupSync([item]) + async_model, async_client = setupAsync([item], TestUpdateOne.Item) + sync_model, sync_client = setupSync([item], TestUpdateOne.Item) - await async_model.update.one_by_id(str(item.id), {'b': 'c'}) - sync_model.update.one_by_id(str(item.id), {'b': 'c'}) + await async_model.update.one_by_id(item.id, {'b': 'c'}) + sync_model.update.one_by_id(item.id, {'b': 'c'}) async_query_composed = cast( helpers.AsyncMock, async_client.execute_and_return).call_args[0][0] @@ -261,12 +268,12 @@ async def test_it_returns_the_new_record(self) -> None: }) # mock result updated = TestUpdateOne.Item(**{**item.dict(), 'b': 'c'}) - async_model, _ = setupAsync([updated]) - sync_model, _ = setupSync([updated]) + async_model, _ = setupAsync([updated], TestUpdateOne.Item) + sync_model, _ = setupSync([updated], TestUpdateOne.Item) results = [ - await async_model.update.one_by_id(str(item.id), {'b': 'c'}), - sync_model.update.one_by_id(str(item.id), {'b': 'c'}) + await async_model.update.one_by_id(item.id, {'b': 'c'}), + sync_model.update.one_by_id(item.id, {'b': 'c'}) ] for result in results: @@ -289,8 +296,8 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None: 'a': 'a', 'b': 'b', }) - async_model, async_client = setupAsync([item]) - sync_model, sync_client = setupSync([item]) + async_model, async_client = setupAsync([item], TestDeleteOneById.Item) + sync_model, sync_client = setupSync([item], TestDeleteOneById.Item) await async_model.delete.one_by_id(str(item.id)) sync_model.delete.one_by_id(str(item.id)) @@ -317,8 +324,8 @@ async def test_it_returns_the_deleted_record(self) -> None: 'a': 'a', 'b': 'b', }) - async_model, _ = setupAsync([item]) - sync_model, _ = setupSync([item]) + async_model, _ = setupAsync([item], TestDeleteOneById.Item) + sync_model, _ = setupSync([item], TestDeleteOneById.Item) results = [await async_model.delete.one_by_id(str(item.id)), sync_model.delete.one_by_id(str(item.id))] @@ -348,8 +355,8 @@ class AsyncExtendedModel(AsyncModel[Item]): read: AsyncReadExtended def __init__(self, client: AsyncClient) -> None: - super().__init__(client, 'extended_model') - self.read = AsyncReadExtended(self.client, self.table) + super().__init__(client, 'extended_model', Item) + self.read = AsyncReadExtended(self.client, self.table, Item) class SyncReadExtended(SyncRead[Item]): """Extending Read with additional query.""" @@ -362,11 +369,11 @@ class SyncExtendedModel(SyncModel[Item]): read: SyncReadExtended def __init__(self, client: SyncClient) -> None: - super().__init__(client, 'extended_model') - self.read = SyncReadExtended(self.client, self.table) + super().__init__(client, 'extended_model', Item) + self.read = SyncReadExtended(self.client, self.table, Item) - _, async_client = setupAsync([Item(**{"id": uuid4()})]) - _, sync_client = setupSync([Item(**{"id": uuid4()})]) + _, async_client = setupAsync([Item(**{"id": uuid4()})], Item) + _, sync_client = setupSync([Item(**{"id": uuid4()})], Item) self.models = [AsyncExtendedModel(async_client), SyncExtendedModel(sync_client)]