Skip to content

Instantly share code, notes, and snippets.

@zopieux
Created July 10, 2024 20:08
Show Gist options
  • Save zopieux/971956ec8105b9931ad818fefc36a805 to your computer and use it in GitHub Desktop.
Save zopieux/971956ec8105b9931ad818fefc36a805 to your computer and use it in GitHub Desktop.
pgvecto.rs & sqlx interrop
use std::ops::Deref;
use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
use sqlx::{Decode, Encode, Postgres, Type};
/// A vector.
#[derive(Clone, PartialEq, Default)]
pub struct Vector(Vec<f32>);
impl Vector {
pub fn new() -> Self {
Self::default()
}
}
impl From<Vec<f32>> for Vector {
fn from(value: Vec<f32>) -> Self {
Self(value)
}
}
impl Into<Vec<f32>> for Vector {
fn into(self) -> Vec<f32> {
self.0
}
}
impl Deref for Vector {
type Target = Vec<f32>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl core::fmt::Debug for Vector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "vector(({}) {:?})", self.0.len(), &self.0)
}
}
impl Type<Postgres> for Vector {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("vector")
}
}
const F32_SIZE: usize = std::mem::size_of::<f32>();
impl<'r> Encode<'r, Postgres> for Vector {
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_vecf32.rs#:~:text=send
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
let dims = self.0.len();
let bytes = dims * F32_SIZE;
let mut out = vec![0u8; U16_SIZE + bytes];
out[..U16_SIZE].copy_from_slice(&(dims as u16).to_ne_bytes());
out[U16_SIZE..].copy_from_slice(unsafe {
std::slice::from_raw_parts(self.0.as_ptr() as *const u8, bytes)
});
buf.extend(out);
IsNull::No
}
}
impl<'r> Decode<'r, Postgres> for Vector {
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_vecf32.rs#:~:text=recv
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let buf = value.as_bytes()?;
let (dims_bytes, rest) = buf.split_at(U16_SIZE);
let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
let bytes = F32_SIZE * (dims as usize);
let mut slice = Vec::<f32>::with_capacity(dims as usize);
unsafe {
std::ptr::copy(rest.as_ptr(), slice.as_mut_ptr().cast(), bytes);
slice.set_len(dims as usize);
};
Ok(slice.into())
}
}
/// A binary vector.
#[derive(Clone, PartialEq, Default)]
pub struct BVector {
dims: u16,
data: Vec<usize>,
}
const USIZE_WIDTH: usize = usize::BITS as usize;
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
const U16_SIZE: usize = std::mem::size_of::<u16>();
impl core::fmt::Debug for BVector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let x: String = self
.clone()
.iter()
.map(|b| if b { '1' } else { '0' })
.collect();
write!(f, "bvector(({}) {})", self.dims, x)
}
}
impl BVector {
pub fn new() -> Self {
Self::default()
}
pub fn from_bits(bools_as_u8: &[u8]) -> Self {
let bools: Vec<_> = bools_as_u8.into_iter().map(|i| *i != 0u8).collect();
Self::from_bools(&bools)
}
pub fn from_bools(bools: &[bool]) -> Self {
Self {
dims: bools.len() as u16,
data: {
let mut data = Vec::new();
let mut current_usize = 0;
let mut bit_index = 0;
for &bit in bools {
if bit {
current_usize |= 1 << bit_index;
}
bit_index += 1;
if bit_index == USIZE_WIDTH {
data.push(current_usize);
current_usize = 0;
bit_index = 0;
}
}
if bit_index > 0 {
data.push(current_usize);
}
data
},
}
}
pub fn iter(self) -> impl Iterator<Item = bool> {
let mut index = 0;
std::iter::from_fn(move || {
if index < self.dims as usize {
let result = self.data[index / USIZE_WIDTH] & (1 << (index % USIZE_WIDTH)) != 0;
index += 1;
Some(result)
} else {
None
}
})
}
}
impl Type<Postgres> for BVector {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("bvector")
}
}
impl<'r> Encode<'r, Postgres> for BVector {
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_bvecf32.rs#:~:text=send
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
let bytes = (self.dims as usize).div_ceil(USIZE_WIDTH) * std::mem::size_of::<usize>();
let mut out = vec![0u8; U16_SIZE + bytes];
out[..U16_SIZE].copy_from_slice(&self.dims.to_ne_bytes());
out[U16_SIZE..].copy_from_slice(unsafe {
std::slice::from_raw_parts(self.data.as_ptr() as *const u8, bytes)
});
buf.extend(out);
IsNull::No
}
}
impl<'r> Decode<'r, Postgres> for BVector {
// https://github.com/tensorchord/pgvecto.rs/blob/main/src/datatype/binary_bvecf32.rs#:~:text=recv
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let buf = value.as_bytes()?;
let (dims_bytes, rest) = buf.split_at(U16_SIZE);
let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
let usizes = (dims as usize).div_ceil(USIZE_WIDTH);
let bytes = usizes * USIZE_SIZE;
let mut data = Vec::<usize>::with_capacity(usizes);
unsafe {
std::ptr::copy(rest.as_ptr(), data.as_mut_ptr().cast(), bytes);
data.set_len(usizes);
};
Ok(Self { dims, data })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_vector() {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
let vec: Vec<f32> = vec![1.1, 2.2, 3.14, 1e6, 1e10];
let bv_out: Vector = sqlx::query_scalar(r#"select '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(*bv_out, vec);
let is_same: bool = sqlx::query_scalar(r#"select $1 = '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
.bind(&vec)
.fetch_one(&pool)
.await
.unwrap();
assert!(is_same);
}
#[tokio::test]
async fn test_bvector() {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
let bv_out: BVector = sqlx::query_scalar(r#"select binarize('[1,0,1,1,1,1,1,1,0,0,1]')"#)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
bv_out,
BVector::from_bits(&[1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1])
);
let v_in = BVector::from_bools(&[true, true, false, false, false, true, false]);
let is_same: bool = sqlx::query_scalar(r#"select $1 = binarize('[1,1,0,0,0,1,0]')"#)
.bind(&v_in)
.fetch_one(&pool)
.await
.unwrap();
assert!(is_same);
}
}
@spikecodes
Copy link

Thanks for this!

@spikecodes
Copy link

Using Claude 3.5 Sonnet, I converted the unsafe code to a safe code alternative. It passed all of my tests so I believe it should function the same.

use std::ops::Deref;

use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef};
use sqlx::{Decode, Encode, Postgres, Type};

/// A vector.
#[derive(Clone, PartialEq, Default)]
pub struct Vector(Vec<f32>);

impl Vector {
	pub fn new() -> Self {
		Self::default()
	}
}

impl From<Vec<f32>> for Vector {
	fn from(value: Vec<f32>) -> Self {
		Self(value)
	}
}

impl Into<Vec<f32>> for Vector {
	fn into(self) -> Vec<f32> {
		self.0
	}
}

impl Deref for Vector {
	type Target = Vec<f32>;

	fn deref(&self) -> &Self::Target {
		&self.0
	}
}

impl core::fmt::Debug for Vector {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		write!(f, "vector(({}) {:?})", self.0.len(), &self.0)
	}
}

impl Type<Postgres> for Vector {
	fn type_info() -> PgTypeInfo {
		PgTypeInfo::with_name("vector")
	}
}

const F32_SIZE: usize = std::mem::size_of::<f32>();

impl<'r> Encode<'r, Postgres> for Vector {
	fn encode_by_ref(
		&self,
		buf: &mut PgArgumentBuffer,
	) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
		let dims = self.0.len();
		let bytes = dims * F32_SIZE;
		let mut out = vec![0u8; U16_SIZE + bytes];
		out[..U16_SIZE].copy_from_slice(&(dims as u16).to_ne_bytes());
		let float_bytes: Vec<u8> = self.0.iter().flat_map(|&f| f.to_ne_bytes()).collect();
		out[U16_SIZE..].copy_from_slice(&float_bytes);
		buf.extend(out);
		Ok(IsNull::No)
	}
}

impl<'r> Decode<'r, Postgres> for Vector {
	fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
		let buf = value.as_bytes()?;
		let (dims_bytes, rest) = buf.split_at(U16_SIZE);
		let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
		let bytes = F32_SIZE * (dims as usize);
		let mut slice = Vec::<f32>::with_capacity(dims as usize);
		for chunk in rest[..bytes].chunks_exact(F32_SIZE) {
			slice.push(f32::from_ne_bytes(chunk.try_into()?));
		}
		Ok(slice.into())
	}
}

/// A binary vector.
#[derive(Clone, PartialEq, Default)]
pub struct BVector {
	dims: u16,
	data: Vec<usize>,
}

const USIZE_WIDTH: usize = usize::BITS as usize;
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
const U16_SIZE: usize = std::mem::size_of::<u16>();

impl core::fmt::Debug for BVector {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		let x: String = self
			.clone()
			.iter()
			.map(|b| if b { '1' } else { '0' })
			.collect();
		write!(f, "bvector(({}) {})", self.dims, x)
	}
}

impl BVector {
	pub fn new() -> Self {
		Self::default()
	}

	pub fn from_bits(bools_as_u8: &[u8]) -> Self {
		let bools: Vec<_> = bools_as_u8.into_iter().map(|i| *i != 0u8).collect();
		Self::from_bools(&bools)
	}

	pub fn from_bools(bools: &[bool]) -> Self {
		Self {
			dims: bools.len() as u16,
			data: {
				let mut data = Vec::new();
				let mut current_usize = 0;
				let mut bit_index = 0;
				for &bit in bools {
					if bit {
						current_usize |= 1 << bit_index;
					}
					bit_index += 1;
					if bit_index == USIZE_WIDTH {
						data.push(current_usize);
						current_usize = 0;
						bit_index = 0;
					}
				}
				if bit_index > 0 {
					data.push(current_usize);
				}
				data
			},
		}
	}

	pub fn iter(self) -> impl Iterator<Item = bool> {
		let mut index = 0;
		std::iter::from_fn(move || {
			if index < self.dims as usize {
				let result = self.data[index / USIZE_WIDTH] & (1 << (index % USIZE_WIDTH)) != 0;
				index += 1;
				Some(result)
			} else {
				None
			}
		})
	}
}

impl Type<Postgres> for BVector {
	fn type_info() -> PgTypeInfo {
		PgTypeInfo::with_name("bvector")
	}
}

impl<'r> Encode<'r, Postgres> for BVector {
	fn encode_by_ref(
		&self,
		buf: &mut PgArgumentBuffer,
	) -> Result<IsNull, Box<dyn std::error::Error + Send + Sync + 'static>> {
		let bytes = (self.dims as usize).div_ceil(USIZE_WIDTH) * std::mem::size_of::<usize>();
		let mut out = vec![0u8; U16_SIZE + bytes];
		out[..U16_SIZE].copy_from_slice(&self.dims.to_ne_bytes());
		for (i, &usize_val) in self.data.iter().enumerate() {
			let start = U16_SIZE + i * USIZE_SIZE;
			let end = start + USIZE_SIZE;
			out[start..end].copy_from_slice(&usize_val.to_ne_bytes());
		}
		buf.extend(out);
		Ok(IsNull::No)
	}
}

impl<'r> Decode<'r, Postgres> for BVector {
	fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
		let buf = value.as_bytes()?;
		let (dims_bytes, rest) = buf.split_at(U16_SIZE);
		let dims = u16::from_ne_bytes(dims_bytes.try_into()?);
		let usizes = (dims as usize).div_ceil(USIZE_WIDTH);
		let mut data = Vec::<usize>::with_capacity(usizes);
		for chunk in rest.chunks_exact(USIZE_SIZE) {
			data.push(usize::from_ne_bytes(chunk.try_into()?));
		}
		Ok(Self { dims, data })
	}
}

#[cfg(test)]
mod tests {
	use super::*;

	#[tokio::test]
	async fn test_vector() {
		let pool = sqlx::postgres::PgPoolOptions::new()
			.connect(&std::env::var("DATABASE_URL").unwrap())
			.await
			.unwrap();

		let vec: Vec<f32> = vec![1.1, 2.2, 3.14, 1e6, 1e10];
		let bv_out: Vector = sqlx::query_scalar(r#"select '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
			.fetch_one(&pool)
			.await
			.unwrap();
		assert_eq!(*bv_out, vec);

		let is_same: bool = sqlx::query_scalar(r#"select $1 = '[1.1,2.2,3.14,1e6,1e10]'::vector"#)
			.bind(&vec)
			.fetch_one(&pool)
			.await
			.unwrap();
		assert!(is_same);
	}

	#[tokio::test]
	async fn test_bvector() {
		let pool = sqlx::postgres::PgPoolOptions::new()
			.connect(&std::env::var("DATABASE_URL").unwrap())
			.await
			.unwrap();

		let bv_out: BVector = sqlx::query_scalar(r#"select binarize('[1,0,1,1,1,1,1,1,0,0,1]')"#)
			.fetch_one(&pool)
			.await
			.unwrap();
		assert_eq!(
			bv_out,
			BVector::from_bits(&[1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1])
		);

		let v_in = BVector::from_bools(&[true, true, false, false, false, true, false]);
		let is_same: bool = sqlx::query_scalar(r#"select $1 = binarize('[1,1,0,0,0,1,0]')"#)
			.bind(&v_in)
			.fetch_one(&pool)
			.await
			.unwrap();
		assert!(is_same);
	}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment