From 62a1968ab9cef625d9a2f34a50eaf3d6f4a4fc08 Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Tue, 24 Dec 2024 20:58:04 +0000 Subject: [PATCH 1/2] query row is attrdict --- apswutils/db.py | 13 +++++++++---- nbs/index.ipynb | 5 +++-- tests/test_query.py | 12 +++++++++++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/apswutils/db.py b/apswutils/db.py index bd29dfc..eb1e096 100644 --- a/apswutils/db.py +++ b/apswutils/db.py @@ -8,6 +8,7 @@ from functools import cache import contextlib, datetime, decimal, inspect, itertools, json, os, pathlib, re, secrets, textwrap, binascii, uuid, logging import apsw.ext, apsw.bestpractice +from fastcore.basics import AttrDict logger = logging.getLogger('apsw') logger.setLevel(logging.ERROR) @@ -414,10 +415,14 @@ def query( parameters, or a dictionary for ``where id = :id`` """ cursor = self.execute(sql, tuple(params or tuple())) - try: columns = [c[0] for c in cursor.description] - except apsw.ExecutionCompleteError: return [] - for row in cursor: - yield dict(zip(columns, row)) + # Row results will be dataclasses + cursor.row_trace = apsw.ext.DataClassRowFactory( + dataclass_kwargs={"frozen": True} + ) + # Yield attrdict so rows can be accessed as row.id or row['id'] + for row in cursor: yield AttrDict(row.__dict__) + # Cleanup the row_trace + cursor.row_trace = None def execute( self, sql: str, parameters: Optional[Union[Iterable, dict]] = None diff --git a/nbs/index.ipynb b/nbs/index.ipynb index ec28c16..3ae1b0a 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -556,7 +556,8 @@ "- WAL is the default\n", "- Setting `Database(recursive_triggers=False)` works as expected \n", "- Primary keys must be set on a table for it to be a target of a foreign key\n", - "- Errors have been changed minimally, future PRs will change them incrementally" + "- OperationError is gone, errors are more precise\n", + "- Database.query rows are AttrDict, which have Python `dict` equality but can also use dot notation" ] }, { @@ -573,7 +574,7 @@ "|Old/sqlite3/dbapi|New/APSW|Reason|\n", "|---|---|---|\n", "|IntegrityError|apsw.ConstraintError|Caused due to SQL transformation blocked on database constraints|\n", - "|sqlite3.dbapi2.OperationalError|apsw.Error|General error, OperationalError is now proxied to apsw.Error|\n", + "|sqlite3.dbapi2.OperationalError|apsw.Error|General error\n", "|sqlite3.dbapi2.OperationalError|apsw.SQLError|When an error is due to flawed SQL statements|\n", "|sqlite3.ProgrammingError|apsw.ConnectionClosedError|Caused by an improperly closed database file|\n" ] diff --git a/tests/test_query.py b/tests/test_query.py index 9d87460..f5da86a 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,4 +1,5 @@ import types +from fastcore.basics import AttrDict # Basic query tests @@ -9,7 +10,6 @@ def test_query(fresh_db): assert isinstance(results, types.GeneratorType) assert list(results) == [{"name": "Pancakes"}, {"name": "Cleo"}] - def test_execute_returning_dicts(fresh_db): # Like db.query() but returns a list, included for backwards compatibility # see https://github.com/simonw/sqlite-utils/issues/290 @@ -25,3 +25,13 @@ def test_query_no_update(fresh_db): results = fresh_db.query("update message set msg_type='note' where msg_type='md'") assert list(results) == [] assert list(fresh_db["message"].rows) == [{"msg_type": "greeting", "content": "hello"}] + +def test_query_attr_dict(fresh_db): + fresh_db["dogs"].insert_all([{"name": "Cleo"}, {"name": "Pancakes"}]) + results = fresh_db.query("select * from dogs order by name desc") + assert isinstance(results, types.GeneratorType) + results = list(results) + assert results == [{"name": "Pancakes"}, {"name": "Cleo"}] + assert results[0] == {"name": "Pancakes"} + obj = AttrDict({"name": "Pancakes"}) +assert results[0].name == obj.name \ No newline at end of file From b3b1bebc87321cf51e2a082330e22aac4124978f Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Wed, 25 Dec 2024 08:36:39 -0500 Subject: [PATCH 2/2] Add AttrDictRowFactory --- apswutils/_modidx.py | 5 ++- apswutils/db.py | 7 ++-- apswutils/ext.py | 38 ++++++++++++++++++++ nbs/ext.ipynb | 83 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_query.py | 2 +- 5 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 apswutils/ext.py create mode 100644 nbs/ext.ipynb diff --git a/apswutils/_modidx.py b/apswutils/_modidx.py index f35f5c2..9567870 100644 --- a/apswutils/_modidx.py +++ b/apswutils/_modidx.py @@ -5,4 +5,7 @@ 'doc_host': 'https://AnswerDotAI.github.io', 'git_url': 'https://github.com/AnswerDotAI/apswutils', 'lib_path': 'apswutils'}, - 'syms': {'apswutils.db': {}, 'apswutils.utils': {}}} + 'syms': { 'apswutils.db': {}, + 'apswutils.ext': { 'apswutils.ext.AttrDictRowFactory': ('ext.html#attrdictrowfactory', 'apswutils/ext.py'), + 'apswutils.ext.AttrDictRowFactory.__call__': ('ext.html#attrdictrowfactory.__call__', 'apswutils/ext.py')}, + 'apswutils.utils': {}}} diff --git a/apswutils/db.py b/apswutils/db.py index eb1e096..d0c2dd2 100644 --- a/apswutils/db.py +++ b/apswutils/db.py @@ -2,6 +2,7 @@ __all__ = ['Database', 'Queryable', 'Table', 'View'] from .utils import chunks, hash_record, OperationalError, suggest_column_types, types_for_column_types, column_affinity, find_spatialite +from .ext import AttrDictRowFactory from collections import namedtuple from collections.abc import Mapping from typing import cast, Any, Callable, Dict, Generator, Iterable, Union, Optional, List, Tuple, Iterator @@ -416,11 +417,9 @@ def query( """ cursor = self.execute(sql, tuple(params or tuple())) # Row results will be dataclasses - cursor.row_trace = apsw.ext.DataClassRowFactory( - dataclass_kwargs={"frozen": True} - ) + cursor.row_trace = AttrDictRowFactory() # Yield attrdict so rows can be accessed as row.id or row['id'] - for row in cursor: yield AttrDict(row.__dict__) + for row in cursor: yield row # Cleanup the row_trace cursor.row_trace = None diff --git a/apswutils/ext.py b/apswutils/ext.py new file mode 100644 index 0000000..a8f92bb --- /dev/null +++ b/apswutils/ext.py @@ -0,0 +1,38 @@ +"""Extensions to improve what apswutils can do""" + +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/ext.ipynb. + +# %% auto 0 +__all__ = ['AttrDictRowFactory'] + +# %% ../nbs/ext.ipynb +from fastcore.basics import AttrDict + +# %% ../nbs/ext.ipynb +class AttrDictRowFactory: + """Returns each row as a :mod:`dataclass `, accessible by column name. + + To use set an instance as :attr:`Connection.row_trace + ` to affect all :class:`cursors + `, or on a specific cursor:: + + connection.row_trace = apsw.ext.AttrDictRowFactory() + for row in connection.execute("SELECT title, sum(orders) AS total, ..."): + # You can now access by attribute + print(row.title, row.total) + # You can now access by dict notation + print(row['title'], row['total']) + # Equality is as if rows were dicts + assert row == {'title': 'AnswerDotAI', 'total': 8000000} + + You can use as many instances of this class as you want, each across as many + :class:`connections ` as you want. + """ + + def __call__(self, cursor, row) -> AttrDict: + """What the row tracer calls + + Returns an AttrDict representation of each row + """ + columns = [d[0] for d in cursor.get_description()] + return AttrDict(dict(zip(columns, row))) diff --git a/nbs/ext.ipynb b/nbs/ext.ipynb new file mode 100644 index 0000000..0f6dbfe --- /dev/null +++ b/nbs/ext.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp ext" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ext\n", + "> Extensions to improve what apswutils can do" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from fastcore.basics import AttrDict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AttrDictRowFactory:\n", + " \"\"\"Returns each row as a :mod:`dataclass `, accessible by column name.\n", + "\n", + " To use set an instance as :attr:`Connection.row_trace\n", + " ` to affect all :class:`cursors\n", + " `, or on a specific cursor::\n", + "\n", + " connection.row_trace = apsw.ext.AttrDictRowFactory()\n", + " for row in connection.execute(\"SELECT title, sum(orders) AS total, ...\"):\n", + " # You can now access by attribute\n", + " print(row.title, row.total)\n", + " # You can now access by dict notation\n", + " print(row['title'], row['total'])\n", + " # Equality is as if rows were dicts\n", + " assert row == {'title': 'AnswerDotAI', 'total': 8000000}\n", + "\n", + " You can use as many instances of this class as you want, each across as many\n", + " :class:`connections ` as you want.\n", + " \"\"\"\n", + "\n", + " def __call__(self, cursor, row) -> AttrDict:\n", + " \"\"\"What the row tracer calls\n", + "\n", + " Returns an AttrDict representation of each row\n", + " \"\"\"\n", + " columns = [d[0] for d in cursor.get_description()]\n", + " return AttrDict(dict(zip(columns, row)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_query.py b/tests/test_query.py index f5da86a..6906880 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -34,4 +34,4 @@ def test_query_attr_dict(fresh_db): assert results == [{"name": "Pancakes"}, {"name": "Cleo"}] assert results[0] == {"name": "Pancakes"} obj = AttrDict({"name": "Pancakes"}) -assert results[0].name == obj.name \ No newline at end of file + assert results[0].name == obj.name \ No newline at end of file