VectorsAsBinarySerTest.java
package com.fasterxml.jackson.databind.ser.jdk;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.core.Base64Variants;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.testutil.DatabindTestUtil;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* Tests for serialization (and deserialization) of {@code float[]}
* as "packed binary" data, as per [databind#5242].
*/
public class VectorsAsBinarySerTest extends DatabindTestUtil
{
private final static float[] FLOAT_VECTOR = new float[] { 1.0f, 0.5f, -1.25f };
private final static String FLOAT_VECTOR_STR = "[1.0,0.5,-1.25]";
private final static double[] DOUBLE_VECTOR = new double[] { -1.0, 1.5, 0.0125 };
private final static String DOUBLE_VECTOR_STR = "[-1.0,1.5,0.0125]";
static class BeanWithArrayFloatVector {
@JsonFormat(shape = JsonFormat.Shape.NATURAL) // or ARRAY
public float[] vector;
protected BeanWithArrayFloatVector() { }
public BeanWithArrayFloatVector(float[] v) {
vector = v;
}
}
static class BeanWithBinaryFloatVector {
@JsonFormat(shape = JsonFormat.Shape.BINARY)
public float[] vector;
protected BeanWithBinaryFloatVector() { }
public BeanWithBinaryFloatVector(float[] v) {
vector = v;
}
}
static class BeanWithArrayDoubleVector {
@JsonFormat(shape = JsonFormat.Shape.NATURAL) // or ARRAY
public double[] vector;
protected BeanWithArrayDoubleVector() { }
public BeanWithArrayDoubleVector(double[] v) {
vector = v;
}
}
static class BeanWithBinaryDoubleVector {
@JsonFormat(shape = JsonFormat.Shape.BINARY)
public double[] vector;
protected BeanWithBinaryDoubleVector() { }
public BeanWithBinaryDoubleVector(double[] v) {
vector = v;
}
}
private final ObjectMapper VANILLA_MAPPER = sharedMapper();
private final ObjectMapper BINARY_VECTOR_MAPPER = jsonMapperBuilder()
.withConfigOverride(float[].class,
c -> c.setFormat(JsonFormat.Value.forShape(JsonFormat.Shape.BINARY)))
.withConfigOverride(double[].class,
c -> c.setFormat(JsonFormat.Value.forShape(JsonFormat.Shape.BINARY)))
.build();
// // // Float Vector tests, as-Array
@Test
public void defaultFloatVectorSerialization() throws Exception {
String json = VANILLA_MAPPER.writeValueAsString(FLOAT_VECTOR);
assertEquals(FLOAT_VECTOR_STR, json);
float[] result = VANILLA_MAPPER.readValue(json, float[].class);
assertArrayEquals(FLOAT_VECTOR, result);
}
@Test
public void asArrayFloatVectorSerialization() throws Exception {
final String exp = a2q("{'vector':"+FLOAT_VECTOR_STR+"}");
String json = VANILLA_MAPPER.writeValueAsString(new BeanWithArrayFloatVector(FLOAT_VECTOR));
assertEquals(exp, json);
// And annotation overrides default shape override
assertEquals(exp,
BINARY_VECTOR_MAPPER.writeValueAsString(new BeanWithArrayFloatVector(FLOAT_VECTOR)));
BeanWithArrayFloatVector result = VANILLA_MAPPER.readValue(json, BeanWithArrayFloatVector.class);
assertArrayEquals(FLOAT_VECTOR, result.vector);
}
// // // Float Vector tests, as-Binary
@Test
public void asBinaryFloatVectorSerializationRoot() throws Exception {
String json = BINARY_VECTOR_MAPPER.writeValueAsString(FLOAT_VECTOR);
assertEquals(q(base64Encode(asBinary(FLOAT_VECTOR))), json);
float[] result = BINARY_VECTOR_MAPPER.readValue(json, float[].class);
assertArrayEquals(FLOAT_VECTOR, result);
}
@Test
public void asBinaryFloatVectorSerializationPOJO() throws Exception {
String json = VANILLA_MAPPER.writeValueAsString(new BeanWithBinaryFloatVector(FLOAT_VECTOR));
assertEquals(a2q("{'vector':'"+base64Encode(asBinary(FLOAT_VECTOR))+"'}"), json);
BeanWithArrayFloatVector result = VANILLA_MAPPER.readValue(json, BeanWithArrayFloatVector.class);
assertArrayEquals(FLOAT_VECTOR, result.vector);
}
// // // Double Vector tests, as-Array
@Test
public void defaultDoubleVectorSerialization() throws Exception {
String json = VANILLA_MAPPER.writeValueAsString(DOUBLE_VECTOR);
assertEquals(DOUBLE_VECTOR_STR, json);
double[] result = VANILLA_MAPPER.readValue(json, double[].class);
assertArrayEquals(DOUBLE_VECTOR, result);
}
@Test
public void asArrayDoubleVectorSerialization() throws Exception {
String exp = a2q("{'vector':"+DOUBLE_VECTOR_STR+"}");
String json = VANILLA_MAPPER.writeValueAsString(new BeanWithArrayDoubleVector(DOUBLE_VECTOR));
assertEquals(exp, json);
// And annotation overrides default shape override
assertEquals(exp,
BINARY_VECTOR_MAPPER.writeValueAsString(new BeanWithArrayDoubleVector(DOUBLE_VECTOR)));
BeanWithArrayDoubleVector result = VANILLA_MAPPER.readValue(json, BeanWithArrayDoubleVector.class);
assertArrayEquals(DOUBLE_VECTOR, result.vector);
}
// // // Double Vector tests, as-Binary
@Test
public void asBinaryDoubleVectorSerializationRoot() throws Exception {
String json = BINARY_VECTOR_MAPPER.writeValueAsString(DOUBLE_VECTOR);
assertEquals(q(base64Encode(asBinary(DOUBLE_VECTOR))), json);
double[] result = BINARY_VECTOR_MAPPER.readValue(json, double[].class);
assertArrayEquals(DOUBLE_VECTOR, result);
}
@Test
public void asBinaryDoubleVectorSerializationPOJO() throws Exception {
String json = VANILLA_MAPPER.writeValueAsString(new BeanWithBinaryDoubleVector(DOUBLE_VECTOR));
assertEquals(a2q("{'vector':'"+base64Encode(asBinary(DOUBLE_VECTOR))+"'}"), json);
BeanWithBinaryDoubleVector result = VANILLA_MAPPER.readValue(json, BeanWithBinaryDoubleVector.class);
assertArrayEquals(DOUBLE_VECTOR, result.vector);
}
// // // Helper methods
private static byte[] asBinary(float[] vector) {
byte[] result = new byte[vector.length * 4];
for (int i = 0; i < vector.length; i++) {
int bits = Float.floatToIntBits(vector[i]);
result[i * 4] = (byte) (bits >> 24);
result[i * 4 + 1] = (byte) (bits >> 16);
result[i * 4 + 2] = (byte) (bits >> 8);
result[i * 4 + 3] = (byte) bits;
}
return result;
}
private static byte[] asBinary(double[] vector) {
byte[] result = new byte[vector.length * 8];
for (int i = 0; i < vector.length; i++) {
long bits = Double.doubleToLongBits(vector[i]);
result[i * 8] = (byte) (bits >> 56);
result[i * 8 + 1] = (byte) (bits >> 48);
result[i * 8 + 2] = (byte) (bits >> 40);
result[i * 8 + 3] = (byte) (bits >> 32);
result[i * 8 + 4] = (byte) (bits >> 24);
result[i * 8 + 5] = (byte) (bits >> 16);
result[i * 8 + 6] = (byte) (bits >> 8);
result[i * 8 + 7] = (byte) bits;
}
return result;
}
private String base64Encode(byte[] data) {
return Base64Variants.getDefaultVariant().encode(data, false);
}
}