From 94084bcdc3aea3863ea050561928c7f6bc40ae6d Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 20 May 2024 16:19:10 -0400 Subject: [PATCH] Added Bit class for Psycopg 2 --- pgvector/psycopg2/__init__.py | 5 ++++- pgvector/psycopg2/bit.py | 14 ++++++++++++++ tests/test_psycopg2.py | 4 ++-- 3 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 pgvector/psycopg2/bit.py diff --git a/pgvector/psycopg2/__init__.py b/pgvector/psycopg2/__init__.py index 9d0473a..ebab3d1 100644 --- a/pgvector/psycopg2/__init__.py +++ b/pgvector/psycopg2/__init__.py @@ -1,8 +1,9 @@ import psycopg2 +from .bit import register_bit_info from .halfvec import register_halfvec_info from .sparsevec import register_sparsevec_info from .vector import register_vector_info -from ..utils import SparseVector +from ..utils import Bit, SparseVector __all__ = ['register_vector'] @@ -16,6 +17,8 @@ def register_vector(conn_or_curs=None): except psycopg2.errors.UndefinedObject: raise psycopg2.ProgrammingError('vector type not found in the database') + register_bit_info() + try: cur.execute('SELECT NULL::halfvec') register_halfvec_info(cur.description[0][1]) diff --git a/pgvector/psycopg2/bit.py b/pgvector/psycopg2/bit.py new file mode 100644 index 0000000..5b10d6d --- /dev/null +++ b/pgvector/psycopg2/bit.py @@ -0,0 +1,14 @@ +from psycopg2.extensions import adapt, register_adapter +from ..utils import Bit + + +class BitAdapter: + def __init__(self, value): + self._value = value + + def getquoted(self): + return adapt(Bit.to_db(self._value)).getquoted() + + +def register_bit_info(): + register_adapter(Bit, BitAdapter) diff --git a/tests/test_psycopg2.py b/tests/test_psycopg2.py index cc7e5c0..cf0a2b1 100644 --- a/tests/test_psycopg2.py +++ b/tests/test_psycopg2.py @@ -1,5 +1,5 @@ import numpy as np -from pgvector.psycopg2 import register_vector, SparseVector +from pgvector.psycopg2 import register_vector, Bit, SparseVector import psycopg2 conn = psycopg2.connect(dbname='pgvector_python_test') @@ -37,7 +37,7 @@ def test_halfvec(self): assert res[1][0] is None def test_bit(self): - embedding = '101' + embedding = Bit('101') cur.execute('INSERT INTO psycopg2_items (binary_embedding) VALUES (%s), (NULL)', (embedding,)) cur.execute('SELECT binary_embedding FROM psycopg2_items ORDER BY id')