Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;

public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
Expand Down Expand Up @@ -106,6 +107,8 @@ else if(parts.length == 4)
return IndexingOOCInstruction.parseInstruction(str);
case Rand:
return DataGenOOCInstruction.parseInstruction(str);
case Append:
return AppendOOCInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class AppendOOCInstruction extends BinaryOOCInstruction {

public enum AppendType {
CBIND
}

protected final AppendType _type;

protected AppendOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, AppendType type,
String opcode, String istr) {
super(OOCType.Append, op, in1, in2, out, opcode, istr);
_type = type;
}

public static AppendOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 5, 4);

String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[parts.length-2]);
boolean cbind = Boolean.parseBoolean(parts[parts.length-1]);

if(in1.getDataType() != Types.DataType.MATRIX || in2.getDataType() != Types.DataType.MATRIX || !cbind){
throw new DMLRuntimeException("Only matrix-matrix cbind is supported");
}
AppendType type = AppendType.CBIND;

Operator op = new ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
return new AppendOOCInstruction(op, in1, in2, out, type, opcode, str);
}

@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject in1 = ec.getMatrixObject(input1);
MatrixObject in2 = ec.getMatrixObject(input2);
validateInput(in1, in2);
if(handleZeroDims(in1, in2, ec))
return;

OOCStream<IndexedMatrixValue> qIn1 = in1.getStreamHandle();
OOCStream<IndexedMatrixValue> qIn2 = in2.getStreamHandle();

int blksize = in1.getBlocksize();
int rem1 = (int) in1.getNumColumns()%blksize;
int rem2 = (int) in2.getNumColumns()%blksize;
int cblk1 = (int) in1.getDataCharacteristics().getNumColBlocks();
int cblk2 = (int) in2.getDataCharacteristics().getNumColBlocks();
int cblkRes = (int) Math.ceil((double)(in1.getNumColumns()+in2.getNumColumns())/blksize);

if(rem1==0){
// no shifting needed
OOCStream<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
mapOOC(qIn2, out, imv -> new IndexedMatrixValue(
new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), imv.getValue()));

ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(List.of(qIn1, out)));
return;
}

List<OOCStream<IndexedMatrixValue>> split1 = splitOOCStream(qIn1, imv -> imv.getIndexes().getColumnIndex()==cblk1? 1 : 0, 2);
List<OOCStream<IndexedMatrixValue>> split2 = splitOOCStream(qIn2, imv -> (int) imv.getIndexes().getColumnIndex()-1, cblk2);

OOCStream<IndexedMatrixValue> head = split1.get(0);
OOCStream<IndexedMatrixValue> lastCol = split1.get(1);
OOCStream<IndexedMatrixValue> firstCol = split2.get(0);

CachingStream firstColCache = new CachingStream(firstCol);
OOCStream<IndexedMatrixValue> firstColForCritical = firstColCache.getReadStream();
OOCStream<IndexedMatrixValue> firstColForTail = firstColCache.getReadStream();

SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
Function<IndexedMatrixValue, MatrixIndexes> rowKey = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), 1);

int fullRem2 = rem2==0? blksize : rem2;
// combine cols both matrices
joinOOC(lastCol, firstColForCritical, out, (left, right) -> {
MatrixBlock lb = (MatrixBlock) left.getValue();
MatrixBlock rb = (MatrixBlock) right.getValue();
int stop = cblk2==1 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1;
MatrixBlock combined = cbindBlocks(lb, sliceCols(rb, 0, stop));
return new IndexedMatrixValue(
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
}, rowKey);

List<OOCStream<IndexedMatrixValue>> outStreams = new ArrayList<>();
outStreams.add(head);
outStreams.add(out);

// shift cols second matrix
OOCStream<IndexedMatrixValue> fst = firstColForTail;
OOCStream<IndexedMatrixValue> sec = null;
for(int i=0; i<cblk2-1; i++){
out = new SubscribableTaskQueue<>();
CachingStream secCachingStream = new CachingStream(split2.get(i+1));
sec = secCachingStream.getReadStream();

int finalI = i;
joinOOC(fst, sec, out, (left, right) -> {
MatrixBlock lb = (MatrixBlock) left.getValue();
MatrixBlock rb = (MatrixBlock) right.getValue();
int stop = finalI+2==cblk2 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1;
MatrixBlock combined = cbindBlocks(sliceCols(lb, blksize-rem1, blksize), sliceCols(rb, 0, stop));
return new IndexedMatrixValue(
new MatrixIndexes(left.getIndexes().getRowIndex(), cblk1 + left.getIndexes().getColumnIndex()),
combined);
}, rowKey);

fst = secCachingStream.getReadStream();
outStreams.add(out);
}

if(cblk1+cblk2==cblkRes){
// overflow
int remSize = (rem1+rem2)%blksize;
out = new SubscribableTaskQueue<>();
mapOOC(fst, out, imv -> new IndexedMatrixValue(
new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()),
sliceCols((MatrixBlock) imv.getValue(), fullRem2-remSize, fullRem2)));

outStreams.add(out);
}
ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(outStreams));
}

public AppendType getAppendType() {
return _type;
}

private void validateInput(MatrixObject m1, MatrixObject m2) {
if(_type == AppendType.CBIND && m1.getNumRows() != m2.getNumRows()) {
throw new DMLRuntimeException(
"Append-cbind is not possible for input matrices " + input1.getName() + " and " + input2.getName()
+ " with different number of rows: " + m1.getNumRows() + " vs " + m2.getNumRows());
}
}

private boolean handleZeroDims(MatrixObject m1, MatrixObject m2, ExecutionContext ec) {
long rows = m1.getNumRows();
long cols1 = m1.getNumColumns();
long cols2 = m2.getNumColumns();
if(rows == 0 || (cols1 == 0 && cols2 == 0)) {
OOCStream<IndexedMatrixValue> empty = createWritableStream();
empty.closeInput();
ec.getMatrixObject(output).setStreamHandle(empty);
}
else if(cols1 == 0) {
ec.getMatrixObject(output).setStreamHandle(m2.getStreamHandle());
}
else if(cols2 == 0) {
ec.getMatrixObject(output).setStreamHandle(m1.getStreamHandle());
}
else return false;

return true;
}

private static MatrixBlock sliceCols(MatrixBlock in, int colStart, int colEndExclusive) {
// slice is inclusive
return in.slice(0, in.getNumRows()-1, colStart, colEndExclusive-1);
}

private static MatrixBlock cbindBlocks(MatrixBlock left, MatrixBlock right) {
return left.append(right, new MatrixBlock());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.sysds.runtime.ooc.stream.SourceOOCStream;
import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage;
import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage;
import org.apache.sysds.runtime.ooc.util.OOCUtils;
import org.apache.sysds.runtime.util.IndexRange;
import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList;

Expand Down Expand Up @@ -453,7 +454,7 @@ public void findCachedAsync(MatrixIndexes idx, Consumer<OOCStream.QueueCallback<
private void validateBlockCountOnClose() {
DataCharacteristics dc = _source.getDataCharacteristics();
if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
long expected = dc.getNumBlocks();
long expected = OOCUtils.getNumBlocks(dc);
if (expected >= 0 && _numBlocks != expected) {
throw new DMLRuntimeException("CachingStream block count mismatch: expected "
+ expected + " but saw " + _numBlocks + " (" + dc.getRows() + "x" + dc.getCols() + ")");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public abstract class OOCInstruction extends Instruction {

public enum OOCType {
Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage;
import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage;
import org.apache.sysds.runtime.ooc.util.OOCUtils;
import org.apache.sysds.runtime.util.IndexRange;

import java.util.LinkedList;
Expand Down Expand Up @@ -166,7 +167,7 @@ public synchronized void closeInput() {
private void validateBlockCountOnClose() {
DataCharacteristics dc = getDataCharacteristics();
if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
long expected = dc.getNumBlocks();
long expected = OOCUtils.getNumBlocks(dc);
if (expected >= 0 && _blockCount.get() != expected) {
throw new DMLRuntimeException("OOCStream block count mismatch: expected "
+ expected + " but saw " + _blockCount.get() + " (" + dc.getRows() + "x" + dc.getCols() + ")");
Expand All @@ -180,6 +181,7 @@ public void setSubscriber(Consumer<QueueCallback<T>> subscriber) {
throw new IllegalArgumentException("Cannot set subscriber to null");

LinkedList<T> data;
boolean needsEos;

synchronized(this) {
if(_subscriber != null)
Expand All @@ -189,12 +191,20 @@ public void setSubscriber(Consumer<QueueCallback<T>> subscriber) {
throw _failure;
data = _data;
_data = new LinkedList<>();
// If this stream was already closed with no buffered data, no further
// onDeliveryFinished() call will happen, so emit EOS immediately.
needsEos = _closed.get() && data.isEmpty() && _availableCtr.get() == 0;
if(needsEos)
_availableCtr.incrementAndGet(); // route terminal emission via onDeliveryFinished
}

for (T t : data) {
subscriber.accept(new SimpleQueueCallback<>(t, _failure));
onDeliveryFinished();
}

if(needsEos)
onDeliveryFinished();
}

@SuppressWarnings("unchecked")
Expand All @@ -214,6 +224,9 @@ private void onDeliveryFinished() {

@Override
public synchronized void propagateFailure(DMLRuntimeException re) {
// Ignore late failures
if(_closed.get() && _availableCtr.get() == 0)
return;
super.propagateFailure(re);
Consumer<QueueCallback<T>> s = _subscriber;
if(s != null)
Expand Down
11 changes: 7 additions & 4 deletions src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ public final void writeEmptyMatrixToHDFS(String fname, long rlen, long clen, int
FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
final Writer writer = IOUtilFunctions.getSeqWriter(path, job, _replication);
try {
MatrixIndexes index = new MatrixIndexes(1, 1);
MatrixBlock block = new MatrixBlock((int) Math.max(Math.min(rlen, blen), 1),
(int) Math.max(Math.min(clen, blen), 1), true);
writer.append(index, block);
// For 0xN or Nx0, emit a valid sequence file header only (no blocks).
if(rlen > 0 && clen > 0) {
MatrixIndexes index = new MatrixIndexes(1, 1);
MatrixBlock block = new MatrixBlock((int) Math.max(Math.min(rlen, blen), 1),
(int) Math.max(Math.min(clen, blen), 1), true);
writer.append(index, block);
}
}
finally {
IOUtilFunctions.closeSilently(writer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ protected void writeBinaryBlockMatrixToHDFS( Path path, JobConf job, MatrixBlock
public long writeMatrixFromStream(String fname, OOCStream<IndexedMatrixValue> stream, long rlen, long clen, int blen)
throws IOException {
Path path = new Path(fname);

// For empty dimensions, no stream tiles are expected but the output must still exist.
if(rlen <= 0 || clen <= 0) {
while(stream.dequeue() != LocalTaskQueue.NO_MORE_TASKS) {
// Drain any unexpected records to keep stream producers unblocked.
}
writeEmptyMatrixToHDFS(fname, rlen, clen, blen);
return 0;
}

long nnz = -1;
DataCharacteristics dc = stream.getDataCharacteristics();
if(dc != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ public MergedOOCStream(List<OOCStream<T>> sources) {
if(_failed.get())
return;

if(cb instanceof OOCStream.GroupQueueCallback<?>) {
OOCStream.GroupQueueCallback<T> group = (OOCStream.GroupQueueCallback<T>) cb;
for(int i = 0; i < group.size(); i++) {
OOCStream.QueueCallback<T> sub = group.getCallback(i);
try(sub) {
_taskQueue.enqueue(sub.keepOpen());
}
}
return;
}

_taskQueue.enqueue(cb.keepOpen());
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.sysds.runtime.ooc.util;

import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;

import java.util.ArrayList;
Expand Down Expand Up @@ -60,4 +61,13 @@ public static Collection<MatrixIndexes> getTilesOfRange(IndexRange range, long b
list.add(new MatrixIndexes(r, c));
return list;
}

public static long getNumBlocks(DataCharacteristics dc) {
if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) {
if(dc.getCols() == 0 || dc.getRows() == 0)
return 0;
return dc.getNumBlocks();
}
return -1;
}
}
Loading