10000 DRY HNSW distance calculations · postgrespro/pgvector@f4b67b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit f4b67b0

Browse files
committed
DRY HNSW distance calculations
1 parent 77688b4 commit f4b67b0

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

src/hnswutils.c

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,15 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
482482
}
483483
}
484484

485+
/*
486+
* Calculate the distance between values
487+
*/
488+
static inline float
489+
HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation)
490+
{
491+
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b));
492+
}
493+
485494
/*
486495
* Load an element and optionally get its distance from q
487496
*/
@@ -507,7 +516,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
507516
if (DatumGetPointer(*q) == NULL)
508517
*distance = 0;
509518
else
510-
*distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data)));
519+
*distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation);
511520
}
512521

513522
/* Load element */
@@ -539,7 +548,7 @@ GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo,
539548
{
540549
Datum value = HnswGetValue(base, element);
541550

542-
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value));
551+
return HnswGetDistance(q, value, procinfo, collation);
543552
}
544553

545554
/*
@@ -921,32 +930,22 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b)
921930
return 0;
922931
}
923932

924-
/*
925-
* Calculate the distance between elements
926-
*/
927-
static float
928-
HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation)
929-
{
930-
Datum aValue = HnswGetValue(base, a);
931-
Datum bValue = HnswGetValue(base, b);
932-
933-
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue));
934-
}
935-
936933
/*
937934
* Check if an element is closer to q than any element from R
938935
*/
939936
static bool
940937
CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation)
941938
{
942939
HnswElement eElement = HnswPtrAccess(base, e->element);
940+
Datum eValue = HnswGetValue(base, eElement);
943941
ListCell *lc2;
944942

945943
foreach(lc2, r)
946944
{
947945
HnswCandidate *ri = lfirst(lc2);
948946
HnswElement riElement = HnswPtrAccess(base, ri->element);
949-
float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation);
947+
Datum riValue = HnswGetValue(base, riElement);
948+
float distance = HnswGetDistance(eValue, riValue, procinfo, collation);
950949

951950
if (distance <= e->distance)
952951
return false;

0 commit comments

Comments
 (0)
0