Last active
April 5, 2024 09:44
-
-
Save ergo70/18bb47f4d6b43d51b7049f2f1b82dd31 to your computer and use it in GitHub Desktop.
cosine similarity function on float4 vectors, stored as PostgreSQL bytea
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <math.h> | |
#include "postgres.h" | |
#include "fmgr.h" | |
#include "utils/array.h" | |
#include "access/htup.h" | |
#include "catalog/pg_type.h" | |
#include "utils/lsyscache.h" // Required for building with PGXS (at least on macOS) | |
/* | |
CREATE FUNCTION cast_bytea_to_float4_array(bytea) RETURNS float4[] | |
AS '$libdir/bytea2float4vec', 'cast_bytea_to_float4_array' | |
LANGUAGE C strict immutable parallel safe; | |
CREATE FUNCTION cast_float4_array_to_bytea(float4[]) RETURNS bytea | |
AS '$libdir/bytea2float4vec', 'cast_float4_array_to_bytea' | |
LANGUAGE C strict immutable parallel safe; | |
CREATE FUNCTION cosine_similarity_bytea(bytea, bytea) RETURNS float8 | |
AS '$libdir/bytea2float4vec', 'cosine_similarity_bytea' | |
LANGUAGE C strict immutable parallel safe; | |
CREATE CAST (float4[] AS bytea) WITH FUNCTION cast_float4_array_to_bytea(float4[]) AS assignment; | |
CREATE CAST (bytea AS float4[]) WITH FUNCTION cast_bytea_to_float4_array(bytea) AS assignment; | |
*/ | |
PG_MODULE_MAGIC; | |
PGDLLEXPORT PG_FUNCTION_INFO_V1(cast_bytea_to_float4_array); | |
Datum cast_bytea_to_float4_array(PG_FUNCTION_ARGS) | |
{ | |
bytea *a = PG_GETARG_BYTEA_PP(0); | |
Oid elemtype = FLOAT4OID; | |
uint32 data_length_a = VARSIZE_ANY(a) - VARHDRSZ; | |
float *readptr = (float *)VARDATA_ANY(a); | |
ArrayType *retval = NULL; | |
Datum *elements = NULL; | |
int16 typlen = 0; | |
bool typbyval; | |
char typalign; | |
int ndims = 1; | |
int dims[MAXDIM]; | |
int lbs[MAXDIM]; | |
int num_elements = data_length_a / sizeof(float); | |
dims[0] = num_elements; | |
lbs[0] = 1; | |
elements = (Datum *)palloc0(num_elements * sizeof(Datum)); | |
for (int i = 0; i < num_elements; i++) | |
{ | |
elements[i] = Float4GetDatum(*readptr); | |
readptr++; | |
} | |
get_typlenbyvalalign(elemtype, &typlen, &typbyval, &typalign); | |
retval = construct_md_array(elements, NULL, ndims, dims, lbs, | |
elemtype, typlen, typbyval, typalign); | |
pfree(elements); | |
PG_RETURN_ARRAYTYPE_P(retval); | |
} | |
PGDLLEXPORT PG_FUNCTION_INFO_V1(cast_float4_array_to_bytea); | |
Datum cast_float4_array_to_bytea(PG_FUNCTION_ARGS) | |
{ | |
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); | |
Oid elemtypeA = ARR_ELEMTYPE(a); | |
Datum *datumsA = NULL; | |
int countA = 0; | |
int16 elemWidthA; | |
bool elemTypeByValA; | |
char elemAlignmentCodeA; | |
bytea *retval = NULL; | |
char *writeptr = NULL; | |
float fieldA = 0.0; | |
if (elemtypeA != FLOAT4OID) | |
ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("float4 OID array needed. Got %d", elemtypeA))); | |
if (ARR_NDIM(a) != 1) | |
ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("One-dimensional array needed. Got %d", ARR_NDIM(a)))); | |
get_typlenbyvalalign(elemtypeA, &elemWidthA, &elemTypeByValA, &elemAlignmentCodeA); | |
deconstruct_array(a, elemtypeA, elemWidthA, elemTypeByValA, elemAlignmentCodeA, &datumsA, NULL, &countA); | |
retval = palloc(VARHDRSZ + (countA * elemWidthA)); | |
writeptr = (char *)VARDATA(retval); | |
for (int i = 0; i < countA; i++) | |
{ | |
fieldA = DatumGetFloat4(datumsA[i]); | |
memcpy(writeptr, &fieldA, elemWidthA); | |
writeptr += sizeof(float); | |
} | |
SET_VARSIZE(retval, VARHDRSZ + (countA * elemWidthA)); | |
PG_RETURN_BYTEA_P(retval); | |
} | |
PGDLLEXPORT PG_FUNCTION_INFO_V1(cosine_similarity_bytea); | |
Datum cosine_similarity_bytea(PG_FUNCTION_ARGS) | |
{ | |
bytea *a = PG_GETARG_BYTEA_PP(0); | |
bytea *b = PG_GETARG_BYTEA_PP(1); | |
uint32 data_length_a = VARSIZE_ANY(a) - VARHDRSZ; | |
uint32 data_length_b = VARSIZE_ANY(b) - VARHDRSZ; | |
float *fa = (float *)VARDATA_ANY(a); | |
float *fb = (float *)VARDATA_ANY(b); | |
float distance = 0.0f; | |
float norma = 0.0f; | |
float normb = 0.0f; | |
float8 similarity = -666.0; | |
if ((data_length_a % sizeof(float) != 0) || (data_length_b % sizeof(float) != 0)) | |
ereport(ERROR, | |
(errcode(ERRCODE_DATA_EXCEPTION), | |
errmsg("Vector size does not match sizeof(float)"))); | |
if (data_length_a != data_length_b) | |
ereport(ERROR, | |
(errcode(ERRCODE_DATA_EXCEPTION), | |
errmsg("Different vector dimensions %d and %d", (data_length_a / sizeof(float)), (data_length_b / sizeof(float))))); | |
for (int i = 0; i < data_length_a; i += sizeof(float)) | |
{ | |
distance += *fa * *fb; | |
norma += *fa * *fa; | |
normb += *fb * *fb; | |
fa++; | |
fb++; | |
} | |
similarity = (double)distance / sqrt((double)norma * (double)normb); | |
PG_RETURN_FLOAT8(similarity); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment