Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.createTempTableReference;
import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG;
import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.RECORDING_ROUTER;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

Expand Down Expand Up @@ -114,9 +115,14 @@
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.ProjectionProducer;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.StateId;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
Expand All @@ -138,6 +144,7 @@
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
Expand Down Expand Up @@ -932,11 +939,18 @@ abstract Builder<T> setBadRecordErrorHandler(
DynamicRead() {}

class CreateBoundedSourceForTable
extends DoFn<KV<String, BigQueryDynamicReadDescriptor>, BigQueryStorageStreamSource<T>> {
extends DoFn<
KV<String, BigQueryDynamicReadDescriptor>, KV<String, BigQueryStorageStreamSource<T>>> {

private final TupleTag<KV<String, CleanupOperationMessage>> cleanupInfoTag;

CreateBoundedSourceForTable(TupleTag<KV<String, CleanupOperationMessage>> cleanupInfoTag) {
this.cleanupInfoTag = cleanupInfoTag;
}

@ProcessElement
public void processElement(
OutputReceiver<BigQueryStorageStreamSource<T>> receiver,
MultiOutputReceiver receiver,
@Element KV<String, BigQueryDynamicReadDescriptor> kv,
PipelineOptions options)
throws Exception {
Expand All @@ -961,7 +975,14 @@ public void processElement(
// shards
long desiredChunkSize = getDesiredChunkSize(options, output);
List<BigQueryStorageStreamSource<T>> split = output.split(desiredChunkSize, options);
split.stream().forEach(source -> receiver.output(source));
split.stream()
.forEach(
source ->
receiver
.get(
new TupleTag<KV<String, BigQueryStorageStreamSource<T>>>(
"mainOutput"))
.output(KV.of(kv.getKey(), source)));
} else {
// run query
BigQueryStorageQuerySource<T> querySource =
Expand Down Expand Up @@ -997,7 +1018,21 @@ public void processElement(
// shards
long desiredChunkSize = getDesiredChunkSize(options, output);
List<BigQueryStorageStreamSource<T>> split = output.split(desiredChunkSize, options);
split.stream().forEach(source -> receiver.output(source));
split.stream()
.forEach(
source ->
receiver
.get(
new TupleTag<KV<String, BigQueryStorageStreamSource<T>>>(
"mainOutput"))
.output(KV.of(kv.getKey(), source.withFromQuery())));
boolean datasetCreatedByBeam = getQueryTempDataset() == null;
CleanupInfo cleanupInfo =
new CleanupInfo(
queryResultTable.getTableReference(), datasetCreatedByBeam, split.size());
receiver
.get(cleanupInfoTag)
.output(KV.of(kv.getKey(), CleanupOperationMessage.initialize(cleanupInfo)));
}
}

Expand All @@ -1010,6 +1045,9 @@ private long getDesiredChunkSize(
@Override
public PCollection<T> expand(PCollection<BigQueryDynamicReadDescriptor> input) {
TupleTag<T> rowTag = new TupleTag<>();
TupleTag<KV<String, BigQueryStorageStreamSource<T>>> streamTag = new TupleTag<>("mainOutput");
TupleTag<KV<String, CleanupOperationMessage>> cleanupInfoTag = new TupleTag<>();

PCollection<KV<String, BigQueryDynamicReadDescriptor>> addJobId =
input
.apply(
Expand All @@ -1024,25 +1062,208 @@ public String apply(BigQueryDynamicReadDescriptor input) {
.apply("Checkpoint", Redistribute.byKey());

PCollectionTuple resultTuple =
addJobId
.apply("Create streams", ParDo.of(new CreateBoundedSourceForTable()))
addJobId.apply(
"Create streams",
ParDo.of(new CreateBoundedSourceForTable(cleanupInfoTag))
.withOutputTags(streamTag, TupleTagList.of(cleanupInfoTag)));

PCollection<KV<String, BigQueryStorageStreamSource<T>>> streams =
resultTuple
.get(streamTag)
.setCoder(
SerializableCoder.of(new TypeDescriptor<BigQueryStorageStreamSource<T>>() {}))
.apply("Redistribute", Redistribute.arbitrarily())
.apply(
"Read Streams with storage read api",
ParDo.of(
new TypedRead.ReadTableSource<T>(
rowTag, getParseFn(), getBadRecordRouter()))
.withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG)));
KvCoder.of(
StringUtf8Coder.of(),
SerializableCoder.of(
new TypeDescriptor<BigQueryStorageStreamSource<T>>() {})))
.apply("Redistribute", Redistribute.arbitrarily());

PCollectionTuple readResultTuple =
streams.apply(
"Read Streams with storage read api",
ParDo.of(
new ReadDynamicStreamSource<T>(
rowTag, getParseFn(), getBadRecordRouter(), cleanupInfoTag))
.withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG).and(cleanupInfoTag)));

PCollection<KV<String, CleanupOperationMessage>> cleanupMessages1 =
resultTuple
.get(cleanupInfoTag)
.setCoder(
KvCoder.of(
StringUtf8Coder.of(), SerializableCoder.of(CleanupOperationMessage.class)));

PCollection<KV<String, CleanupOperationMessage>> cleanupMessages2 =
readResultTuple
.get(cleanupInfoTag)
.setCoder(
KvCoder.of(
StringUtf8Coder.of(), SerializableCoder.of(CleanupOperationMessage.class)));

PCollectionList.of(cleanupMessages1)
.and(cleanupMessages2)
.apply(Flatten.pCollections())
.apply("CleanupTempTables", ParDo.of(new CleanupTempTableDoFn(getBigQueryServices())));

getBadRecordErrorHandler()
.addErrorCollection(
resultTuple.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline())));
return resultTuple.get(rowTag).setCoder(getOutputCoder());
readResultTuple
.get(BAD_RECORD_TAG)
.setCoder(BadRecord.getCoder(input.getPipeline())));
return readResultTuple.get(rowTag).setCoder(getOutputCoder());
}
}

/** Implementation of {@link BigQueryIO#read()}. */
static class CleanupInfo implements Serializable {
private final String projectId;
private final String datasetId;
private final String tableId;
private final boolean datasetCreatedByBeam;
private final int totalStreams;

public CleanupInfo(TableReference tableRef, boolean datasetCreatedByBeam, int totalStreams) {
if (tableRef != null) {
this.projectId = tableRef.getProjectId();
this.datasetId = tableRef.getDatasetId();
this.tableId = tableRef.getTableId();
} else {
this.projectId = null;
this.datasetId = null;
this.tableId = null;
}
this.datasetCreatedByBeam = datasetCreatedByBeam;
this.totalStreams = totalStreams;
}

public TableReference getTableReference() {
if (projectId == null || datasetId == null || tableId == null) {
return null;
}
return new TableReference()
.setProjectId(projectId)
.setDatasetId(datasetId)
.setTableId(tableId);
}

public boolean isDatasetCreatedByBeam() {
return datasetCreatedByBeam;
}

public int getTotalStreams() {
return totalStreams;
}
}

static class CleanupOperationMessage implements Serializable {
private final @Nullable CleanupInfo cleanupInfo;
private final boolean isStreamCompletion;

private CleanupOperationMessage(@Nullable CleanupInfo cleanupInfo, boolean isStreamCompletion) {
this.cleanupInfo = cleanupInfo;
this.isStreamCompletion = isStreamCompletion;
}

public static CleanupOperationMessage streamComplete() {
return new CleanupOperationMessage(null, true);
}

public static CleanupOperationMessage initialize(CleanupInfo cleanupInfo) {
return new CleanupOperationMessage(cleanupInfo, false);
}

public @Nullable CleanupInfo getCleanupInfo() {
return cleanupInfo;
}

public boolean isStreamCompletion() {
return isStreamCompletion;
}
}

static class CleanupTempTableDoFn extends DoFn<KV<String, CleanupOperationMessage>, Void> {
private final BigQueryServices bqServices;
private static final Logger LOG = LoggerFactory.getLogger(CleanupTempTableDoFn.class);

@StateId("cleanupInfo")
private final StateSpec<ValueState<CleanupInfo>> cleanupInfoSpec = StateSpecs.value();

@StateId("completedStreams")
private final StateSpec<ValueState<Integer>> completedStreamsSpec = StateSpecs.value();

CleanupTempTableDoFn(BigQueryServices bqServices) {
this.bqServices = bqServices;
}

@ProcessElement
public void processElement(
@Element KV<String, CleanupOperationMessage> element,
@StateId("cleanupInfo") ValueState<CleanupInfo> cleanupInfoState,
@StateId("completedStreams") ValueState<Integer> completedStreamsState,
PipelineOptions options)
throws Exception {

CleanupOperationMessage msg = element.getValue();
CleanupInfo cleanupInfo = cleanupInfoState.read();
int completed = firstNonNull(completedStreamsState.read(), 0);

if (msg.isStreamCompletion()) {
completed += 1;
completedStreamsState.write(completed);
} else {
cleanupInfoState.write(msg.getCleanupInfo());
cleanupInfo = msg.getCleanupInfo();
}

if (cleanupInfo != null
&& cleanupInfo.getTotalStreams() > 0
&& completed == cleanupInfo.getTotalStreams()) {
TableReference tempTable = cleanupInfo.getTableReference();
try (DatasetService datasetService =
bqServices.getDatasetService(options.as(BigQueryOptions.class))) {
LOG.info("Deleting temporary table with query results {}", tempTable);
datasetService.deleteTable(tempTable);
if (cleanupInfo.isDatasetCreatedByBeam()) {
LOG.info("Deleting temporary dataset with query results {}", tempTable.getDatasetId());
datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
}
} catch (Exception e) {
LOG.warn("Failed to delete temporary BigQuery table {}", tempTable, e);
}
cleanupInfoState.clear();
completedStreamsState.clear();
}
}
}

private static class ReadDynamicStreamSource<T>
extends DoFn<KV<String, BigQueryStorageStreamSource<T>>, T> {
private final TypedRead.ReadTableSource<T> readTableSource;
private final TupleTag<KV<String, CleanupOperationMessage>> streamCompletionTag;

ReadDynamicStreamSource(
TupleTag<T> rowTag,
SerializableFunction<SchemaAndRecord, T> parseFn,
BadRecordRouter badRecordRouter,
TupleTag<KV<String, CleanupOperationMessage>> streamCompletionTag) {
this.readTableSource = new TypedRead.ReadTableSource<>(rowTag, parseFn, badRecordRouter);
this.streamCompletionTag = streamCompletionTag;
}

@ProcessElement
public void processElement(
@Element KV<String, BigQueryStorageStreamSource<T>> element,
MultiOutputReceiver receiver,
PipelineOptions options)
throws Exception {
readTableSource.processElement(element.getValue(), receiver, options);
if (element.getValue().getFromQuery()) {
receiver
.get(streamCompletionTag)
.output(KV.of(element.getKey(), CleanupOperationMessage.streamComplete()));
}
}
}

public static class Read extends PTransform<PBegin, PCollection<TableRow>> {
private final TypedRead<TableRow> inner;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ public static <T> BigQueryStorageStreamSource<T> create(
toJsonString(Preconditions.checkArgumentNotNull(tableSchema, "tableSchema")),
parseFn,
outputCoder,
bqServices);
bqServices,
false);
}

public BigQueryStorageStreamSource<T> withFromQuery() {
return new BigQueryStorageStreamSource<>(
readSession, readStream, jsonTableSchema, parseFn, outputCoder, bqServices, true);
}

@Override
Expand Down Expand Up @@ -106,13 +112,13 @@ public int hashCode() {
*/
public BigQueryStorageStreamSource<T> fromExisting(ReadStream newReadStream) {
return new BigQueryStorageStreamSource<>(
readSession, newReadStream, jsonTableSchema, parseFn, outputCoder, bqServices);
readSession, newReadStream, jsonTableSchema, parseFn, outputCoder, bqServices, fromQuery);
}

public BigQueryStorageStreamSource<T> fromExisting(
SerializableFunction<SchemaAndRecord, T> parseFn) {
return new BigQueryStorageStreamSource<>(
readSession, readStream, jsonTableSchema, parseFn, outputCoder, bqServices);
readSession, readStream, jsonTableSchema, parseFn, outputCoder, bqServices, fromQuery);
}

private final ReadSession readSession;
Expand All @@ -121,20 +127,27 @@ public BigQueryStorageStreamSource<T> fromExisting(
private final SerializableFunction<SchemaAndRecord, T> parseFn;
private final Coder<T> outputCoder;
private final BigQueryServices bqServices;
private final boolean fromQuery;

private BigQueryStorageStreamSource(
ReadSession readSession,
ReadStream readStream,
String jsonTableSchema,
SerializableFunction<SchemaAndRecord, T> parseFn,
Coder<T> outputCoder,
BigQueryServices bqServices) {
BigQueryServices bqServices,
boolean fromQuery) {
this.readSession = Preconditions.checkArgumentNotNull(readSession, "readSession");
this.readStream = Preconditions.checkArgumentNotNull(readStream, "stream");
this.jsonTableSchema = Preconditions.checkArgumentNotNull(jsonTableSchema, "jsonTableSchema");
this.parseFn = Preconditions.checkArgumentNotNull(parseFn, "parseFn");
this.outputCoder = Preconditions.checkArgumentNotNull(outputCoder, "outputCoder");
this.bqServices = Preconditions.checkArgumentNotNull(bqServices, "bqServices");
this.fromQuery = fromQuery;
}

public boolean getFromQuery() {
return fromQuery;
}

@Override
Expand Down
Loading
Loading