Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ cc_library(
"@avro",
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/types:any",
"@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc",
"@com_google_googleapis//google/cloud/bigquery/storage/v1:storage_cc_grpc",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
Expand Down Expand Up @@ -219,7 +219,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:any",
"@com_google_absl//absl/types:variant",
"@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc",
"@com_google_googleapis//google/cloud/bigquery/storage/v1:storage_cc_grpc",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
Expand Down
14 changes: 7 additions & 7 deletions tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
std::vector<string> default_values_;
std::vector<absl::any> typed_default_values_;
int64 offset_;
apiv1beta1::DataFormat data_format_;
apiv1::DataFormat data_format_;

class Dataset : public DatasetBase {
public:
Expand All @@ -120,7 +120,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
std::vector<string> selected_fields,
std::vector<DataType> output_types,
std::vector<absl::any> typed_default_values, int64 offset_,
apiv1beta1::DataFormat data_format)
apiv1::DataFormat data_format)
: DatasetBase(DatasetContext(ctx)),
client_resource_(client_resource),
output_types_vector_(output_types_vector),
Expand All @@ -134,10 +134,10 @@ class BigQueryDatasetOp : public DatasetOpKernel {
data_format_(data_format) {
client_resource_->Ref();

if (data_format == apiv1beta1::DataFormat::AVRO) {
if (data_format == apiv1::DataFormat::AVRO) {
std::istringstream istream(schema);
avro::compileJsonSchema(istream, *avro_schema_);
} else if (data_format == apiv1beta1::DataFormat::ARROW) {
} else if (data_format == apiv1::DataFormat::ARROW) {
auto buffer_ = std::make_shared<arrow::Buffer>(
reinterpret_cast<const uint8_t *>(&schema[0]), schema.length());

Expand All @@ -158,11 +158,11 @@ class BigQueryDatasetOp : public DatasetOpKernel {

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string &prefix) const override {
if (data_format_ == apiv1beta1::DataFormat::AVRO) {
if (data_format_ == apiv1::DataFormat::AVRO) {
return std::unique_ptr<IteratorBase>(
new BigQueryReaderAvroDatasetIterator<Dataset>(
{this, strings::StrCat(prefix, "::BigQueryAvroDataset")}));
} else if (data_format_ == apiv1beta1::DataFormat::ARROW) {
} else if (data_format_ == apiv1::DataFormat::ARROW) {
return std::unique_ptr<IteratorBase>(
new BigQueryReaderArrowDatasetIterator<Dataset>(
{this, strings::StrCat(prefix, "::BigQueryArrowDataset")}));
Expand Down Expand Up @@ -229,7 +229,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
const std::unique_ptr<avro::ValidSchema> avro_schema_;
const int64 offset_;
std::shared_ptr<::arrow::Schema> arrow_schema_;
const apiv1beta1::DataFormat data_format_;
const apiv1::DataFormat data_format_;
};
};

Expand Down
47 changes: 21 additions & 26 deletions tensorflow_io/core/kernels/bigquery/bigquery_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace {

namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1;
namespace apiv1 = ::google::cloud::bigquery::storage::v1;

class BigQueryClientOp : public OpKernel {
public:
Expand Down Expand Up @@ -105,35 +105,30 @@ class BigQueryReadSessionOp : public OpKernel {
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
core::ScopedUnref scoped_unref(client_resource);

apiv1beta1::CreateReadSessionRequest createReadSessionRequest;
createReadSessionRequest.mutable_table_reference()->set_project_id(
project_id_);
createReadSessionRequest.mutable_table_reference()->set_dataset_id(
dataset_id_);
createReadSessionRequest.mutable_table_reference()->set_table_id(table_id_);
apiv1::CreateReadSessionRequest createReadSessionRequest;
createReadSessionRequest.set_parent(parent_);
*createReadSessionRequest.mutable_read_options()
->mutable_selected_fields() = {selected_fields_.begin(),
selected_fields_.end()};
createReadSessionRequest.mutable_read_options()->set_row_restriction(
row_restriction_);
createReadSessionRequest.set_requested_streams(requested_streams_);
createReadSessionRequest.set_sharding_strategy(
apiv1beta1::ShardingStrategy::BALANCED);
createReadSessionRequest.set_format(data_format_);
apiv1::ReadSession* read_session =
createReadSessionRequest.mutable_read_session();
read_session->set_table(strings::Printf(
"projects/%s/datasets/%s/tables/%s", project_id_.c_str(),
dataset_id_.c_str(), table_id_.c_str()));
read_session->set_data_format(data_format_);
*read_session->mutable_read_options()->mutable_selected_fields() = {
selected_fields_.begin(), selected_fields_.end()};
read_session->mutable_read_options()->set_row_restriction(row_restriction_);
createReadSessionRequest.set_max_stream_count(requested_streams_);

VLOG(3) << "createReadSessionRequest: "
<< createReadSessionRequest.DebugString();
::grpc::ClientContext context;
context.AddMetadata(
"x-goog-request-params",
strings::Printf("table_reference.dataset_id=%s&table_"
"reference.project_id=%s",
dataset_id_.c_str(), project_id_.c_str()));
context.AddMetadata("x-goog-request-params",
strings::Printf("read_session.table=%s",
read_session->table().c_str()));
context.set_deadline(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_seconds(60, GPR_TIMESPAN)));

std::shared_ptr<apiv1beta1::ReadSession> readSessionResponse =
std::make_shared<apiv1beta1::ReadSession>();
std::shared_ptr<apiv1::ReadSession> readSessionResponse =
std::make_shared<apiv1::ReadSession>();
VLOG(3) << "calling readSession";
::grpc::Status status = client_resource->GetStub("")->CreateReadSession(
&context, createReadSessionRequest, readSessionResponse.get());
Expand All @@ -155,13 +150,13 @@ class BigQueryReadSessionOp : public OpKernel {
Tensor* schema_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("schema", {}, &schema_t));

if (data_format_ == apiv1beta1::DataFormat::AVRO) {
if (data_format_ == apiv1::DataFormat::AVRO) {
OP_REQUIRES(ctx, readSessionResponse->has_avro_schema(),
errors::InvalidArgument("AVRO schema is missing"));
VLOG(3) << "avro schema:" << readSessionResponse->avro_schema().schema();
schema_t->scalar<tstring>()() =
readSessionResponse->avro_schema().schema();
} else if (data_format_ == apiv1beta1::DataFormat::ARROW) {
} else if (data_format_ == apiv1::DataFormat::ARROW) {
OP_REQUIRES(ctx, readSessionResponse->has_arrow_schema(),
errors::InvalidArgument("ARROW schema is missing"));
VLOG(3) << "arrow schema:"
Expand All @@ -183,7 +178,7 @@ class BigQueryReadSessionOp : public OpKernel {
std::vector<DataType> output_types_;
string row_restriction_;
int requested_streams_;
apiv1beta1::DataFormat data_format_;
apiv1::DataFormat data_format_;

mutex mu_;
ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_io/core/kernels/bigquery/bigquery_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ string GrpcStatusToString(const ::grpc::Status& status) {
}

Status GetDataFormat(string data_format_str,
apiv1beta1::DataFormat* data_format) {
apiv1::DataFormat* data_format) {
if (data_format_str == "ARROW") {
*data_format = apiv1beta1::DataFormat::ARROW;
*data_format = apiv1::DataFormat::ARROW;
} else if (data_format_str == "AVRO") {
*data_format = apiv1beta1::DataFormat::AVRO;
*data_format = apiv1::DataFormat::AVRO;
} else {
return errors::Internal("Unsupported data format: " + data_format_str);
}
Expand Down
44 changes: 21 additions & 23 deletions tensorflow_io/core/kernels/bigquery/bigquery_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ limitations under the License.
#include "arrow/buffer.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/api.h"
#include "google/cloud/bigquery/storage/v1beta1/storage.grpc.pb.h"
#include "google/cloud/bigquery/storage/v1/storage.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -50,18 +50,18 @@ limitations under the License.

namespace tensorflow {

namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1;
namespace apiv1 = ::google::cloud::bigquery::storage::v1;
static constexpr int kMaxReceiveMessageSize = -1; // Disabled

Status GrpcStatusToTfStatus(const ::grpc::Status &status);
string GrpcStatusToString(const ::grpc::Status &status);
Status GetDataFormat(string data_format_str,
apiv1beta1::DataFormat *data_format);
apiv1::DataFormat *data_format);

class BigQueryClientResource : public ResourceBase {
public:
explicit BigQueryClientResource(
std::function<std::unique_ptr<apiv1beta1::BigQueryStorage::Stub>(
std::function<std::unique_ptr<apiv1::BigQueryRead::Stub>(
const string &read_stream)>
stub_factory)
: stub_factory_(stub_factory) {}
Expand All @@ -80,10 +80,10 @@ class BigQueryClientResource : public ResourceBase {
args.SetString("read_stream", read_stream);
auto channel = ::grpc::CreateCustomChannel(server_name, creds, args);
VLOG(3) << "Creating GRPC channel";
return absl::make_unique<apiv1beta1::BigQueryStorage::Stub>(channel);
return absl::make_unique<apiv1::BigQueryRead::Stub>(channel);
}) {}

apiv1beta1::BigQueryStorage::Stub *GetStub(const string &read_stream)
apiv1::BigQueryRead::Stub *GetStub(const string &read_stream)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (stubs_.find(read_stream) == stubs_.end()) {
auto stub = stub_factory_(read_stream);
Expand All @@ -95,11 +95,11 @@ class BigQueryClientResource : public ResourceBase {
string DebugString() const override { return "BigQueryClientResource"; }

private:
std::function<std::unique_ptr<apiv1beta1::BigQueryStorage::Stub>(
std::function<std::unique_ptr<apiv1::BigQueryRead::Stub>(
const string &)>
stub_factory_;
mutex mu_;
std::unordered_map<string, std::unique_ptr<apiv1beta1::BigQueryStorage::Stub>>
std::unordered_map<string, std::unique_ptr<apiv1::BigQueryRead::Stub>>
stubs_ TF_GUARDED_BY(mu_);
};

Expand Down Expand Up @@ -156,11 +156,9 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator<Dataset> {
return OkStatus();
}

apiv1beta1::ReadRowsRequest readRowsRequest;
readRowsRequest.mutable_read_position()->mutable_stream()->set_name(
this->dataset()->stream());
readRowsRequest.mutable_read_position()->set_offset(
this->dataset()->offset());
apiv1::ReadRowsRequest readRowsRequest;
readRowsRequest.set_read_stream(this->dataset()->stream());
readRowsRequest.set_offset(this->dataset()->offset());

read_rows_context_ = absl::make_unique<::grpc::ClientContext>();
// The deadline is for the entire ReadRows (not a single message receipt),
Expand All @@ -169,14 +167,14 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator<Dataset> {
std::chrono::hours(24));
read_rows_context_->AddMetadata(
"x-goog-request-params",
absl::StrCat("read_position.stream.name=",
readRowsRequest.read_position().stream().name()));
absl::StrCat("read_stream=",
readRowsRequest.read_stream()));

VLOG(3) << "getting reader, stream: "
<< readRowsRequest.read_position().stream().DebugString();
<< readRowsRequest.read_stream();
reader_ = this->dataset()
->client_resource()
->GetStub(readRowsRequest.read_position().stream().name())
->GetStub(readRowsRequest.read_stream())
->ReadRows(read_rows_context_.get(), readRowsRequest);

return OkStatus();
Expand All @@ -191,9 +189,9 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator<Dataset> {
int current_row_index_ = 0;
mutex mu_;
std::unique_ptr<::grpc::ClientContext> read_rows_context_ TF_GUARDED_BY(mu_);
std::unique_ptr<::grpc::ClientReader<apiv1beta1::ReadRowsResponse>> reader_
std::unique_ptr<::grpc::ClientReader<apiv1::ReadRowsResponse>> reader_
TF_GUARDED_BY(mu_);
std::unique_ptr<apiv1beta1::ReadRowsResponse> response_ TF_GUARDED_BY(mu_);
std::unique_ptr<apiv1::ReadRowsResponse> response_ TF_GUARDED_BY(mu_);
};

// BigQuery reader for Arrow serialized data.
Expand All @@ -213,11 +211,11 @@ class BigQueryReaderArrowDatasetIterator
TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
if (this->response_ && this->response_->has_arrow_record_batch() &&
this->current_row_index_ <
this->response_->arrow_record_batch().row_count()) {
this->response_->row_count()) {
return OkStatus();
}

this->response_ = absl::make_unique<apiv1beta1::ReadRowsResponse>();
this->response_ = absl::make_unique<apiv1::ReadRowsResponse>();
if (!this->reader_->Read(this->response_.get())) {
*end_of_sequence = true;
return GrpcStatusToTfStatus(this->reader_->Finish());
Expand Down Expand Up @@ -315,11 +313,11 @@ class BigQueryReaderAvroDatasetIterator
Status EnsureHasRow(bool *end_of_sequence)
TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
if (this->response_ &&
this->current_row_index_ < this->response_->avro_rows().row_count()) {
this->current_row_index_ < this->response_->row_count()) {
return OkStatus();
}

this->response_ = absl::make_unique<apiv1beta1::ReadRowsResponse>();
this->response_ = absl::make_unique<apiv1::ReadRowsResponse>();
VLOG(3) << "calling read";
if (!this->reader_->Read(this->response_.get())) {
VLOG(3) << "no data";
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_io/core/kernels/tests/bigquery_test_client_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
namespace {

namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1;
namespace apiv1 = ::google::cloud::bigquery::storage::v1;

class BigQueryTestClientOp : public OpKernel {
public:
Expand Down Expand Up @@ -58,13 +58,13 @@ class BigQueryTestClientOp : public OpKernel {
std::shared_ptr<grpc::Channel> channel =
::grpc::CreateChannel(this->fake_server_address_,
grpc::InsecureChannelCredentials());
auto stub = apiv1beta1::BigQueryStorage::NewStub(channel);
auto stub = apiv1::BigQueryRead::NewStub(channel);
LOG(INFO) << "BigQueryTestClientOp waiting for connections";
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_seconds(15, GPR_TIMESPAN)));
LOG(INFO) << "Done creating BigQueryTestClientOp Fake client";
return absl::make_unique<apiv1beta1::BigQueryStorage::Stub>(
return absl::make_unique<apiv1::BigQueryRead::Stub>(
channel);
});
return OkStatus();
Expand Down
Loading