From e350968b8ef47c84b70ab727f54790442ae8d958 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 9 Feb 2025 23:31:32 -0800 Subject: [PATCH 1/2] Changed vector type to return Vector class instead of NumPy array [skip ci] --- pgvector/django/vector.py | 29 ++++++---------------- pgvector/vector.py | 8 +++--- tests/test_asyncpg.py | 24 +++++++++--------- tests/test_django.py | 21 ++++++++-------- tests/test_peewee.py | 13 +++++----- tests/test_psycopg.py | 38 ++++++++++++++-------------- tests/test_psycopg2.py | 22 ++++++++--------- tests/test_sqlalchemy.py | 52 +++++++++++++++++++-------------------- tests/test_sqlmodel.py | 16 ++++++------ 9 files changed, 99 insertions(+), 124 deletions(-) diff --git a/pgvector/django/vector.py b/pgvector/django/vector.py index 861cfde..09173fa 100644 --- a/pgvector/django/vector.py +++ b/pgvector/django/vector.py @@ -1,6 +1,5 @@ from django import forms from django.db.models import Field -import numpy as np from .. import Vector @@ -28,9 +27,12 @@ def from_db_value(self, value, expression, connection): return Vector._from_db(value) def to_python(self, value): - if isinstance(value, list): - return np.array(value, dtype=np.float32) - return Vector._from_db(value) + if value is None or isinstance(value, Vector): + return value + elif isinstance(value, str): + return Vector._from_db(value) + else: + return Vector(value) def get_prep_value(self, value): return Vector._to_db(value) @@ -38,35 +40,20 @@ def get_prep_value(self, value): def value_to_string(self, obj): return self.get_prep_value(self.value_from_object(obj)) - def validate(self, value, model_instance): - if isinstance(value, np.ndarray): - value = value.tolist() - super().validate(value, model_instance) - - def run_validators(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() - super().run_validators(value) - def formfield(self, **kwargs): return super().formfield(form_class=VectorFormField, **kwargs) class VectorWidget(forms.TextInput): def format_value(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() + if isinstance(value, Vector): + value = value.to_list() return super().format_value(value) class VectorFormField(forms.CharField): widget = VectorWidget - def has_changed(self, initial, data): - if isinstance(initial, np.ndarray): - initial = initial.tolist() - return super().has_changed(initial, data) - def to_python(self, value): if isinstance(value, str) and value == '': return None diff --git a/pgvector/vector.py b/pgvector/vector.py index ebbcafd..7b0304a 100644 --- a/pgvector/vector.py +++ b/pgvector/vector.py @@ -70,14 +70,14 @@ def _to_db_binary(cls, value): @classmethod def _from_db(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_text(value).to_numpy().astype(np.float32) + return cls.from_text(value) @classmethod def _from_db_binary(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_binary(value).to_numpy().astype(np.float32) + return cls.from_binary(value) diff --git a/tests/test_asyncpg.py b/tests/test_asyncpg.py index 3c36048..e9aa836 100644 --- a/tests/test_asyncpg.py +++ b/tests/test_asyncpg.py @@ -1,6 +1,6 @@ import asyncpg import numpy as np -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.asyncpg import register_vector import pytest @@ -15,12 +15,11 @@ async def test_vector(self): await register_vector(conn) - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding) - assert res[0]['embedding'].dtype == np.float32 + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -38,11 +37,11 @@ async def test_halfvec(self): await register_vector(conn) - embedding = [1.5, 2, 3] + embedding = HalfVector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert res[0]['embedding'].to_list() == [1.5, 2, 3] + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -87,7 +86,7 @@ async def test_sparsevec(self): await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert res[0]['embedding'].to_list() == [1.5, 2, 3] + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -105,12 +104,12 @@ async def test_vector_array(self): await register_vector(conn) - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings[0], embeddings[1]) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embeddings'][0], embeddings[0]) - assert np.array_equal(res[0]['embeddings'][1], embeddings[1]) + assert res[0]['embeddings'][0] == embeddings[0] + assert res[0]['embeddings'][1] == embeddings[1] await conn.close() @@ -126,10 +125,9 @@ async def init(conn): await conn.execute('DROP TABLE IF EXISTS asyncpg_items') await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))') - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding) - assert res[0]['embedding'].dtype == np.float32 + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None diff --git a/tests/test_django.py b/tests/test_django.py index 65082a3..7182e3a 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -12,7 +12,7 @@ import numpy as np import os import pgvector.django -from pgvector import HalfVector, SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance from unittest import mock @@ -165,12 +165,11 @@ def setup_method(self): def test_vector(self): Item(id=1, embedding=[1, 2, 3]).save() item = Item.objects.get(pk=1) - assert np.array_equal(item.embedding, np.array([1, 2, 3])) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() - distance = L2Distance('embedding', [1, 1, 1]) + distance = L2Distance('embedding', Vector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, sqrt(3)] @@ -293,7 +292,7 @@ def test_vector_avg(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] - assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] @@ -301,7 +300,7 @@ def test_vector_sum(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] - assert np.array_equal(sum, np.array([5, 7, 9])) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] @@ -309,7 +308,7 @@ def test_halfvec_avg(self): Item(half_embedding=[1, 2, 3]).save() Item(half_embedding=[4, 5, 6]).save() avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] - assert avg.to_list() == [2.5, 3.5, 4.5] + assert avg == HalfVector([2.5, 3.5, 4.5]) def test_halfvec_sum(self): sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] @@ -317,7 +316,7 @@ def test_halfvec_sum(self): Item(half_embedding=[1, 2, 3]).save() Item(half_embedding=[4, 5, 6]).save() sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] - assert sum.to_list() == [5, 7, 9] + assert sum == HalfVector([5, 7, 9]) def test_serialization(self): create_items() @@ -347,7 +346,7 @@ def test_vector_form_save(self): assert form.has_changed() assert form.is_valid() assert form.save() - assert [4, 5, 6] == Item.objects.get(pk=1).embedding.tolist() + assert [4, 5, 6] == Item.objects.get(pk=1).embedding.to_list() def test_vector_form_save_missing(self): Item(id=1).save() @@ -465,8 +464,8 @@ def test_vector_array(self): # this fails if the driver does not cast arrays item = Item.objects.get(pk=1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0].to_list() == [1, 2, 3] + assert item.embeddings[1].to_list() == [4, 5, 6] def test_double_array(self): Item(id=1, double_embedding=[1, 1, 1]).save() diff --git a/tests/test_peewee.py b/tests/test_peewee.py index 670d880..42b7787 100644 --- a/tests/test_peewee.py +++ b/tests/test_peewee.py @@ -1,7 +1,7 @@ from math import sqrt import numpy as np from peewee import Model, PostgresqlDatabase, fn -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField db = PostgresqlDatabase('pgvector_python_test') @@ -43,8 +43,7 @@ def setup_method(self): def test_vector(self): Item.create(id=1, embedding=[1, 2, 3]) item = Item.get_by_id(1) - assert np.array_equal(item.embedding, np.array([1, 2, 3])) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -170,7 +169,7 @@ def test_vector_avg(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar() - assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() @@ -178,7 +177,7 @@ def test_vector_sum(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() - assert np.array_equal(sum, np.array([5, 7, 9])) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() @@ -186,7 +185,7 @@ def test_halfvec_avg(self): Item.create(half_embedding=[1, 2, 3]) Item.create(half_embedding=[4, 5, 6]) avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() - assert avg.to_list() == [2.5, 3.5, 4.5] + assert avg == HalfVector([2.5, 3.5, 4.5]) def test_halfvec_sum(self): sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() @@ -194,7 +193,7 @@ def test_halfvec_sum(self): Item.create(half_embedding=[1, 2, 3]) Item.create(half_embedding=[4, 5, 6]) sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() - assert sum.to_list() == [5, 7, 9] + assert sum == HalfVector([5, 7, 9]) def test_get_or_create(self): Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]}) diff --git a/tests/test_psycopg.py b/tests/test_psycopg.py index 6a9d0b7..af658b2 100644 --- a/tests/test_psycopg.py +++ b/tests/test_psycopg.py @@ -23,19 +23,18 @@ def test_vector(self): conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) res = conn.execute('SELECT embedding FROM psycopg_items ORDER BY id').fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_binary_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_text_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_binary_format_correct(self): embedding = np.array([1.5, 2, 3]) @@ -46,23 +45,23 @@ def test_vector_text_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([3, 2, 1.5])) + assert res == Vector([3, 2, 1.5]) def test_vector_binary_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %b::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([3, 2, 1.5])) + assert res == Vector([3, 2, 1.5]) def test_vector_class_binary_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, np.array([1.5, 2, 3])) + assert res == embedding def test_vector_class_text_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([1.5, 2, 3])) + assert res == embedding def test_halfvec(self): embedding = HalfVector([1.5, 2, 3]) @@ -156,7 +155,7 @@ def test_text_copy_to(self): assert row[1] == "[1.5,2,3]" def test_binary_copy_to(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() @@ -166,23 +165,23 @@ def test_binary_copy_to(self): assert HalfVector.from_binary(row[1]).to_list() == [1.5, 2, 3] def test_binary_copy_to_set_types(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding) TO STDOUT WITH (FORMAT BINARY)") as copy: copy.set_types(['vector', 'halfvec']) for row in copy.rows(): - assert np.array_equal(row[0], embedding) - assert row[1].to_list() == [1.5, 2, 3] + assert row[0] == embedding + assert row[1] == half_embedding def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] conn.execute('INSERT INTO psycopg_items (embeddings) VALUES (%s)', (embeddings,)) res = conn.execute('SELECT embeddings FROM psycopg_items ORDER BY id').fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_pool(self): def configure(conn): @@ -192,7 +191,7 @@ def configure(conn): with pool.connection() as conn: res = conn.execute("SELECT '[1,2,3]'::vector").fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) pool.close() @@ -206,14 +205,13 @@ async def test_async(self): await register_vector_async(conn) - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) async with conn.cursor() as cur: await cur.execute('SELECT * FROM psycopg_items ORDER BY id') res = await cur.fetchall() - assert np.array_equal(res[0][1], embedding) - assert res[0][1].dtype == np.float32 + assert res[0][1] == embedding assert res[1][1] is None @pytest.mark.asyncio @@ -228,6 +226,6 @@ async def configure(conn): async with conn.cursor() as cur: await cur.execute("SELECT '[1,2,3]'::vector") res = await cur.fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) await pool.close() diff --git a/tests/test_psycopg2.py b/tests/test_psycopg2.py index 1994c87..8ef9911 100644 --- a/tests/test_psycopg2.py +++ b/tests/test_psycopg2.py @@ -26,8 +26,7 @@ def test_vector(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_class(self): @@ -36,17 +35,16 @@ def test_vector_class(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding.to_numpy()) - assert res[0][0].dtype == np.float32 + assert res[0][0] == embedding assert res[1][0] is None def test_halfvec(self): - embedding = [1.5, 2, 3] + embedding = HalfVector([1.5, 2, 3]) cur.execute('INSERT INTO psycopg2_items (half_embedding) VALUES (%s), (NULL)', (embedding,)) cur.execute('SELECT half_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0].to_list() == [1.5, 2, 3] + assert res[0][0] == embedding assert res[1][0] is None def test_bit(self): @@ -55,7 +53,7 @@ def test_bit(self): cur.execute('SELECT binary_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0] == '101' + assert res[0][0] == embedding assert res[1][0] is None def test_sparsevec(self): @@ -64,17 +62,17 @@ def test_sparsevec(self): cur.execute('SELECT sparse_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0].to_list() == [1.5, 2, 3] + assert res[0][0] == embedding assert res[1][0] is None def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] cur.execute('INSERT INTO psycopg2_items (embeddings) VALUES (%s::vector[])', (embeddings,)) cur.execute('SELECT embeddings FROM psycopg2_items ORDER BY id') res = cur.fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_halfvec_array(self): embeddings = [HalfVector([1.5, 2, 3]), HalfVector([4.5, 5, 6])] @@ -122,7 +120,7 @@ def test_pool(self): cur = conn.cursor() cur.execute("SELECT '[1,2,3]'::vector") res = cur.fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) finally: pool.putconn(conn) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 052edd7..f7203f1 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -1,7 +1,7 @@ import asyncpg import numpy as np import os -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum import pytest from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY @@ -190,10 +190,8 @@ def test_orm(self, engine): assert items[0].id % 3 == 1 assert items[1].id % 3 == 2 assert items[2].id % 3 == 0 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self, engine): @@ -201,7 +199,7 @@ def test_vector(self, engine): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self, engine): create_items() @@ -256,7 +254,7 @@ def test_halfvec(self, engine): session.add(Item(id=1, half_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.half_embedding.to_list() == [1, 2, 3] + assert item.half_embedding == HalfVector([1, 2, 3]) def test_halfvec_l2_distance(self, engine): create_items() @@ -348,7 +346,7 @@ def test_sparsevec(self, engine): session.add(Item(id=1, sparse_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.sparse_embedding.to_list() == [1, 2, 3] + assert item.sparse_embedding == SparseVector([1, 2, 3]) def test_sparsevec_l2_distance(self, engine): create_items() @@ -429,7 +427,7 @@ def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(avg(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_avg_orm(self, engine): with Session(engine) as session: @@ -438,7 +436,7 @@ def test_avg_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_sum(self, engine): with Session(engine) as session: @@ -447,7 +445,7 @@ def test_sum(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(sum(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_sum_orm(self, engine): with Session(engine) as session: @@ -456,7 +454,7 @@ def test_sum_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_bad_dimensions(self, engine): item = Item(embedding=[1, 2]) @@ -509,7 +507,7 @@ def test_automap(self, engine): with Session(engine) as session: session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}]) item = session.query(AutoItem).first() - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_half_precision(self, engine): create_items() @@ -541,8 +539,8 @@ def test_vector_array(self, engine): # this fails if the driver does not cast arrays item = session.get(Item, 1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0] == Vector([1, 2, 3]) + assert item.embeddings[1] == Vector([4, 5, 6]) def test_halfvec_array(self, engine): with Session(engine) as session: @@ -551,8 +549,8 @@ def test_halfvec_array(self, engine): # this fails if the driver does not cast arrays item = session.get(Item, 1) - assert item.half_embeddings[0].to_list() == [1, 2, 3] - assert item.half_embeddings[1].to_list() == [4, 5, 6] + assert item.half_embeddings[0] == HalfVector([1, 2, 3]) + assert item.half_embeddings[1] == HalfVector([4, 5, 6]) @pytest.mark.parametrize('engine', async_engines) @@ -566,10 +564,10 @@ async def test_vector(self, engine): async with async_session() as session: async with session.begin(): - embedding = np.array([1, 2, 3]) + embedding = Vector([1, 2, 3]) session.add(Item(id=1, embedding=embedding)) item = await session.get(Item, 1) - assert np.array_equal(item.embedding, embedding) + assert item.embedding == embedding await engine.dispose() @@ -579,10 +577,10 @@ async def test_halfvec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = HalfVector([1, 2, 3]) session.add(Item(id=1, half_embedding=embedding)) item = await session.get(Item, 1) - assert item.half_embedding.to_list() == embedding + assert item.half_embedding == embedding await engine.dispose() @@ -605,10 +603,10 @@ async def test_sparsevec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = SparseVector([1, 2, 3]) session.add(Item(id=1, sparse_embedding=embedding)) item = await session.get(Item, 1) - assert item.sparse_embedding.to_list() == embedding + assert item.sparse_embedding == embedding await engine.dispose() @@ -621,7 +619,7 @@ async def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = await session.scalars(select(avg(Item.embedding))) - assert res.first().tolist() == [2.5, 3.5, 4.5] + assert res.first() == Vector([2.5, 3.5, 4.5]) await engine.dispose() @@ -637,9 +635,9 @@ async def test_vector_array(self, engine): async with async_session() as session: async with session.begin(): - session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) + session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])])) item = await session.get(Item, 1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0] == Vector([1, 2, 3]) + assert item.embeddings[1] == Vector([4, 5, 6]) await engine.dispose() diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index b0e8ccd..8f3e268 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -1,5 +1,5 @@ import numpy as np -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum import pytest from sqlalchemy.exc import StatementError @@ -65,10 +65,8 @@ def test_orm(self): assert items[0].id == 1 assert items[1].id == 2 assert items[2].id == 3 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self): @@ -76,7 +74,7 @@ def test_vector(self): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -107,7 +105,7 @@ def test_halfvec(self): session.add(Item(id=1, half_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.half_embedding.to_list() == [1, 2, 3] + assert item.half_embedding == HalfVector([1, 2, 3]) def test_halfvec_l2_distance(self): create_items() @@ -202,7 +200,7 @@ def test_vector_avg(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): with Session(engine) as session: @@ -211,7 +209,7 @@ def test_vector_sum(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_halfvec_avg(self): with Session(engine) as session: From 3f44e8aa397de40ba64c54da166fee7462eb439f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 10 Feb 2025 19:35:44 -0800 Subject: [PATCH 2/2] Use consistent style [skip ci] --- tests/test_django.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_django.py b/tests/test_django.py index ae42d5b..eff5a98 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -12,7 +12,7 @@ import numpy as np import os import pgvector.django -from pgvector import Vector, HalfVector, SparseVector +from pgvector import HalfVector, SparseVector, Vector from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance from unittest import mock