Created
October 30, 2024 11:04
-
-
Save mp911de/c302dbfbdc9590f038054e6d016cd6b5 to your computer and use it in GitHub Desktop.
Vectors π±π€―
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import java.util.Arrays; | |
import org.springframework.util.ObjectUtils; | |
/** | |
* {@link Vector} implementation based on {@code double} array. | |
* | |
* @author Mark Paluch | |
*/ | |
class DoubleVector implements Vector { | |
private final double[] v; | |
public DoubleVector(double[] v) { | |
this.v = v; | |
} | |
/** | |
* Copy the given {@code double} array and wrap it within a Vector. | |
*/ | |
static Vector copy(double[] v) { | |
double[] copy = new double[v.length]; | |
System.arraycopy(v, 0, copy, 0, copy.length); | |
return new DoubleVector(copy); | |
} | |
@Override | |
public Class<Double> getType() { | |
return Double.TYPE; | |
} | |
@Override | |
public Object getSource() { | |
return v; | |
} | |
@Override | |
public int size() { | |
return v.length; | |
} | |
@Override | |
public float[] toFloatArray() { | |
float[] copy = new float[this.v.length]; | |
for (int i = 0; i < this.v.length; i++) { | |
copy[i] = (float) this.v[i]; | |
} | |
return copy; | |
} | |
@Override | |
public double[] toDoubleArray() { | |
double[] copy = new double[this.v.length]; | |
System.arraycopy(this.v, 0, copy, 0, copy.length); | |
return copy; | |
} | |
@Override | |
public boolean equals(Object o) { | |
if (this == o) { | |
return true; | |
} | |
if (!(o instanceof DoubleVector that)) { | |
return false; | |
} | |
return ObjectUtils.nullSafeEquals(v, that.v); | |
} | |
@Override | |
public int hashCode() { | |
return Arrays.hashCode(v); | |
} | |
@Override | |
public String toString() { | |
return "D" + Arrays.toString(v); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import static org.assertj.core.api.Assertions.*; | |
import org.junit.jupiter.api.Test; | |
/** | |
* Unit tests for {@link DoubleVector}. | |
* | |
* @author Mark Paluch | |
*/ | |
class DoubleVectorUnitTests { | |
double[] values = new double[] { 1.1, 2.2, 3.3, 4.4, 5.5 }; | |
float[] floats = new float[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d }; | |
@Test | |
void shouldCreateVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.size()).isEqualTo(5); | |
assertThat(vector.getType()).isEqualTo(Double.TYPE); | |
} | |
@Test | |
void shouldCreateUnsafeVector() { | |
Vector vector = Vector.unsafe(values); | |
assertThat(vector.getSource()).isSameAs(values); | |
} | |
@Test | |
void shouldCopyVectorValues() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); | |
} | |
@Test | |
void shouldRenderToString() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).hasToString("D[1.1, 2.2, 3.3, 4.4, 5.5]"); | |
} | |
@Test | |
void shouldCompareVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).isEqualTo(Vector.of(values)); | |
assertThat(vector).hasSameHashCodeAs(Vector.of(values)); | |
} | |
@Test | |
void sourceShouldReturnSource() { | |
Vector vector = new DoubleVector(values); | |
assertThat(vector.getSource()).isSameAs(values); | |
} | |
@Test | |
void shouldCreateFloatArray() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.toFloatArray()).isEqualTo(floats).isNotSameAs(floats); | |
} | |
@Test | |
void shouldCreateDoubleArray() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import java.util.Arrays; | |
import org.springframework.util.ObjectUtils; | |
/** | |
* {@link Vector} implementation based on {@code float} array. | |
* | |
* @author Mark Paluch | |
*/ | |
class FloatVector implements Vector { | |
private final float[] v; | |
public FloatVector(float[] v) { | |
this.v = v; | |
} | |
/** | |
* Copy the given {@code float} array and wrap it within a Vector. | |
*/ | |
static Vector copy(float[] v) { | |
float[] copy = new float[v.length]; | |
System.arraycopy(v, 0, copy, 0, copy.length); | |
return new FloatVector(copy); | |
} | |
@Override | |
public Class<Float> getType() { | |
return Float.TYPE; | |
} | |
@Override | |
public Object getSource() { | |
return v; | |
} | |
@Override | |
public int size() { | |
return v.length; | |
} | |
@Override | |
public float[] toFloatArray() { | |
float[] copy = new float[this.v.length]; | |
System.arraycopy(this.v, 0, copy, 0, copy.length); | |
return copy; | |
} | |
@Override | |
public double[] toDoubleArray() { | |
double[] copy = new double[this.v.length]; | |
for (int i = 0; i < this.v.length; i++) { | |
copy[i] = this.v[i]; | |
} | |
return copy; | |
} | |
@Override | |
public boolean equals(Object o) { | |
if (this == o) { | |
return true; | |
} | |
if (!(o instanceof FloatVector that)) { | |
return false; | |
} | |
return ObjectUtils.nullSafeEquals(v, that.v); | |
} | |
@Override | |
public int hashCode() { | |
return Arrays.hashCode(v); | |
} | |
@Override | |
public String toString() { | |
return "F" + Arrays.toString(v); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import static org.assertj.core.api.Assertions.*; | |
import org.junit.jupiter.api.Test; | |
/** | |
* Unit tests for {@link FloatVector}. | |
* | |
* @author Mark Paluch | |
*/ | |
class FloatVectorUnitTests { | |
float[] values = new float[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f }; | |
double[] doubles = new double[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f }; | |
@Test | |
void shouldCreateVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.size()).isEqualTo(5); | |
assertThat(vector.getType()).isEqualTo(Float.TYPE); | |
} | |
@Test | |
void shouldCreateUnsafeVector() { | |
Vector vector = Vector.unsafe(values); | |
assertThat(vector.getSource()).isSameAs(values); | |
} | |
@Test | |
void shouldCopyVectorValues() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); | |
} | |
@Test | |
void shouldRenderToString() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).hasToString("F[1.1, 2.2, 3.3, 4.4, 5.5]"); | |
} | |
@Test | |
void shouldCompareVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).isEqualTo(Vector.of(values)); | |
assertThat(vector).hasSameHashCodeAs(Vector.of(values)); | |
} | |
@Test | |
void sourceShouldReturnSource() { | |
Vector vector = new FloatVector(values); | |
assertThat(vector.getSource()).isSameAs(values); | |
} | |
@Test | |
void shouldCreateFloatArray() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(values); | |
} | |
@Test | |
void shouldCreateDoubleArray() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.toDoubleArray()).isEqualTo(doubles).isNotSameAs(doubles); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import java.util.Arrays; | |
import java.util.Collection; | |
import org.springframework.util.ObjectUtils; | |
/** | |
* {@link Vector} implementation based on {@link Number} array. | |
* | |
* @author Mark Paluch | |
*/ | |
class NumberVector implements Vector { | |
private final Number[] v; | |
public NumberVector(Number[] v) { | |
this.v = v; | |
} | |
/** | |
* Copy the given {@link Number} array and wrap it within a Vector. | |
*/ | |
static Vector copy(Number[] v) { | |
Number[] copy = new Number[v.length]; | |
System.arraycopy(v, 0, copy, 0, copy.length); | |
return new NumberVector(copy); | |
} | |
/** | |
* Copy the given {@link Number} and wrap it within a Vector. | |
*/ | |
static Vector copy(Collection<Number> numbers) { | |
Number[] copy = new Number[numbers.size()]; | |
int i = 0; | |
for (Number number : numbers) { | |
copy[i++] = number; | |
} | |
return new NumberVector(copy); | |
} | |
@Override | |
public Class<? extends Number> getType() { | |
Class<?> candidate = null; | |
for (Object val : v) { | |
if (val != null) { | |
if (candidate == null) { | |
candidate = val.getClass(); | |
} else if (candidate != val.getClass()) { | |
return Number.class; | |
} | |
} | |
} | |
return (Class<? extends Number>) candidate; | |
} | |
@Override | |
public Object getSource() { | |
return v; | |
} | |
@Override | |
public int size() { | |
return v.length; | |
} | |
@Override | |
public float[] toFloatArray() { | |
float[] copy = new float[this.v.length]; | |
for (int i = 0; i < this.v.length; i++) { | |
copy[i] = this.v[i].floatValue(); | |
} | |
return copy; | |
} | |
@Override | |
public double[] toDoubleArray() { | |
double[] copy = new double[this.v.length]; | |
for (int i = 0; i < this.v.length; i++) { | |
copy[i] = this.v[i].doubleValue(); | |
} | |
return copy; | |
} | |
@Override | |
public boolean equals(Object o) { | |
if (this == o) { | |
return true; | |
} | |
if (!(o instanceof NumberVector that)) { | |
return false; | |
} | |
return ObjectUtils.nullSafeEquals(v, that.v); | |
} | |
@Override | |
public int hashCode() { | |
return Arrays.hashCode(v); | |
} | |
@Override | |
public String toString() { | |
return "N" + Arrays.toString(v); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import static org.assertj.core.api.Assertions.*; | |
import org.junit.jupiter.api.Test; | |
/** | |
* Unit tests for {@link NumberVector}. | |
* | |
* @author Mark Paluch | |
*/ | |
class NumberVectorUnitTests { | |
Number[] values = new Number[] { 1.1, 2.2, 3.3, 4.4, 5.5 }; | |
Number[] floats = new Number[] { (float) 1.1d, (float) 2.2d, (float) 3.3d, (float) 4.4d, (float) 5.5d }; | |
@Test | |
void shouldCreateVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.size()).isEqualTo(5); | |
assertThat(vector.getType()).isEqualTo(Double.class); | |
} | |
@Test | |
void shouldCopyVectorValues() { | |
Vector vector = Vector.of(values); | |
assertThat(vector.getSource()).isNotSameAs(vector).isEqualTo(values); | |
} | |
@Test | |
void shouldRenderToString() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).hasToString("N[1.1, 2.2, 3.3, 4.4, 5.5]"); | |
} | |
@Test | |
void shouldCompareVector() { | |
Vector vector = Vector.of(values); | |
assertThat(vector).isEqualTo(Vector.of(values)); | |
assertThat(vector).hasSameHashCodeAs(Vector.of(values)); | |
} | |
@Test | |
void sourceShouldReturnSource() { | |
Vector vector = new NumberVector(values); | |
assertThat(vector.getSource()).isSameAs(values); | |
} | |
@Test | |
void shouldCreateFloatArray() { | |
Vector vector = Vector.of(values); | |
float[] values = new float[this.floats.length]; | |
for (int i = 0; i < values.length; i++) { | |
values[i] = this.floats[i].floatValue(); | |
} | |
assertThat(vector.toFloatArray()).isEqualTo(values).isNotSameAs(floats); | |
} | |
@Test | |
void shouldCreateDoubleArray() { | |
Vector vector = Vector.of(values); | |
double[] values = new double[this.values.length]; | |
for (int i = 0; i < values.length; i++) { | |
values[i] = this.values[i].doubleValue(); | |
} | |
assertThat(vector.toDoubleArray()).isEqualTo(values).isNotSameAs(values); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2024 the original author or authors. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* https://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.springframework.data.domain; | |
import java.util.Collection; | |
import org.springframework.util.Assert; | |
/** | |
* A vector is a fixed-length array of non-null numeric values. Vectors are represent a point in a multidimensional | |
* space that is commonly used in machine learning and statistics. | |
* <p> | |
* Vector properties do not map cleanly to an existing class in the standard JDK Collections hierarchy. Vectors when | |
* used with embeddings (machine learning) represent an opaque point in the vector space that does not expose meaningful | |
* properties nor guarantees computational values to the outside world. | |
* <p> | |
* Vectors should be treated as opaque values and should not be modified. They can be created from an array of numbers | |
* (typically {@code double} or {@code float} values) and used by components that need to provide the vector for storage | |
* or computation. | |
* | |
* @author Mark Paluch | |
*/ | |
public interface Vector { | |
/** | |
* Creates a new {@link Vector} from the given float {@code values}. Vector values are duplicated to avoid capturing a | |
* mutable array instance and to prevent mutability. | |
* | |
* @param values float vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector of(float... values) { | |
Assert.notNull(values, "float vector values must not be null"); | |
return FloatVector.copy(values); | |
} | |
/** | |
* Creates a new {@link Vector} from the given double {@code values}. Vector values are duplicated to avoid capturing | |
* a mutable array instance and to prevent mutability. | |
* | |
* @param values double vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector of(double... values) { | |
Assert.notNull(values, "double vector values must not be null"); | |
return DoubleVector.copy(values); | |
} | |
/** | |
* Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing | |
* a mutable array instance and to prevent mutability. | |
* | |
* @param values number vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector of(Number... values) { | |
Assert.notNull(values, "Vector values must not be null"); | |
return NumberVector.copy(values); | |
} | |
/** | |
* Creates a new {@link Vector} from the given number {@code values}. Vector values are duplicated to avoid capturing | |
* a mutable collection instance and to prevent mutability. | |
* | |
* @param values number vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector of(Collection<Number> values) { | |
Assert.notNull(values, "Vector values must not be null"); | |
return NumberVector.copy(values); | |
} | |
/** | |
* Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array | |
* values and are merely a view on the source array. | |
* <p> | |
* Supported source type | |
* | |
* @param values vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector unsafe(float[] values) { | |
Assert.notNull(values, "float vector values must not be null"); | |
return new FloatVector(values); | |
} | |
/** | |
* Creates a new unsafe {@link Vector} wrapper from the given {@code values}. Unsafe wrappers do not duplicate array | |
* values and are merely a view on the source array. | |
* <p> | |
* Supported source type | |
* | |
* @param values vector values. | |
* @return the {@link Vector} for the given vector values. | |
*/ | |
static Vector unsafe(double[] values) { | |
Assert.notNull(values, "double vector values must not be null"); | |
return new DoubleVector(values); | |
} | |
/** | |
* Returns the type of the underlying vector source. | |
* | |
* @return the type of the underlying vector source. | |
*/ | |
Class<? extends Number> getType(); | |
/** | |
* Returns the source array of the vector. The source array is not copied and should not be modified to avoid | |
* mutability issues. This method should be used for performance access. | |
* | |
* @return the source array of the vector. | |
*/ | |
Object getSource(); | |
/** | |
* Returns the number of dimensions. | |
* | |
* @return the number of dimensions. | |
*/ | |
int size(); | |
/** | |
* Convert the vector to a {@code float} array. The returned array is a copy of the {@link #getSource() source} array | |
* and can be modified safely. | |
* <p> | |
* Conversion to {@code float} can incorporate loss of precision or result in values with a slight offset due to data | |
* type conversion if the source is not a {@code float} array. | |
* | |
* @return a new {@code float} array representing the vector point. | |
*/ | |
float[] toFloatArray(); | |
/** | |
* Convert the vector to a {@code double} array. The returned array is a copy of the {@link #getSource() source} array | |
* and can be modified safely. | |
* <p> | |
* Conversion to {@code double} can incorporate loss of precision or result in values with a slight offset due to data | |
* type conversion if the source is not a {@code double} array. | |
* | |
* @return a new {@code double} array representing the vector point. | |
*/ | |
double[] toDoubleArray(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment