diff --git a/pgvector/django/vector.py b/pgvector/django/vector.py index 861cfde..09173fa 100644 --- a/pgvector/django/vector.py +++ b/pgvector/django/vector.py @@ -1,6 +1,5 @@ from django import forms from django.db.models import Field -import numpy as np from .. import Vector @@ -28,9 +27,12 @@ def from_db_value(self, value, expression, connection): return Vector._from_db(value) def to_python(self, value): - if isinstance(value, list): - return np.array(value, dtype=np.float32) - return Vector._from_db(value) + if value is None or isinstance(value, Vector): + return value + elif isinstance(value, str): + return Vector._from_db(value) + else: + return Vector(value) def get_prep_value(self, value): return Vector._to_db(value) @@ -38,35 +40,20 @@ def get_prep_value(self, value): def value_to_string(self, obj): return self.get_prep_value(self.value_from_object(obj)) - def validate(self, value, model_instance): - if isinstance(value, np.ndarray): - value = value.tolist() - super().validate(value, model_instance) - - def run_validators(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() - super().run_validators(value) - def formfield(self, **kwargs): return super().formfield(form_class=VectorFormField, **kwargs) class VectorWidget(forms.TextInput): def format_value(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() + if isinstance(value, Vector): + value = value.to_list() return super().format_value(value) class VectorFormField(forms.CharField): widget = VectorWidget - def has_changed(self, initial, data): - if isinstance(initial, np.ndarray): - initial = initial.tolist() - return super().has_changed(initial, data) - def to_python(self, value): if isinstance(value, str) and value == '': return None diff --git a/pgvector/vector.py b/pgvector/vector.py index ebbcafd..7b0304a 100644 --- a/pgvector/vector.py +++ b/pgvector/vector.py @@ -70,14 +70,14 @@ def _to_db_binary(cls, value): @classmethod def _from_db(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_text(value).to_numpy().astype(np.float32) + return cls.from_text(value) @classmethod def _from_db_binary(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_binary(value).to_numpy().astype(np.float32) + return cls.from_binary(value) diff --git a/tests/test_asyncpg.py b/tests/test_asyncpg.py index 34d66a1..093fe1a 100644 --- a/tests/test_asyncpg.py +++ b/tests/test_asyncpg.py @@ -20,9 +20,8 @@ async def test_vector(self): await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding.to_numpy()) - assert res[0]['embedding'].dtype == np.float32 - assert np.array_equal(res[1]['embedding'], embedding2) + assert res[0]['embedding'] == embedding + assert res[1]['embedding'] == Vector(embedding2) assert res[2]['embedding'] is None # ensures binary format is correct @@ -116,10 +115,8 @@ async def test_vector_array(self): await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings2[0], embeddings2[1]) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embeddings'][0], embeddings[0].to_numpy()) - assert np.array_equal(res[0]['embeddings'][1], embeddings[1].to_numpy()) - assert np.array_equal(res[1]['embeddings'][0], embeddings2[0]) - assert np.array_equal(res[1]['embeddings'][1], embeddings2[1]) + assert res[0]['embeddings'] == embeddings + assert res[1]['embeddings'] == [Vector(e) for e in embeddings2] await conn.close() @@ -140,7 +137,6 @@ async def init(conn): await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding.to_numpy()) - assert res[0]['embedding'].dtype == np.float32 - assert np.array_equal(res[1]['embedding'], embedding2) + assert res[0]['embedding'] == embedding + assert res[1]['embedding'] == Vector(embedding2) assert res[2]['embedding'] is None diff --git a/tests/test_django.py b/tests/test_django.py index 7a8a6eb..eff5a98 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -12,7 +12,7 @@ import numpy as np import os import pgvector.django -from pgvector import HalfVector, SparseVector +from pgvector import HalfVector, SparseVector, Vector from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance from unittest import mock @@ -165,12 +165,11 @@ def setup_method(self): def test_vector(self): Item(id=1, embedding=[1, 2, 3]).save() item = Item.objects.get(pk=1) - assert np.array_equal(item.embedding, [1, 2, 3]) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() - distance = L2Distance('embedding', [1, 1, 1]) + distance = L2Distance('embedding', Vector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, sqrt(3)] @@ -293,7 +292,7 @@ def test_vector_avg(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] - assert np.array_equal(avg, [2.5, 3.5, 4.5]) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] @@ -301,7 +300,7 @@ def test_vector_sum(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] - assert np.array_equal(sum, [5, 7, 9]) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] @@ -347,7 +346,7 @@ def test_vector_form_save(self): assert form.has_changed() assert form.is_valid() assert form.save() - assert np.array_equal(Item.objects.get(pk=1).embedding, [4, 5, 6]) + assert Item.objects.get(pk=1).embedding == Vector([4, 5, 6]) def test_vector_form_save_missing(self): Item(id=1).save() @@ -465,8 +464,7 @@ def test_vector_array(self): # this fails if the driver does not cast arrays item = Item.objects.get(pk=1) - assert np.array_equal(item.embeddings[0], [1, 2, 3]) - assert np.array_equal(item.embeddings[1], [4, 5, 6]) + assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])] def test_double_array(self): Item(id=1, double_embedding=[1, 1, 1]).save() diff --git a/tests/test_peewee.py b/tests/test_peewee.py index 64fc009..826608a 100644 --- a/tests/test_peewee.py +++ b/tests/test_peewee.py @@ -1,7 +1,7 @@ from math import sqrt import numpy as np from peewee import Model, PostgresqlDatabase, fn -from pgvector import HalfVector, SparseVector +from pgvector import HalfVector, SparseVector, Vector from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField db = PostgresqlDatabase('pgvector_python_test') @@ -43,8 +43,7 @@ def setup_method(self): def test_vector(self): Item.create(id=1, embedding=[1, 2, 3]) item = Item.get_by_id(1) - assert np.array_equal(item.embedding, [1, 2, 3]) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -170,7 +169,7 @@ def test_vector_avg(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar() - assert np.array_equal(avg, [2.5, 3.5, 4.5]) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() @@ -178,7 +177,7 @@ def test_vector_sum(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() - assert np.array_equal(sum, [5, 7, 9]) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() diff --git a/tests/test_psycopg.py b/tests/test_psycopg.py index 698b34f..6d921f6 100644 --- a/tests/test_psycopg.py +++ b/tests/test_psycopg.py @@ -23,19 +23,18 @@ def test_vector(self): conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) res = conn.execute('SELECT embedding FROM psycopg_items ORDER BY id').fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_binary_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_text_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_binary_format_correct(self): embedding = np.array([1.5, 2, 3]) @@ -46,23 +45,23 @@ def test_vector_text_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, [3, 2, 1.5]) + assert res == Vector([3, 2, 1.5]) def test_vector_binary_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %b::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, [3, 2, 1.5]) + assert res == Vector([3, 2, 1.5]) def test_vector_class_binary_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, [1.5, 2, 3]) + assert res == embedding def test_vector_class_text_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, [1.5, 2, 3]) + assert res == embedding def test_halfvec(self): embedding = HalfVector([1.5, 2, 3]) @@ -146,33 +145,33 @@ def test_text_copy_to(self): assert row[1] == "[1.5,2,3]" def test_binary_copy_to(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding) TO STDOUT WITH (FORMAT BINARY)") as copy: for row in copy.rows(): - assert np.array_equal(Vector.from_binary(row[0]).to_numpy(), embedding) + assert Vector.from_binary(row[0]) == embedding assert HalfVector.from_binary(row[1]) == half_embedding def test_binary_copy_to_set_types(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding) TO STDOUT WITH (FORMAT BINARY)") as copy: copy.set_types(['vector', 'halfvec']) for row in copy.rows(): - assert np.array_equal(row[0], embedding) + assert row[0] == embedding assert row[1] == half_embedding def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] conn.execute('INSERT INTO psycopg_items (embeddings) VALUES (%s)', (embeddings,)) res = conn.execute('SELECT embeddings FROM psycopg_items ORDER BY id').fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_pool(self): def configure(conn): @@ -182,7 +181,7 @@ def configure(conn): with pool.connection() as conn: res = conn.execute("SELECT '[1,2,3]'::vector").fetchone() - assert np.array_equal(res[0], [1, 2, 3]) + assert res[0] == Vector([1, 2, 3]) pool.close() @@ -196,14 +195,13 @@ async def test_async(self): await register_vector_async(conn) - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) async with conn.cursor() as cur: await cur.execute('SELECT * FROM psycopg_items ORDER BY id') res = await cur.fetchall() - assert np.array_equal(res[0][1], embedding) - assert res[0][1].dtype == np.float32 + assert res[0][1] == embedding assert res[1][1] is None @pytest.mark.asyncio @@ -218,6 +216,6 @@ async def configure(conn): async with conn.cursor() as cur: await cur.execute("SELECT '[1,2,3]'::vector") res = await cur.fetchone() - assert np.array_equal(res[0], [1, 2, 3]) + assert res[0] == Vector([1, 2, 3]) await pool.close() diff --git a/tests/test_psycopg2.py b/tests/test_psycopg2.py index 8f56ef5..260a6ba 100644 --- a/tests/test_psycopg2.py +++ b/tests/test_psycopg2.py @@ -26,8 +26,7 @@ def test_vector(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_class(self): @@ -36,17 +35,16 @@ def test_vector_class(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding.to_numpy()) - assert res[0][0].dtype == np.float32 + assert res[0][0] == embedding assert res[1][0] is None def test_halfvec(self): - embedding = [1.5, 2, 3] + embedding = HalfVector([1.5, 2, 3]) cur.execute('INSERT INTO psycopg2_items (half_embedding) VALUES (%s), (NULL)', (embedding,)) cur.execute('SELECT half_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0] == HalfVector([1.5, 2, 3]) + assert res[0][0] == embedding assert res[1][0] is None def test_bit(self): @@ -55,7 +53,7 @@ def test_bit(self): cur.execute('SELECT binary_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0] == '101' + assert res[0][0] == embedding assert res[1][0] is None def test_sparsevec(self): @@ -64,17 +62,17 @@ def test_sparsevec(self): cur.execute('SELECT sparse_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0] == SparseVector([1.5, 2, 3]) + assert res[0][0] == embedding assert res[1][0] is None def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] cur.execute('INSERT INTO psycopg2_items (embeddings) VALUES (%s::vector[])', (embeddings,)) cur.execute('SELECT embeddings FROM psycopg2_items ORDER BY id') res = cur.fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_halfvec_array(self): embeddings = [HalfVector([1.5, 2, 3]), HalfVector([4.5, 5, 6])] @@ -122,7 +120,7 @@ def test_pool(self): cur = conn.cursor() cur.execute("SELECT '[1,2,3]'::vector") res = cur.fetchone() - assert np.array_equal(res[0], [1, 2, 3]) + assert res[0] == Vector([1, 2, 3]) finally: pool.putconn(conn) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 41c309f..8043e6f 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -190,10 +190,8 @@ def test_orm(self, engine): assert items[0].id % 3 == 1 assert items[1].id % 3 == 2 assert items[2].id % 3 == 0 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self, engine): @@ -201,7 +199,7 @@ def test_vector(self, engine): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert np.array_equal(item.embedding, [1, 2, 3]) + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self, engine): create_items() @@ -429,7 +427,7 @@ def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(avg(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_avg_orm(self, engine): with Session(engine) as session: @@ -438,7 +436,7 @@ def test_avg_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_sum(self, engine): with Session(engine) as session: @@ -447,7 +445,7 @@ def test_sum(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(sum(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_sum_orm(self, engine): with Session(engine) as session: @@ -456,7 +454,7 @@ def test_sum_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_bad_dimensions(self, engine): item = Item(embedding=[1, 2]) @@ -509,7 +507,7 @@ def test_automap(self, engine): with Session(engine) as session: session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}]) item = session.query(AutoItem).first() - assert np.array_equal(item.embedding, [1, 2, 3]) + assert item.embedding == Vector([1, 2, 3]) def test_half_precision(self, engine): create_items() @@ -541,8 +539,7 @@ def test_vector_array(self, engine): # this fails if the driver does not cast arrays item = session.get(Item, 1) - assert np.array_equal(item.embeddings[0], [1, 2, 3]) - assert np.array_equal(item.embeddings[1], [4, 5, 6]) + assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])] def test_halfvec_array(self, engine): with Session(engine) as session: @@ -566,10 +563,10 @@ async def test_vector(self, engine): async with async_session() as session: async with session.begin(): - embedding = np.array([1, 2, 3]) + embedding = Vector([1, 2, 3]) session.add(Item(id=1, embedding=embedding)) item = await session.get(Item, 1) - assert np.array_equal(item.embedding, embedding) + assert item.embedding == embedding await engine.dispose() @@ -579,10 +576,10 @@ async def test_halfvec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = HalfVector([1, 2, 3]) session.add(Item(id=1, half_embedding=embedding)) item = await session.get(Item, 1) - assert item.half_embedding == HalfVector(embedding) + assert item.half_embedding == embedding await engine.dispose() @@ -605,10 +602,10 @@ async def test_sparsevec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = SparseVector([1, 2, 3]) session.add(Item(id=1, sparse_embedding=embedding)) item = await session.get(Item, 1) - assert item.sparse_embedding == SparseVector(embedding) + assert item.sparse_embedding == embedding await engine.dispose() @@ -621,7 +618,7 @@ async def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = await session.scalars(select(avg(Item.embedding))) - assert np.array_equal(res.first(), [2.5, 3.5, 4.5]) + assert res.first() == Vector([2.5, 3.5, 4.5]) await engine.dispose() @@ -639,12 +636,10 @@ async def test_vector_array(self, engine): async with session.begin(): session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])])) item = await session.get(Item, 1) - assert np.array_equal(item.embeddings[0], [1, 2, 3]) - assert np.array_equal(item.embeddings[1], [4, 5, 6]) + assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])] session.add(Item(id=2, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) item = await session.get(Item, 2) - assert np.array_equal(item.embeddings[0], [1, 2, 3]) - assert np.array_equal(item.embeddings[1], [4, 5, 6]) + assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])] await engine.dispose() diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index f4994f4..23312a2 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -1,5 +1,5 @@ import numpy as np -from pgvector import HalfVector, SparseVector +from pgvector import HalfVector, SparseVector, Vector from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum import pytest from sqlalchemy.exc import StatementError @@ -65,10 +65,8 @@ def test_orm(self): assert items[0].id == 1 assert items[1].id == 2 assert items[2].id == 3 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self): @@ -76,7 +74,7 @@ def test_vector(self): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert np.array_equal(item.embedding, np.array([1, 2, 3])) + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -202,7 +200,7 @@ def test_vector_avg(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): with Session(engine) as session: @@ -211,7 +209,7 @@ def test_vector_sum(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_halfvec_avg(self): with Session(engine) as session: