Skip to content

Instantly share code, notes, and snippets.

@mp911de
Created October 30, 2024 11:04
Show Gist options
  • Save mp911de/c302dbfbdc9590f038054e6d016cd6b5 to your computer and use it in GitHub Desktop.
Save mp911de/c302dbfbdc9590f038054e6d016cd6b5 to your computer and use it in GitHub Desktop.
Vectors 😱🀯
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import java.util.Arrays;
import org.springframework.util.ObjectUtils;
/**
* {@link Vector} implementation based on {@code double} array.
*
* @author Mark Paluch
*/
class DoubleVector implements Vector {
private final double[] v;
public DoubleVector(double[] v) {
this.v = v;
}
/**
* Copy the given {@code double} array and wrap it within a Vector.
*/
static Vector copy(double[] v) {
double[] copy = new double[v.length];
System.arraycopy(v, 0, copy, 0, copy.length);
return new DoubleVector(copy);
}
@Override
public Class<Double> getType() {
return Double.TYPE;
}
@Override
public Object getSource() {
return v;
}
@Override
public int size() {
return v.length;
}
@Override
public float[] toFloatArray() {
float[] copy = new float[this.v.length];
for (int i = 0; i < this.v.length; i++) {
copy[i] = (float) this.v[i];
}
return copy;
}
@Override
public double[] toDoubleArray() {
double[] copy = new double[this.v.length];
System.arraycopy(this.v, 0, copy, 0, copy.length);
return copy;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof DoubleVector that)) {
return false;
}
return ObjectUtils.nullSafeEquals(v, that.v);
}
@Override
public int hashCode() {
return Arrays.hashCode(v);
}
@Override
public String toString() {
return "D" + Arrays.toString(v);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import static org.assertj.core.api.Assertions.*;
import org.junit.jupiter.api.Test;
/**
* Unit tests for {@link DoubleVector}.
*
* @author Mark Paluch
*/
class DoubleVectorUnitTests {
double[] values = new double[] { 1.1, 2.2, 3.3, 4.4, 5.5 };
float[] floats = new float[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d };
@Test
void shouldCreateVector() {
Vector vector = Vector.of(values);
assertThat(vector.size()).isEqualTo(5);
assertThat(vector.getType()).isEqualTo(Double.TYPE);
}
@Test
void shouldCreateUnsafeVector() {
Vector vector = Vector.unsafe(values);
assertThat(vector.getSource()).isSameAs(values);
}
@Test
void shouldCopyVectorValues() {
Vector vector = Vector.of(values);
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values);
}
@Test
void shouldRenderToString() {
Vector vector = Vector.of(values);
assertThat(vector).hasToString("D[1.1, 2.2, 3.3, 4.4, 5.5]");
}
@Test
void shouldCompareVector() {
Vector vector = Vector.of(values);
assertThat(vector).isEqualTo(Vector.of(values));
assertThat(vector).hasSameHashCodeAs(Vector.of(values));
}
@Test
void sourceShouldReturnSource() {
Vector vector = new DoubleVector(values);
assertThat(vector.getSource()).isSameAs(values);
}
@Test
void shouldCreateFloatArray() {
Vector vector = Vector.of(values);
assertThat(vector.toFloatArray()).isEqualTo(floats).isNotSameAs(floats);
}
@Test
void shouldCreateDoubleArray() {
Vector vector = Vector.of(values);
assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import java.util.Arrays;
import org.springframework.util.ObjectUtils;
/**
* {@link Vector} implementation based on {@code float} array.
*
* @author Mark Paluch
*/
class FloatVector implements Vector {
private final float[] v;
public FloatVector(float[] v) {
this.v = v;
}
/**
* Copy the given {@code float} array and wrap it within a Vector.
*/
static Vector copy(float[] v) {
float[] copy = new float[v.length];
System.arraycopy(v, 0, copy, 0, copy.length);
return new FloatVector(copy);
}
@Override
public Class<Float> getType() {
return Float.TYPE;
}
@Override
public Object getSource() {
return v;
}
@Override
public int size() {
return v.length;
}
@Override
public float[] toFloatArray() {
float[] copy = new float[this.v.length];
System.arraycopy(this.v, 0, copy, 0, copy.length);
return copy;
}
@Override
public double[] toDoubleArray() {
double[] copy = new double[this.v.length];
for (int i = 0; i < this.v.length; i++) {
copy[i] = this.v[i];
}
return copy;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof FloatVector that)) {
return false;
}
return ObjectUtils.nullSafeEquals(v, that.v);
}
@Override
public int hashCode() {
return Arrays.hashCode(v);
}
@Override
public String toString() {
return "F" + Arrays.toString(v);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import static org.assertj.core.api.Assertions.*;
import org.junit.jupiter.api.Test;
/**
* Unit tests for {@link FloatVector}.
*
* @author Mark Paluch
*/
class FloatVectorUnitTests {
float[] values = new float[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f };
double[] doubles = new double[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f };
@Test
void shouldCreateVector() {
Vector vector = Vector.of(values);
assertThat(vector.size()).isEqualTo(5);
assertThat(vector.getType()).isEqualTo(Float.TYPE);
}
@Test
void shouldCreateUnsafeVector() {
Vector vector = Vector.unsafe(values);
assertThat(vector.getSource()).isSameAs(values);
}
@Test
void shouldCopyVectorValues() {
Vector vector = Vector.of(values);
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values);
}
@Test
void shouldRenderToString() {
Vector vector = Vector.of(values);
assertThat(vector).hasToString("F[1.1, 2.2, 3.3, 4.4, 5.5]");
}
@Test
void shouldCompareVector() {
Vector vector = Vector.of(values);
assertThat(vector).isEqualTo(Vector.of(values));
assertThat(vector).hasSameHashCodeAs(Vector.of(values));
}
@Test
void sourceShouldReturnSource() {
Vector vector = new FloatVector(values);
assertThat(vector.getSource()).isSameAs(values);
}
@Test
void shouldCreateFloatArray() {
Vector vector = Vector.of(values);
assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(values);
}
@Test
void shouldCreateDoubleArray() {
Vector vector = Vector.of(values);
assertThat(vector.toDoubleArray()).isEqualTo(doubles).isNotSameAs(doubles);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import java.util.Arrays;
import java.util.Collection;
import org.springframework.util.ObjectUtils;
/**
* {@link Vector} implementation based on {@link Number} array.
*
* @author Mark Paluch
*/
class NumberVector implements Vector {
private final Number[] v;
public NumberVector(Number[] v) {
this.v = v;
}
/**
* Copy the given {@link Number} array and wrap it within a Vector.
*/
static Vector copy(Number[] v) {
Number[] copy = new Number[v.length];
System.arraycopy(v, 0, copy, 0, copy.length);
return new NumberVector(copy);
}
/**
* Copy the given {@link Number} and wrap it within a Vector.
*/
static Vector copy(Collection<Number> numbers) {
Number[] copy = new Number[numbers.size()];
int i = 0;
for (Number number : numbers) {
copy[i++] = number;
}
return new NumberVector(copy);
}
@Override
public Class<? extends Number> getType() {
Class<?> candidate = null;
for (Object val : v) {
if (val != null) {
if (candidate == null) {
candidate = val.getClass();
} else if (candidate != val.getClass()) {
return Number.class;
}
}
}
return (Class<? extends Number>) candidate;
}
@Override
public Object getSource() {
return v;
}
@Override
public int size() {
return v.length;
}
@Override
public float[] toFloatArray() {
float[] copy = new float[this.v.length];
for (int i = 0; i < this.v.length; i++) {
copy[i] = this.v[i].floatValue();
}
return copy;
}
@Override
public double[] toDoubleArray() {
double[] copy = new double[this.v.length];
for (int i = 0; i < this.v.length; i++) {
copy[i] = this.v[i].doubleValue();
}
return copy;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof NumberVector that)) {
return false;
}
return ObjectUtils.nullSafeEquals(v, that.v);
}
@Override
public int hashCode() {
return Arrays.hashCode(v);
}
@Override
public String toString() {
return "N" + Arrays.toString(v);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import static org.assertj.core.api.Assertions.*;
import org.junit.jupiter.api.Test;
/**
* Unit tests for {@link NumberVector}.
*
* @author Mark Paluch
*/
class NumberVectorUnitTests {
Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5 };
Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d };
@Test
void shouldCreateVector() {
Vector vector = Vector.of(values);
assertThat(vector.size()).isEqualTo(5);
assertThat(vector.getType()).isEqualTo(Double.class);
}
@Test
void shouldCopyVectorValues() {
Vector vector = Vector.of(values);
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values);
}
@Test
void shouldRenderToString() {
Vector vector = Vector.of(values);
assertThat(vector).hasToString("N[1.1, 2.2, 3.3, 4.4, 5.5]");
}
@Test
void shouldCompareVector() {
Vector vector = Vector.of(values);
assertThat(vector).isEqualTo(Vector.of(values));
assertThat(vector).hasSameHashCodeAs(Vector.of(values));
}
@Test
void sourceShouldReturnSource() {
Vector vector = new NumberVector(values);
assertThat(vector.getSource()).isSameAs(values);
}
@Test
void shouldCreateFloatArray() {
Vector vector = Vector.of(values);
float[] values = new float[this.floats.length];
for (int i = 0; i < values.length; i++) {
values[i] = this.floats[i].floatValue();
}
assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(floats);
}
@Test
void shouldCreateDoubleArray() {
Vector vector = Vector.of(values);
double[] values = new double[this.values.length];
for (int i = 0; i < values.length; i++) {
values[i] = this.values[i].doubleValue();
}
assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values);
}
}
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.domain;
import java.util.Collection;
import org.springframework.util.Assert;
/**
* A vector is a fixed-length array of non-null numeric values. Vectors are represent a point in a multidimensional
* space that is commonly used in machine learning and statistics.
* <p>
* Vector properties do not map cleanly to an existing class in the standard JDK Collections hierarchy. Vectors when
* used with embeddings (machine learning) represent an opaque point in the vector space that does not expose meaningful
* properties nor guarantees computational values to the outside world.
* <p>
* Vectors should be treated as opaque values and should not be modified. They can be created from an array of numbers
* (typically {@code double} or {@code float} values) and used by components that need to provide the vector for storage
* or computation.
*
* @author Mark Paluch
*/
public interface Vector {
/**
* Creates a new {@link Vector} from the given float {@code values}. Vector values are duplicated to avoid capturing a
* mutable array instance and to prevent mutability.
*
* @param values float vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector of(float... values) {
Assert.notNull(values, "float vector values must not be null");
return FloatVector.copy(values);
}
/**
* Creates a new {@link Vector} from the given double {@code values}. Vector values are duplicated to avoid capturing
* a mutable array instance and to prevent mutability.
*
* @param values double vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector of(double... values) {
Assert.notNull(values, "double vector values must not be null");
return DoubleVector.copy(values);
}
/**
* Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing
* a mutable array instance and to prevent mutability.
*
* @param values number vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector of(Number... values) {
Assert.notNull(values, "Vector values must not be null");
return NumberVector.copy(values);
}
/**
* Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing
* a mutable collection instance and to prevent mutability.
*
* @param values number vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector of(Collection<Number> values) {
Assert.notNull(values, "Vector values must not be null");
return NumberVector.copy(values);
}
/**
* Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array
* values and are merely a view on the source array.
* <p>
* Supported source type
*
* @param values vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector unsafe(float[] values) {
Assert.notNull(values, "float vector values must not be null");
return new FloatVector(values);
}
/**
* Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array
* values and are merely a view on the source array.
* <p>
* Supported source type
*
* @param values vector values.
* @return the {@link Vector} for the given vector values.
*/
static Vector unsafe(double[] values) {
Assert.notNull(values, "double vector values must not be null");
return new DoubleVector(values);
}
/**
* Returns the type of the underlying vector source.
*
* @return the type of the underlying vector source.
*/
Class<? extends Number> getType();
/**
* Returns the source array of the vector. The source array is not copied and should not be modified to avoid
* mutability issues. This method should be used for performance access.
*
* @return the source array of the vector.
*/
Object getSource();
/**
* Returns the number of dimensions.
*
* @return the number of dimensions.
*/
int size();
/**
* Convert the vector to a {@code float} array. The returned array is a copy of the {@link #getSource() source} array
* and can be modified safely.
* <p>
* Conversion to {@code float} can incorporate loss of precision or result in values with a slight offset due to data
* type conversion if the source is not a {@code float} array.
*
* @return a new {@code float} array representing the vector point.
*/
float[] toFloatArray();
/**
* Convert the vector to a {@code double} array. The returned array is a copy of the {@link #getSource() source} array
* and can be modified safely.
* <p>
* Conversion to {@code double} can incorporate loss of precision or result in values with a slight offset due to data
* type conversion if the source is not a {@code double} array.
*
* @return a new {@code double} array representing the vector point.
*/
double[] toDoubleArray();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment