Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion src/spanner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ categories.workspace = true
rust-version.workspace = true

[features]
unstable-stream = ["dep:futures"]
default = ["default-rustls-provider"]
# Enabled by default. Use the default rustls crypto provider ([aws-lc-rs]) for
# TLS and authentication. Applications with specific requirements for
# cryptography (such as exclusively using the [ring] crate) should disable this
# default and call `rustls::CryptoProvider::install_default()`.
default-rustls-provider = ["gaxi/_default-rustls-provider"]
unstable-stream = ["dep:futures"]

[dependencies]
async-trait.workspace = true
Expand All @@ -54,6 +60,7 @@ wkt = { workspace = true, features = ["time"] }

[dev-dependencies]
anyhow.workspace = true
google-cloud-spanner = { path = ".", features = ["default-rustls-provider"] }
google-cloud-test-macros.workspace = true
mockall.workspace = true
spanner-grpc-mock = { path = "grpc-mock" }
Expand Down
54 changes: 39 additions & 15 deletions src/spanner/src/batch_read_only_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ impl Partition {
.send()
.await?;

Ok(ResultSet::new(ResultSetParams {
ResultSet::create(ResultSetParams {
stream,
transaction_selector: Some(ReadContextTransactionSelector::Fixed(
req.transaction
Expand All @@ -399,7 +399,8 @@ impl Partition {
operation: StreamOperation::Query(req.clone()),
channel_hint,
gax_options,
}))
})
.await
}

async fn execute_read(
Expand All @@ -414,7 +415,7 @@ impl Partition {
.send()
.await?;

Ok(ResultSet::new(ResultSetParams {
ResultSet::create(ResultSetParams {
stream,
transaction_selector: Some(ReadContextTransactionSelector::Fixed(
req.transaction
Expand All @@ -429,7 +430,8 @@ impl Partition {
operation: StreamOperation::Read(req.clone()),
channel_hint,
gax_options,
}))
})
.await
}
}

Expand All @@ -451,7 +453,8 @@ pub(crate) mod tests {
use google_cloud_test_macros::tokio_test_no_panics;
use prost_types::Timestamp;
use spanner_grpc_mock::google::spanner::v1::{
Partition as MockPartition, PartitionResponse, Transaction,
PartialResultSet, Partition as MockPartition, PartitionResponse, ResultSetMetadata,
StructType, Transaction,
};
use static_assertions::assert_impl_all;
use std::fmt::Debug;
Expand Down Expand Up @@ -487,6 +490,22 @@ pub(crate) mod tests {
Ok(())
}

fn setup_select1() -> PartialResultSet {
PartialResultSet {
metadata: Some(ResultSetMetadata {
row_type: Some(StructType {
fields: vec![Default::default()],
}),
..Default::default()
}),
values: vec![prost_types::Value {
kind: Some(prost_types::value::Kind::StringValue("1".to_string())),
}],
last: true,
..Default::default()
}
}

#[tokio_test_no_panics]
async fn partition_execute_respects_options() -> anyhow::Result<()> {
use gaxi::grpc::tonic::Response;
Expand All @@ -499,8 +518,9 @@ pub(crate) mod tests {
assert!(timeout.is_some(), "Missing grpc-timeout header");
assert_eq!(timeout.unwrap(), "5000000u"); // 5 seconds in micros

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
Ok(Response::from(crate::result_set::tests::adapt([Ok(
setup_select1(),
)])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down Expand Up @@ -606,8 +626,9 @@ pub(crate) mod tests {
assert!(req.transaction.is_some());
assert_eq!(req.sql, "SELECT * FROM Users");

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
Ok(Response::from(crate::result_set::tests::adapt([Ok(
setup_select1(),
)])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down Expand Up @@ -648,8 +669,9 @@ pub(crate) mod tests {
assert!(req.transaction.is_some());
assert_eq!(req.table, "Users");

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
Ok(Response::from(crate::result_set::tests::adapt([Ok(
setup_select1(),
)])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down Expand Up @@ -812,8 +834,9 @@ pub(crate) mod tests {
mock.expect_execute_streaming_sql().once().returning(|req| {
let req = req.into_inner();
assert!(req.data_boost_enabled, "data_boost_enabled should be true");
let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
Ok(Response::from(crate::result_set::tests::adapt([Ok(
setup_select1(),
)])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down Expand Up @@ -844,8 +867,9 @@ pub(crate) mod tests {
mock.expect_streaming_read().once().returning(|req| {
let req = req.into_inner();
assert!(req.data_boost_enabled, "data_boost_enabled should be true");
let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(Response::from(rx))
Ok(Response::from(crate::result_set::tests::adapt([Ok(
setup_select1(),
)])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down
12 changes: 10 additions & 2 deletions src/spanner/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ pub use crate::read_only_transaction::SingleUseReadOnlyTransaction;
pub use crate::read_only_transaction::SingleUseReadOnlyTransactionBuilder;
pub use crate::read_write_transaction::ReadWriteTransaction;
pub use crate::result_set::ResultSet;
pub use crate::result_set::ResultSetError;
pub use crate::result_set_metadata::ResultSetMetadata;
pub use crate::row::Row;
pub use crate::statement::Statement;
Expand Down Expand Up @@ -1087,7 +1086,16 @@ mod tests {
"grpc-timeout header should be present for read"
);

let (_tx, rx) = tokio::sync::mpsc::channel(1);
let (tx, rx) = tokio::sync::mpsc::channel(1);
let metadata = mock_v1::ResultSetMetadata {
transaction: None,
..Default::default()
};
let prs = mock_v1::PartialResultSet {
metadata: Some(metadata),
..Default::default()
};
tx.try_send(Ok(prs)).unwrap();
Ok(Response::new(rx))
});

Expand Down
66 changes: 42 additions & 24 deletions src/spanner/src/read_only_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,20 +567,24 @@ impl ReadContextTransactionSelector {
TransactionState::Started(_, _) | TransactionState::NotStarted(_) => unreachable!(),
}
}
}

pub(crate) struct ExplicitBeginParams {
pub(crate) client: crate::database_client::DatabaseClient,
pub(crate) session_name: String,
pub(crate) transaction_tag: Option<String>,
pub(crate) channel_hint: usize,
pub(crate) request_options: crate::RequestOptions,
pub(crate) is_stream_fallback: bool,
pub(crate) precommit_token_tracker: crate::precommit::PrecommitTokenTracker,
}

impl ReadContextTransactionSelector {
/// Explicitly begins a transaction if the transaction selector is a `Lazy`
/// selector and the transaction has not yet been started. This is used by
/// the client to force the start of a transaction if the first statement
/// failed.
pub(crate) async fn begin_explicitly(
&self,
client: &crate::database_client::DatabaseClient,
session_name: String,
transaction_tag: Option<String>,
channel_hint: usize,
request_options: crate::RequestOptions,
is_stream_fallback: bool,
) -> crate::Result<()> {
pub(crate) async fn begin_explicitly(&self, params: ExplicitBeginParams) -> crate::Result<()> {
let Self::Lazy(lazy) = self else {
return Ok(());
};
Expand Down Expand Up @@ -613,7 +617,7 @@ impl ReadContextTransactionSelector {
// and must wait for the leader. If this call originated from a stream resume fallback
// (`is_stream_fallback = true`), this thread is the stream leader whose initial query failed,
// and it must proceed with an explicit BeginTransaction RPC.
if !is_stream_fallback {
if !params.is_stream_fallback {
FallbackAction::Wait(Arc::clone(notify))
} else {
FallbackAction::Begin(options.clone(), Some(Arc::clone(notify)))
Expand All @@ -640,12 +644,12 @@ impl ReadContextTransactionSelector {
// Waiters are blocked in `poll_selector_status` waiting for the result,
// and already completed states return early above.
let response = match execute_begin_transaction(
client,
session_name.clone(),
&params.client,
params.session_name,
options,
transaction_tag,
channel_hint,
request_options,
params.transaction_tag,
params.channel_hint,
params.request_options,
)
.await
{
Expand All @@ -671,6 +675,9 @@ impl ReadContextTransactionSelector {
};

self.update(response.id, response.read_timestamp)?;
params
.precommit_token_tracker
.update(response.precommit_token);

Ok(())
}
Expand Down Expand Up @@ -847,14 +854,15 @@ impl ReadContext {
}

self.transaction_selector
.begin_explicitly(
&self.client,
self.session_name.clone(),
self.transaction_tag.clone(),
self.channel_hint,
.begin_explicitly(ExplicitBeginParams {
client: self.client.clone(),
session_name: self.session_name.clone(),
transaction_tag: self.transaction_tag.clone(),
channel_hint: self.channel_hint,
request_options,
is_stream_fallback,
)
precommit_token_tracker: self.precommit_token_tracker.clone(),
})
.await?;
Ok(true)
}
Expand Down Expand Up @@ -892,7 +900,7 @@ macro_rules! execute_stream_with_retry {
}
};

Ok(ResultSet::new(ResultSetParams {
ResultSet::create(ResultSetParams {
stream,
transaction_selector: Some($self.transaction_selector.clone()),
precommit_token_tracker: $self.precommit_token_tracker.clone(),
Expand All @@ -902,7 +910,8 @@ macro_rules! execute_stream_with_retry {
operation: $operation_variant($request),
channel_hint: $self.channel_hint,
gax_options: $gax_options,
}))
})
.await
}};
}

Expand Down Expand Up @@ -2559,7 +2568,16 @@ pub(crate) mod tests {
selector: Some(mock_v1::transaction_selector::Selector::Id(vec![42])),
})
);
let (_tx, rx) = mpsc::channel(1);
let (tx, rx) = mpsc::channel(1);
let metadata = mock_v1::ResultSetMetadata {
row_type: Some(mock_v1::StructType { fields: vec![] }),
..Default::default()
};
let prs = mock_v1::PartialResultSet {
metadata: Some(metadata),
..Default::default()
};
tx.try_send(Ok(prs)).expect("send should succeed");
Ok(tonic::Response::new(rx))
});

Expand Down
30 changes: 26 additions & 4 deletions src/spanner/src/read_write_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,33 @@ impl ReadWriteTransaction {

/// Commits the transaction.
pub(crate) async fn commit(self) -> crate::Result<CommitResponse> {
let mutations = take(&mut *self.mutations.lock().unwrap());
let mut id = self.context.transaction_selector.get_id_no_wait()?;
if id.is_none() {
if self.is_starting()? {
return Err(crate::error::internal_error(
"Commit called while an asynchronous statement is still starting the transaction",
));
}
if self.begin_explicitly_if_not_started(false).await? {
// TODO(#5821): Include mutation_key during explicit transaction initialization fallback to preserve blind write intent on multiplexed sessions.
let mut begin_options = crate::RequestOptions::default();
if let Some(d) = self.deadline {
let remaining = d.saturating_duration_since(Instant::now());
begin_options.set_attempt_timeout(remaining);
}
begin_options = amend_request_options_for_lar(
self.context.client.leader_aware_routing_enabled,
begin_options,
);
if self
.context
.begin_explicitly_if_not_started(begin_options, false)
.await?
{
id = self.context.transaction_selector.get_id_no_wait()?;
}
}
let transaction_id = id.ok_or_else(|| internal_error("Transaction ID is missing"))?;
let mutations = take(&mut *self.mutations.lock().unwrap());
let precommit_token = self.context.precommit_token_tracker.get();
let request = CommitRequest::default()
.set_session(self.context.session_name.clone())
Expand Down Expand Up @@ -1492,8 +1506,16 @@ mod tests {
})
);

let (_, rx) = tokio::sync::mpsc::channel(1);
Ok(tonic::Response::from(rx))
let prs = v1::PartialResultSet {
metadata: Some(v1::ResultSetMetadata {
row_type: Some(v1::StructType { fields: vec![] }),
..Default::default()
}),
..Default::default()
};
Ok(tonic::Response::from(crate::result_set::tests::adapt([
Ok(prs),
])))
});

let (db_client, _server) = setup_db_client(mock).await;
Expand Down
Loading
Loading