8000 Added tests for double and numeric arrays · pgvector/pgvector-python@79d4111 · GitHub
[go: up one dir, main page]

Skip to content

Commit 79d4111

Browse files
committed
Added tests for double and numeric arrays
1 parent 59a3efc commit 79d4111

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

tests/test_django.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from django.contrib.postgres.fields import ArrayField
44
from django.core import serializers
55
from django.db import connection, migrations, models
6-
from django.db.models import Avg, Sum
6+
from django.db.models import Avg, Sum, FloatField, DecimalField
7+
from django.db.models.functions import Cast
78
from django.db.migrations.loader import MigrationLoader
89
from django.forms import ModelForm
910
from math import sqrt
@@ -48,6 +49,8 @@ class Item(models.Model):
4849
binary_embedding = BitField(length=3, null=True, blank=True)
4950
sparse_embedding = SparseVectorField(dimensions=3, null=True, blank=True)
5051
embeddings = ArrayField(VectorField(dimensions=3), null=True, blank=True)
52+
double_embedding = ArrayField(FloatField(), null=True, blank=True)
53+
numeric_embedding = ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True)
5154

5255
class Meta:
5356
app_label = 'django_app'
@@ -85,6 +88,8 @@ class Migration(migrations.Migration):
8588
('binary_embedding', pgvector.django.BitField(length=3, null=True, blank=True)),
8689
('sparse_embedding', pgvector.django.SparseVectorField(dimensions=3, null=True, blank=True)),
8790
('embeddings', ArrayField(pgvector.django.VectorField(dimensions=3), null=True, blank=True)),
91+
('double_embedding', ArrayField(FloatField(), null=True, blank=True)),
92+
('numeric_embedding', ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True)),
8893
],
8994
),
9095
migrations.AddIndex(
@@ -448,3 +453,23 @@ def test_vector_array(self):
448453
item = Item.objects.get(pk=1)
449454
assert item.embeddings[0].tolist() == [1, 2, 3]
450455
assert item.embeddings[1].tolist() == [4, 5, 6]
456+
457+
def test_double_array(self):
458+
Item(id=1, double_embedding=[1, 1, 1]).save()
459+
Item(id=2, double_embedding=[2, 2, 2]).save()
460+
Item(id=3, double_embedding=[1, 1, 2]).save()
461+
distance = L2Distance(Cast('double_embedding', VectorField()), [1, 1, 1])
462+
items = Item.objects.annotate(distance=distance).order_by(distance)
463+
assert [v.id for v in items] == [1, 3, 2]
464+
assert [v.distance for v in items] == [0, 1, sqrt(3)]
465+
assert items[1].double_embedding == [1, 1, 2]
466+
467+
def test_numeric_array(self):
468+
Item(id=1, numeric_embedding=[1, 1, 1]).save()
469+
Item(id=2, numeric_embedding=[2, 2, 2]).save()
470+
Item(id=3, numeric_embedding=[1, 1, 2]).save()
471+
distance = L2Distance(Cast('numeric_embedding', VectorField()), [1, 1, 1])
472+
items = Item.objects.annotate(distance=distance).order_by(distance)
473+
assert [v.id for v in items] == [1, 3, 2]
474+
assert [v.distance for v in items] == [0, 1, sqrt(3)]
475+
assert items[1].numeric_embedding == [1, 1, 2]

0 commit comments

Comments
 (0)
0