|
3 | 3 | from django.contrib.postgres.fields import ArrayField
|
4 | 4 | from django.core import serializers
|
5 | 5 | from django.db import connection, migrations, models
|
6 |
| -from django.db.models import Avg, Sum |
| 6 | +from django.db.models import Avg, Sum, FloatField, DecimalField |
| 7 | +from django.db.models.functions import Cast |
7 | 8 | from django.db.migrations.loader import MigrationLoader
|
8 | 9 | from django.forms import ModelForm
|
9 | 10 | from math import sqrt
|
@@ -48,6 +49,8 @@ class Item(models.Model):
|
48 | 49 | binary_embedding = BitField(length=3, null=True, blank=True)
|
49 | 50 | sparse_embedding = SparseVectorField(dimensions=3, null=True, blank=True)
|
50 | 51 | embeddings = ArrayField(VectorField(dimensions=3), null=True, blank=True)
|
| 52 | + double_embedding = ArrayField(FloatField(), null=True, blank=True) |
| 53 | + numeric_embedding = ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True) |
51 | 54 |
|
52 | 55 | class Meta:
|
53 | 56 | app_label = 'django_app'
|
@@ -85,6 +88,8 @@ class Migration(migrations.Migration):
|
85 | 88 | ('binary_embedding', pgvector.django.BitField(length=3, null=True, blank=True)),
|
86 | 89 | ('sparse_embedding', pgvector.django.SparseVectorField(dimensions=3, null=True, blank=True)),
|
87 | 90 | ('embeddings', ArrayField(pgvector.django.VectorField(dimensions=3), null=True, blank=True)),
|
| 91 | + ('double_embedding', ArrayField(FloatField(), null=True, blank=True)), |
| 92 | + ('numeric_embedding', ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True)), |
88 | 93 | ],
|
89 | 94 | ),
|
90 | 95 | migrations.AddIndex(
|
@@ -448,3 +453,23 @@ def test_vector_array(self):
|
448 | 453 | item = Item.objects.get(pk=1)
|
449 | 454 | assert item.embeddings[0].tolist() == [1, 2, 3]
|
450 | 455 | assert item.embeddings[1].tolist() == [4, 5, 6]
|
| 456 | + |
| 457 | + def test_double_array(self): |
| 458 | + Item(id=1, double_embedding=[1, 1, 1]).save() |
| 459 | + Item(id=2, double_embedding=[2, 2, 2]).save() |
| 460 | + Item(id=3, double_embedding=[1, 1, 2]).save() |
| 461 | + distance = L2Distance(Cast('double_embedding', VectorField()), [1, 1, 1]) |
| 462 | + items = Item.objects.annotate(distance=distance).order_by(distance) |
| 463 | + assert [v.id for v in items] == [1, 3, 2] |
| 464 | + assert [v.distance for v in items] == [0, 1, sqrt(3)] |
| 465 | + assert items[1].double_embedding == [1, 1, 2] |
| 466 | + |
| 467 | + def test_numeric_array(self): |
| 468 | + Item(id=1, numeric_embedding=[1, 1, 1]).save() |
| 469 | + Item(id=2, numeric_embedding=[2, 2, 2]).save() |
| 470 | + Item(id=3, numeric_embedding=[1, 1, 2]).save() |
| 471 | + distance = L2Distance(Cast('numeric_embedding', VectorField()), [1, 1, 1]) |
| 472 | + items = Item.objects.annotate(distance=distance).order_by(distance) |
| 473 | + assert [v.id for v in items] == [1, 3, 2] |
| 474 | + assert [v.distance for v in items] == [0, 1, sqrt(3)] |
| 475 | + assert items[1].numeric_embedding == [1, 1, 2] |
0 commit comments