Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ public SQLDatasetProducer getProducer(SQLPullRequest pullRequest, PullCapability

String table = datasets.get(pullRequest.getDatasetName()).getBigQueryTable();

return new BigQuerySparkDatasetProducer(sqlEngineConfig, datasetProject, dataset, table);
return new BigQuerySparkDatasetProducer(sqlEngineConfig,
datasetProject,
dataset,
table,
pullRequest.getDatasetSchema());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package io.cdap.plugin.gcp.bigquery.sqlengine;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.etl.api.engine.sql.dataset.RecordCollection;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetDescription;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetProducer;
import io.cdap.cdap.etl.api.sql.engine.dataset.SparkRecordCollectionImpl;
import io.cdap.plugin.gcp.common.GCPConfig;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.nio.charset.StandardCharsets;
Expand All @@ -39,6 +42,8 @@
public class BigQuerySparkDatasetProducer
implements SQLDatasetProducer, Serializable {

private static final Logger LOG = LoggerFactory.getLogger(BigQuerySparkDatasetProducer.class);

private static final String FORMAT = "bigquery";
private static final String CONFIG_CREDENTIALS_FILE = "credentialsFile";
private static final String CONFIG_CREDENTIALS = "credentials";
Expand All @@ -47,15 +52,19 @@ public class BigQuerySparkDatasetProducer
private String project;
private String bqDataset;
private String bqTable;
private Schema schema;


public BigQuerySparkDatasetProducer(BigQuerySQLEngineConfig config,
String project,
String bqDataset,
String bqTable) {
String bqTable,
Schema schema) {
this.config = config;
this.project = project;
this.bqDataset = bqDataset;
this.bqTable = bqTable;
this.schema = schema;
}

@Override
Expand Down Expand Up @@ -87,6 +96,7 @@ public RecordCollection produce(SQLDataset sqlDataset) {

// Load path into dataset.
Dataset<Row> ds = bqReader.load(path);
ds = convertFieldTypes(ds);

return new SparkRecordCollectionImpl(ds);
}
Expand All @@ -95,4 +105,37 @@ public RecordCollection produce(SQLDataset sqlDataset) {
private String encodeBase64(String serviceAccountJson) {
return Base64.getEncoder().encodeToString(serviceAccountJson.getBytes(StandardCharsets.UTF_8));
}

/**
* Adjust CDAP types for int and float fields.
*
* @param ds input dataframe
* @return dataframe with updated schema.
*/
private Dataset<Row> convertFieldTypes(Dataset<Row> ds) {
for (Schema.Field field : schema.getFields()) {
String fieldName = field.getName();
Schema fieldSchema = field.getSchema();

// For nullable types, check the underlying type.
if (fieldSchema.isNullable()) {
fieldSchema = fieldSchema.getNonNullable();
}

// Handle Int types
if (fieldSchema.getType() == Schema.Type.INT && fieldSchema.getLogicalType() == null) {
LOG.trace("Converting field {} to Integer", fieldName);
ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.IntegerType));
}

// Handle float types
if (fieldSchema.getType() == Schema.Type.FLOAT && fieldSchema.getLogicalType() == null) {
LOG.trace("Converting field {} to Float", fieldName);
ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.FloatType));
}
}

return ds;
}

}