Created
July 10, 2024 20:08
-
-
Save zopieux/971956ec8105b9931ad818fefc36a805 to your computer and use it in GitHub Desktop.
pgvecto.rs & sqlx interrop
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::ops::Deref; | |
use sqlx::encode::IsNull; | |
use sqlx::error::BoxDynError; | |
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef}; | |
use sqlx::{Decode, Encode, Postgres, Type}; | |
/// A vector. | |
#[derive(Clone, PartialEq, Default)] | |
pub struct Vector(Vec<f32>); | |
impl Vector { | |
pub fn new() -> Self { | |
Self::default() | |
} | |
} | |
impl From<Vec<f32>> for Vector { | |
fn from(value: Vec<f32>) -> Self { | |
Self(value) | |
} | |
} | |
impl Into<Vec<f32>> for Vector { | |
fn into(self) -> Vec<f32> { | |
self.0 | |
} | |
} | |
impl Deref for Vector { | |
type Target = Vec<f32>; | |
fn deref(&self) -> &Self::Target { | |
&self.0 | |
} | |
} | |
impl core::fmt::Debug for Vector { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
write!(f, "vector(({}) {:?})", self.0.len(), &self.0) | |
} | |
} | |
impl Type<Postgres> for Vector { | |
fn type_info() -> PgTypeInfo { | |
PgTypeInfo::with_name("vector") | |
} | |
} | |
const F32_SIZE: usize = std::mem::size_of::<f32>(); | |
impl<'r> Encode<'r, Postgres> for Vector { | |
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_vecf32.rs#:~:text=send | |
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { | |
let dims = self.0.len(); | |
let bytes = dims * F32_SIZE; | |
let mut out = vec![0u8; U16_SIZE + bytes]; | |
out[..U16_SIZE].copy_from_slice(&(dims as u16).to_ne_bytes()); | |
out[U16_SIZE..].copy_from_slice(unsafe { | |
std::slice::from_raw_parts(self.0.as_ptr() as *const u8, bytes) | |
}); | |
buf.extend(out); | |
IsNull::No | |
} | |
} | |
impl<'r> Decode<'r, Postgres> for Vector { | |
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_vecf32.rs#:~:text=recv | |
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> { | |
let buf = value.as_bytes()?; | |
let (dims_bytes, rest) = buf.split_at(U16_SIZE); | |
let dims = u16::from_ne_bytes(dims_bytes.try_into()?); | |
let bytes = F32_SIZE * (dims as usize); | |
let mut slice = Vec::<f32>::with_capacity(dims as usize); | |
unsafe { | |
std::ptr::copy(rest.as_ptr(), slice.as_mut_ptr().cast(), bytes); | |
slice.set_len(dims as usize); | |
}; | |
Ok(slice.into()) | |
} | |
} | |
/// A binary vector. | |
#[derive(Clone, PartialEq, Default)] | |
pub struct BVector { | |
dims: u16, | |
data: Vec<usize>, | |
} | |
const USIZE_WIDTH: usize = usize::BITS as usize; | |
const USIZE_SIZE: usize = std::mem::size_of::<usize>(); | |
const U16_SIZE: usize = std::mem::size_of::<u16>(); | |
impl core::fmt::Debug for BVector { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
let x: String = self | |
.clone() | |
.iter() | |
.map(|b| if b { '1' } else { '0' }) | |
.collect(); | |
write!(f, "bvector(({}) {})", self.dims, x) | |
} | |
} | |
impl BVector { | |
pub fn new() -> Self { | |
Self::default() | |
} | |
pub fn from_bits(bools_as_u8: &[u8]) -> Self { | |
let bools: Vec<_> = bools_as_u8.into_iter().map(|i| *i != 0u8).collect(); | |
Self::from_bools(&bools) | |
} | |
pub fn from_bools(bools: &[bool]) -> Self { | |
Self { | |
dims: bools.len() as u16, | |
data: { | |
let mut data = Vec::new(); | |
let mut current_usize = 0; | |
let mut bit_index = 0; | |
for &bit in bools { | |
if bit { | |
current_usize |= 1 << bit_index; | |
} | |
bit_index += 1; | |
if bit_index == USIZE_WIDTH { | |
data.push(current_usize); | |
current_usize = 0; | |
bit_index = 0; | |
} | |
} | |
if bit_index > 0 { | |
data.push(current_usize); | |
} | |
data | |
}, | |
} | |
} | |
pub fn iter(self) -> impl Iterator<Item = bool> { | |
let mut index = 0; | |
std::iter::from_fn(move || { | |
if index < self.dims as usize { | |
let result = self.data[index / USIZE_WIDTH] & (1 << (index % USIZE_WIDTH)) != 0; | |
index += 1; | |
Some(result) | |
} else { | |
None | |
} | |
}) | |
} | |
} | |
impl Type<Postgres> for BVector { | |
fn type_info() -> PgTypeInfo { | |
PgTypeInfo::with_name("bvector") | |
} | |
} | |
impl<'r> Encode<'r, Postgres> for BVector { | |
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_bvecf32.rs#:~:text=send | |
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { | |
let bytes = (self.dims as usize).div_ceil(USIZE_WIDTH) * std::mem::size_of::<usize>(); | |
let mut out = vec![0u8; U16_SIZE + bytes]; | |
out[..U16_SIZE].copy_from_slice(&self.dims.to_ne_bytes()); | |
out[U16_SIZE..].copy_from_slice(unsafe { | |
std::slice::from_raw_parts(self.data.as_ptr() as *const u8, bytes) | |
}); | |
buf.extend(out); | |
IsNull::No | |
} | |
} | |
impl<'r> Decode<'r, Postgres> for BVector { | |
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_bvecf32.rs#:~:text=recv | |
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> { | |
let buf = value.as_bytes()?; | |
let (dims_bytes, rest) = buf.split_at(U16_SIZE); | |
let dims = u16::from_ne_bytes(dims_bytes.try_into()?); | |
let usizes = (dims as usize).div_ceil(USIZE_WIDTH); | |
let bytes = usizes * USIZE_SIZE; | |
let mut data = Vec::<usize>::with_capacity(usizes); | |
unsafe { | |
std::ptr::copy(rest.as_ptr(), data.as_mut_ptr().cast(), bytes); | |
data.set_len(usizes); | |
}; | |
Ok(Self { dims, data }) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[tokio::test] | |
async fn test_vector() { | |
let pool = sqlx::postgres::PgPoolOptions::new() | |
.connect(&std::env::var("DATABASE_URL").unwrap()) | |
.await | |
.unwrap(); | |
let vec: Vec<f32> = vec![1.1, 2.2, 3.14, 1e6, 1e10]; | |
let bv_out: Vector = sqlx::query_scalar(r#"select '[1.1,2.2,3.14,1e6,1e10]'::vector"#) | |
.fetch_one(&pool) | |
.await | |
.unwrap(); | |
assert_eq!(*bv_out, vec); | |
let is_same: bool = sqlx::query_scalar(r#"select $1 = '[1.1,2.2,3.14,1e6,1e10]'::vector"#) | |
.bind(&vec) | |
.fetch_one(&pool) | |
.await | |
.unwrap(); | |
assert!(is_same); | |
} | |
#[tokio::test] | |
async fn test_bvector() { | |
let pool = sqlx::postgres::PgPoolOptions::new() | |
.connect(&std::env::var("DATABASE_URL").unwrap()) | |
.await | |
.unwrap(); | |
let bv_out: BVector = sqlx::query_scalar(r#"select binarize('[1,0,1,1,1,1,1,1,0,0,1]')"#) | |
.fetch_one(&pool) | |
.await | |
.unwrap(); | |
assert_eq!( | |
bv_out, | |
BVector::from_bits(&[1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1]) | |
); | |
let v_in = BVector::from_bools(&[true, true, false, false, false, true, false]); | |
let is_same: bool = sqlx::query_scalar(r#"select $1 = binarize('[1,1,0,0,0,1,0]')"#) | |
.bind(&v_in) | |
.fetch_one(&pool) | |
.await | |
.unwrap(); | |
assert!(is_same); | |
} | |
} |
Using Claude 3.5 Sonnet, I converted the unsafe code to a safe code alternative. It passed all of my tests so I believe it should function the same.
use std::ops::Deref;
use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
use sqlx::{Decode, Encode, Postgres, Type};
/// A vector.
#[derive(Clone, PartialEq, Default)]
pub struct Vector(Vec<f32>);
impl Vector {
pub fn new() -> Self {
Self::default()
}
}
impl From<Vec<f32>> for Vector {
fn from(value: Vec<f32>) -> Self {
Self(value)
}
}
impl Into<Vec<f32>> for Vector {
fn into(self) -> Vec<f32> {
self.0
}
}
impl Deref for Vector {
type Target = Vec<f32>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl core::fmt::Debug for Vector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "vector(({}) {:?})", self.0.len(), &self.0)
}
}
impl Type<Postgres> for Vector {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("vector")
}
}
const F32_SIZE: usize = std::mem::size_of::<f32>();
impl<'r> Encode<'r, Postgres> for Vector {
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
let dims = self.0.len();
let bytes = dims * F32_SIZE;
let mut out = vec![0u8; U16_SIZE + bytes];
out[..U16_SIZE].copy_from_slice(&(dims as u16).to_ne_bytes());
let float_bytes: Vec<u8> = self.0.iter().flat_map(|&f| f.to_ne_bytes()).collect();
out[U16_SIZE..].copy_from_slice(&float_bytes);
buf.extend(out);
Ok(IsNull::No)
}
}
impl<'r> Decode<'r, Postgres> for Vector {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let buf = value.as_bytes()?;
let (dims_bytes, rest) = buf.split_at(U16_SIZE);
let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
let bytes = F32_SIZE * (dims as usize);
let mut slice = Vec::<f32>::with_capacity(dims as usize);
for chunk in rest[..bytes].chunks_exact(F32_SIZE) {
slice.push(f32::from_ne_bytes(chunk.try_into()?));
}
Ok(slice.into())
}
}
/// A binary vector.
#[derive(Clone, PartialEq, Default)]
pub struct BVector {
dims: u16,
data: Vec<usize>,
}
const USIZE_WIDTH: usize = usize::BITS as usize;
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
const U16_SIZE: usize = std::mem::size_of::<u16>();
impl core::fmt::Debug for BVector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let x: String = self
.clone()
.iter()
.map(|b| if b { '1' } else { '0' })
.collect();
write!(f, "bvector(({}) {})", self.dims, x)
}
}
impl BVector {
pub fn new() -> Self {
Self::default()
}
pub fn from_bits(bools_as_u8: &[u8]) -> Self {
let bools: Vec<_> = bools_as_u8.into_iter().map(|i| *i != 0u8).collect();
Self::from_bools(&bools)
}
pub fn from_bools(bools: &[bool]) -> Self {
Self {
dims: bools.len() as u16,
data: {
let mut data = Vec::new();
let mut current_usize = 0;
let mut bit_index = 0;
for &bit in bools {
if bit {
current_usize |= 1 << bit_index;
}
bit_index += 1;
if bit_index == USIZE_WIDTH {
data.push(current_usize);
current_usize = 0;
bit_index = 0;
}
}
if bit_index > 0 {
data.push(current_usize);
}
data
},
}
}
pub fn iter(self) -> impl Iterator<Item = bool> {
let mut index = 0;
std::iter::from_fn(move || {
if index < self.dims as usize {
let result = self.data[index / USIZE_WIDTH] & (1 << (index % USIZE_WIDTH)) != 0;
index += 1;
Some(result)
} else {
None
}
})
}
}
impl Type<Postgres> for BVector {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("bvector")
}
}
impl<'r> Encode<'r, Postgres> for BVector {
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
let bytes = (self.dims as usize).div_ceil(USIZE_WIDTH) * std::mem::size_of::<usize>();
let mut out = vec![0u8; U16_SIZE + bytes];
out[..U16_SIZE].copy_from_slice(&self.dims.to_ne_bytes());
for (i, &usize_val) in self.data.iter().enumerate() {
let start = U16_SIZE + i * USIZE_SIZE;
let end = start + USIZE_SIZE;
out[start..end].copy_from_slice(&usize_val.to_ne_bytes());
}
buf.extend(out);
Ok(IsNull::No)
}
}
impl<'r> Decode<'r, Postgres> for BVector {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let buf = value.as_bytes()?;
let (dims_bytes, rest) = buf.split_at(U16_SIZE);
let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
let usizes = (dims as usize).div_ceil(USIZE_WIDTH);
let mut data = Vec::<usize>::with_capacity(usizes);
for chunk in rest.chunks_exact(USIZE_SIZE) {
data.push(usize::from_ne_bytes(chunk.try_into()?));
}
Ok(Self { dims, data })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_vector() {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
let vec: Vec<f32> = vec![1.1, 2.2, 3.14, 1e6, 1e10];
let bv_out: Vector = sqlx::query_scalar(r#"select '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(*bv_out, vec);
let is_same: bool = sqlx::query_scalar(r#"select $1 = '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
.bind(&vec)
.fetch_one(&pool)
.await
.unwrap();
assert!(is_same);
}
#[tokio::test]
async fn test_bvector() {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
let bv_out: BVector = sqlx::query_scalar(r#"select binarize('[1,0,1,1,1,1,1,1,0,0,1]')"#)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
bv_out,
BVector::from_bits(&[1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1])
);
let v_in = BVector::from_bools(&[true, true, false, false, false, true, false]);
let is_same: bool = sqlx::query_scalar(r#"select $1 = binarize('[1,1,0,0,0,1,0]')"#)
.bind(&v_in)
.fetch_one(&pool)
.await
.unwrap();
assert!(is_same);
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for this!