Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.GenericArrayType;
Expand Down Expand Up @@ -69,6 +70,24 @@ public class ReflectData extends SpecificData {

private static final String STRING_OUTER_PARENT_REFERENCE = "this$0";

private static final Method IS_SEALED_METHOD;
private static final Method GET_PERMITTED_SUBCLASSES_METHOD;

static {
Class<? extends Class> classClass = Class.class;
Method isSealed;
Method getPermittedSubclasses;
try {
isSealed = classClass.getMethod("isSealed");
getPermittedSubclasses = classClass.getMethod("getPermittedSubclasses");
} catch (NoSuchMethodException | SecurityException e) {
isSealed = null;
getPermittedSubclasses = null;
}
IS_SEALED_METHOD = isSealed;
GET_PERMITTED_SUBCLASSES_METHOD = getPermittedSubclasses;
}

// holds a wrapper so null entries will have a cached value
private final ConcurrentMap<Schema, CustomEncodingWrapper> encoderCache = new ConcurrentHashMap<>();

Expand Down Expand Up @@ -713,7 +732,7 @@ protected Schema createSchema(Type type, Map<String, Schema> names) {
String space = c.getPackage() == null ? "" : c.getPackage().getName();
if (c.getEnclosingClass() != null) // nested class
space = c.getEnclosingClass().getName().replace('$', '.');
Union union = c.getAnnotation(Union.class);
Class[] union = getUnion(c);
if (union != null) { // union annotated
return getAnnotatedUnion(union, names);
} else if (isStringable(c)) { // Stringable
Expand Down Expand Up @@ -819,10 +838,46 @@ private void setElement(Schema schema, Type element) {
schema.addProp(ELEMENT_PROP, c.getName());
}

private Class[] getUnion(AnnotatedElement element) {
Union union = element.getAnnotation(Union.class);
if (union != null) {
return union.value();
}

if (element instanceof Class) {
// automatic sealed class polymorphic
try {
if (IS_SEALED_METHOD != null && Boolean.TRUE.equals(IS_SEALED_METHOD.invoke(element))) {
var subclasses = (Class<?>[]) GET_PERMITTED_SUBCLASSES_METHOD.invoke(element);

List<Class> subclassList = new ArrayList<>();

for (Class<?> subclass : subclasses) {
if (Modifier.isAbstract(subclass.getModifiers()) || Modifier.isInterface(subclass.getModifiers())) {

var subUnion = getUnion(subclass); // recursively process subclasses
if (subUnion != null) {
subclassList.addAll(List.of(subUnion));
}
continue;
}
subclassList.add(subclass);
}
if (!subclassList.isEmpty()) {
return subclassList.toArray(new Class[0]);
}
}
} catch (ReflectiveOperationException e) {
throw new AvroRuntimeException(e);
}
}
return null;
}

// construct a schema from a union annotation
private Schema getAnnotatedUnion(Union union, Map<String, Schema> names) {
private Schema getAnnotatedUnion(Class[] union, Map<String, Schema> names) {
List<Schema> branches = new ArrayList<>();
for (Class branch : union.value())
for (Class branch : union)
branches.add(createSchema(branch, names));
return Schema.createUnion(branches);
}
Expand Down Expand Up @@ -889,7 +944,7 @@ protected Schema createFieldSchema(Field field, Map<String, Schema> names) {

Union union = field.getAnnotation(Union.class);
if (union != null)
return getAnnotatedUnion(union, names);
return getAnnotatedUnion(union.value(), names);

Schema schema = createSchema(field.getGenericType(), names);
if (field.isAnnotationPresent(Stringable.class)) { // Stringable
Expand Down Expand Up @@ -936,7 +991,7 @@ private Message getMessage(Method method, Protocol protocol, Map<String, Schema>
if (annotation instanceof AvroSchema) // explicit schema
paramSchema = new Schema.Parser().parse(((AvroSchema) annotation).value());
else if (annotation instanceof Union) // union
paramSchema = getAnnotatedUnion(((Union) annotation), names);
paramSchema = getAnnotatedUnion(((Union) annotation).value(), names);
else if (annotation instanceof Nullable) // nullable
paramSchema = makeNullable(paramSchema);
}
Expand All @@ -948,7 +1003,7 @@ else if (annotation instanceof Nullable) // nullable
Type genericReturnType = method.getGenericReturnType();
Type returnType = genericTypeMap.getOrDefault(genericReturnType, genericReturnType);
Union union = method.getAnnotation(Union.class);
Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union, names);
Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union.value(), names);
if (method.isAnnotationPresent(Nullable.class)) // nullable
response = makeNullable(response);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.avro.reflect;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import org.apache.avro.Schema;
import org.apache.avro.file.DataFileStream;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.io.DatumReader;
import org.junit.jupiter.api.Test;

public class TestPolymorphicEncoding {

@Test
public void testPolymorphicEncoding() throws IOException {
List<Animal> expected = Arrays.asList(new Cat("Green"), new Dog(5));
byte[] encoded = write(Animal.class, expected);
List<Animal> decoded = read(encoded);

assertEquals(expected, decoded);
}

@Test
public void testPolymorphicEncodingMultipleLevels() throws IOException {
List<Animal> expected = Arrays.asList(new Cat("Calico"), new Takahe(3.2));
byte[] encoded = write(Animal.class, expected);
List<Animal> decoded = read(encoded);

assertEquals(expected, decoded);
}

private <T> List<T> read(byte[] toDecode) throws IOException {
DatumReader<T> datumReader = new ReflectDatumReader<>();
try (DataFileStream<T> dataFileReader = new DataFileStream<>(new ByteArrayInputStream(toDecode, 0, toDecode.length),
datumReader)) {
List<T> toReturn = new ArrayList<>();
while (dataFileReader.hasNext()) {
toReturn.add(dataFileReader.next());
}
return toReturn;
}
}

private <T> byte[] write(Class<?> type, List<T> custom) {
Schema schema = ReflectData.get().getSchema(type);
ReflectDatumWriter<T> datumWriter = new ReflectDatumWriter<>();
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataFileWriter<T> writer = new DataFileWriter<>(datumWriter)) {
writer.create(schema, baos);
for (T c : custom) {
writer.append(c);
}
writer.flush();
return baos.toByteArray();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public static sealed interface Animal permits Cat,Dog,Bird {
}

public static final class Dog implements Animal {

private int size;

public Dog() {
}

public Dog(int size) {
this.size = size;
}

public int getSize() {
return size;
}

@Override
public int hashCode() {
return Objects.hash(size);
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Dog other = (Dog) obj;
return size == other.size;
}

}

public static final class Cat implements Animal {

private String color;

public Cat() {
}

public Cat(String color) {
super();
this.color = color;
}

public String getColor() {
return color;
}

@Override
public int hashCode() {
return Objects.hash(color);
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Cat other = (Cat) obj;
return Objects.equals(color, other.color);
}

}

public static sealed interface Bird extends Animal permits Kea,Takahe {
}

public static final class Kea implements Bird {
private int age;

public Kea() {
}

public Kea(int age) {
this.age = age;
}

public int getAge() {
return age;
}

@Override
public int hashCode() {
return Objects.hash(age);
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Kea other = (Kea) obj;
return age == other.age;
}

}

public static final class Takahe implements Bird {
private double weight;

public Takahe() {
}

public Takahe(double weight) {
this.weight = weight;
}

public double getWeight() {
return weight;
}

@Override
public int hashCode() {
return Objects.hash(weight);
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Takahe other = (Takahe) obj;
return Double.doubleToLongBits(weight) == Double.doubleToLongBits(other.weight);
}

}

}