Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
}

public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords, UDF
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;

import org.apache.sysds.api.DMLScript;
Expand Down Expand Up @@ -46,7 +49,7 @@ public class ColumnEncoderUDF extends ColumnEncoder {
//TODO pass execution context through encoder factory for arbitrary functions not just builtin
//TODO integration into IPA to ensure existence of unoptimized functions

private final String _fName;
private String _fName;
public int _domainSize = 1;

protected ColumnEncoderUDF(int ptCols, String name) {
Expand Down Expand Up @@ -165,4 +168,20 @@ protected double getCode(CacheBlock<?> in, int row) {
protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) {
throw new DMLRuntimeException("UDF encoders only support full column access.");
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeUTF(_fName != null ? _fName : "");
out.writeInt(_domainSize);
}

@Override
public void readExternal(ObjectInput in) throws IOException {
super.readExternal(in);
_fName = in.readUTF();
if(_fName.isEmpty())
_fName = null;
_domainSize = in.readInt();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ else if(columnEncoder instanceof ColumnEncoderWordEmbedding)
return EncoderType.WordEmbedding.ordinal();
else if(columnEncoder instanceof ColumnEncoderBagOfWords)
return EncoderType.BagOfWords.ordinal();
else if(columnEncoder instanceof ColumnEncoderUDF)
return EncoderType.UDF.ordinal();
throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
}

Expand All @@ -276,19 +278,21 @@ public static ColumnEncoder createInstance(int type) {
case Bin:
return new ColumnEncoderBin();
case Dummycode:
return new ColumnEncoderDummycode();
case FeatureHash:
return new ColumnEncoderFeatureHash();
case PassThrough:
return new ColumnEncoderPassThrough();
case Recode:
return new ColumnEncoderRecode();
case WordEmbedding:
return new ColumnEncoderWordEmbedding();
case BagOfWords:
return new ColumnEncoderBagOfWords();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
return new ColumnEncoderDummycode();
case FeatureHash:
return new ColumnEncoderFeatureHash();
case PassThrough:
return new ColumnEncoderPassThrough();
case Recode:
return new ColumnEncoderRecode();
case WordEmbedding:
return new ColumnEncoderWordEmbedding();
case BagOfWords:
return new ColumnEncoderBagOfWords();
case UDF:
return new ColumnEncoderUDF();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderUDF;
import org.apache.sysds.runtime.util.LocalFileUtils;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -46,6 +48,9 @@
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.Collections;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;

public class SerializeTest extends AutomatedTestBase
{
Expand Down Expand Up @@ -113,6 +118,11 @@ public void testWEEncoderSerialization(){
runSerializeWEEncoder();
}

@Test
public void testUDFEncoderSerialization(){
runSerializeUDFEncoder();
}

private void runSerializeTest( int rows, int cols, double sparsity )
{
try
Expand Down Expand Up @@ -188,6 +198,63 @@ private void runSerializeWEEncoder(){
}
}

private void runSerializeUDFEncoder(){
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutput out = new ObjectOutputStream(bos)) {
final String udfName = "dummyUdf";
final int colId = 2;
final int domainSize = 5;

ColumnEncoderUDF udf = createUdf(colId, udfName, domainSize);
ColumnEncoderComposite composite = new ColumnEncoderComposite(Collections.singletonList(udf));
MultiColumnEncoder encoder = new MultiColumnEncoder(Collections.singletonList(composite));

encoder.writeExternal(out);
out.flush();

ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
ObjectInput in = new ObjectInputStream(bis);
MultiColumnEncoder encoderSer = new MultiColumnEncoder();
encoderSer.readExternal(in);
in.close();

ColumnEncoderComposite decodedComposite = encoderSer.getColumnEncoders().get(0);
ColumnEncoderUDF decodedUdf = decodedComposite.getEncoder(ColumnEncoderUDF.class);

Assert.assertNotNull(decodedUdf);
Assert.assertEquals(colId, decodedUdf.getColID());
Assert.assertEquals(domainSize, decodedUdf._domainSize);
Assert.assertEquals(udfName, getUdfName(decodedUdf));
}
catch(IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

private ColumnEncoderUDF createUdf(int colId, String name, int domainSize) {
try {
Constructor<ColumnEncoderUDF> ctor = ColumnEncoderUDF.class.getDeclaredConstructor(int.class, String.class);
ctor.setAccessible(true);
ColumnEncoderUDF udf = ctor.newInstance(colId, name);
udf._domainSize = domainSize;
return udf;
}
catch(Exception e) {
throw new RuntimeException(e);
}
}

private String getUdfName(ColumnEncoderUDF udf) {
try {
Field f = ColumnEncoderUDF.class.getDeclaredField("_fName");
f.setAccessible(true);
return (String) f.get(udf);
}
catch(Exception e) {
throw new RuntimeException(e);
}
}

private void runSerializeDedupDenseTest( int rows, int cols )
{
try
Expand Down