Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.avro.reflect;

import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
Expand All @@ -30,7 +31,8 @@
* file. Use of {@link org.apache.avro.io.ValidatingEncoder} is recommended.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
@Inherited
@Target({ ElementType.FIELD, ElementType.TYPE })
public @interface AvroEncode {
Class<? extends CustomEncoding<?>> using();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class FieldAccessReflect extends FieldAccess {

@Override
protected FieldAccessor getAccessor(Field field) {
AvroEncode enc = field.getAnnotation(AvroEncode.class);
AvroEncode enc = ReflectionUtil.getAvroEncode(field);
if (enc != null)
try {
return new ReflectionBasesAccessorCustomEncoded(field, enc.using().getDeclaredConstructor().newInstance());
Expand All @@ -47,7 +47,7 @@ public ReflectionBasedAccessor(Field field) {
this.field = field;
this.field.setAccessible(true);
isStringable = field.isAnnotationPresent(Stringable.class);
isCustomEncoded = field.isAnnotationPresent(AvroEncode.class);
isCustomEncoded = ReflectionUtil.getAvroEncode(field) != null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ public class ReflectData extends SpecificData {

private static final String STRING_OUTER_PARENT_REFERENCE = "this$0";

// holds a wrapper so null entries will have a cached value
private final ConcurrentMap<Schema, CustomEncodingWrapper> encoderCache = new ConcurrentHashMap<>();

/**
* Always false since custom coders are not available for {@link ReflectData}.
*/
Expand Down Expand Up @@ -864,7 +867,7 @@ private static Field[] getFields(Class<?> recordClass, boolean excludeJava) {

/** Create a schema for a field. */
protected Schema createFieldSchema(Field field, Map<String, Schema> names) {
AvroEncode enc = field.getAnnotation(AvroEncode.class);
AvroEncode enc = ReflectionUtil.getAvroEncode(field);
if (enc != null)
try {
return enc.using().getDeclaredConstructor().newInstance().getSchema();
Expand Down Expand Up @@ -1042,4 +1045,36 @@ public Object newRecord(Object old, Schema schema) {
}
return super.newRecord(old, schema);
}

public CustomEncoding getCustomEncoding(Schema schema) {

return this.encoderCache.computeIfAbsent(schema, this::populateEncoderCache).get();
}

private CustomEncodingWrapper populateEncoderCache(Schema schema) {
var enc = ReflectionUtil.getAvroEncode(getClass(schema));
if (enc != null) {
try {
return new CustomEncodingWrapper(enc.using().getDeclaredConstructor().newInstance());
} catch (Exception e) {
throw new AvroRuntimeException("Could not instantiate custom Encoding");
}
}
return new CustomEncodingWrapper(null);
}

private static class CustomEncodingWrapper {

private final CustomEncoding customEncoding;

private CustomEncodingWrapper(CustomEncoding customEncoding) {
this.customEncoding = customEncoding;
}

public CustomEncoding get() {
return customEncoding;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public ReflectDatumReader(ReflectData data) {
super(data);
}

private ReflectData getReflectData() {
return (ReflectData) getSpecificData();
}

@Override
protected Object newArray(Object old, int size, Schema schema) {
Class<?> collectionClass = ReflectData.getClassProp(schema, SpecificData.CLASS_PROP);
Expand Down Expand Up @@ -251,6 +255,16 @@ protected Object readBytes(Object old, Schema s, Decoder in) throws IOException
}
}

@Override
protected Object read(Object old, Schema expected, ResolvingDecoder in) throws IOException {
CustomEncoding encoder = getReflectData().getCustomEncoding(expected);
if (encoder != null) {
return encoder.read(old, in);
} else {
return super.read(old, expected, in);
}
}

@Override
protected Object readInt(Object old, Schema expected, Decoder in) throws IOException {
Object value = in.readInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ protected ReflectDatumWriter(ReflectData reflectData) {
super(reflectData);
}

private ReflectData getReflectData() {
return (ReflectData) getSpecificData();
}

/**
* Called to write a array. May be overridden for alternate array
* representations.
Expand Down Expand Up @@ -158,7 +162,13 @@ else if (datum instanceof Map && ReflectData.isNonStringMapSchema(schema)) {
datum = ((Optional) datum).orElse(null);
}
try {
super.write(schema, datum, out);

CustomEncoding encoder = getReflectData().getCustomEncoding(schema);
if (encoder != null) {
encoder.write(datum, out);
} else {
super.write(schema, datum, out);
}
} catch (NullPointerException e) { // improve error message
throw npe(e, " in " + schema.getFullName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
Expand Down Expand Up @@ -188,4 +189,19 @@ public static <V, R> Function<V, R> getConstructorAsFunction(Class<V> parameterC
}
}

protected static AvroEncode getAvroEncode(Field field) {
var enc = field.getAnnotation(AvroEncode.class);
if (enc != null) {
return enc;
} else {
return getAvroEncode(field.getType());
}
}

protected static AvroEncode getAvroEncode(Class<?> clazz) {
if (clazz == null) {
return null;
}
return clazz.getAnnotation(AvroEncode.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.avro.reflect;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;

import org.apache.avro.AvroTypeException;
import org.apache.avro.Schema;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
import org.junit.jupiter.api.Test;

public class TestAvroEncode {
EncoderFactory factory = new EncoderFactory();

@Test
void testWithinClass() throws IOException {

var wrapper = new Wrapper(new R1("test"));

var read = readWrite(wrapper);

assertEquals("test", wrapper.getR1().getValue());
assertEquals("test used this", read.getR1().getValue());
}

@Test
void testDirect() throws IOException {

var r1 = new R1("test");

var read = readWrite(r1);

assertEquals("test", r1.getValue());
assertEquals("test used this", read.getValue());
}

@Test
void testFieldAnnotationTakesPrecedence() throws IOException {

var wrapper = new OtherWrapper(new R1("test"));

var read = readWrite(wrapper);

assertEquals("test", wrapper.getR1().getValue());
assertEquals("test used other", read.getR1().getValue());
}

public static class Wrapper {

private R1 r1;

public Wrapper() {
}

public Wrapper(R1 r1) {
this.r1 = r1;
}

public R1 getR1() {
return r1;
}

public void setR1(R1 r1) {
this.r1 = r1;
}

}

public static class OtherWrapper {
@AvroEncode(using = R1EncodingOther.class)
private R1 r1;

public OtherWrapper() {
}

public OtherWrapper(R1 r1) {
this.r1 = r1;
}

public R1 getR1() {
return r1;
}

public void setR1(R1 r1) {
this.r1 = r1;
}

}

@AvroEncode(using = R1Encoding.class)
public static class R1 {

private final String value;

public R1(String value) {
this.value = value;
}

public String getValue() {
return value;
}

}

public static class R1Encoding extends CustomEncoding<R1> {

{
schema = Schema.createRecord("R1", null, null, false,
Arrays.asList(new Schema.Field("value", Schema.create(Schema.Type.STRING), null, null)));
}

@Override
protected void write(Object datum, Encoder out) throws IOException {
if (datum instanceof R1) {
out.writeString(((R1) datum).getValue());
} else {
throw new AvroTypeException("Expected R1, got " + datum.getClass());
}

}

@Override
protected R1 read(Object reuse, Decoder in) throws IOException {
return new R1(in.readString() + " used this");
}
}

public static class R1EncodingOther extends CustomEncoding<R1> {

{
schema = Schema.createRecord("R1", null, null, false,
Arrays.asList(new Schema.Field("value", Schema.create(Schema.Type.STRING), null, null)));
}

@Override
protected void write(Object datum, Encoder out) throws IOException {
if (datum instanceof R1) {
out.writeString(((R1) datum).getValue());
} else {
throw new AvroTypeException("Expected R1, got " + datum.getClass());
}
}

@Override
protected R1 read(Object reuse, Decoder in) throws IOException {
return new R1(in.readString() + " used other");
}
}

<T> T readWrite(T object) throws IOException {
var schema = new ReflectData().getSchema(object.getClass());
ReflectDatumWriter<T> writer = new ReflectDatumWriter<>(schema);
ByteArrayOutputStream out = new ByteArrayOutputStream();
writer.write(object, factory.directBinaryEncoder(out, null));
ReflectDatumReader<T> reader = new ReflectDatumReader<>(schema);
return reader.read(null, DecoderFactory.get().binaryDecoder(out.toByteArray(), null));
}
}