Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pgdog/src/frontend/client/query_engine/test/set_schema_sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ async fn test_set_works_cross_shard_disabled() {
let reply = client.read_until('Z').await.unwrap();
assert_eq!(reply.len(), 2);
}

#[tokio::test]
async fn test_ambiguous_schema_sharded_query_errors_when_cross_shard_disabled() {
let table = "schema_shard_ambiguous_test";

let mut setup = TestClient::new_sharded(Parameters::default()).await;
for stmt in [
"CREATE SCHEMA IF NOT EXISTS acustomer".to_string(),
"CREATE SCHEMA IF NOT EXISTS bcustomer".to_string(),
format!("CREATE TABLE IF NOT EXISTS acustomer.{table} (id INT)"),
format!("CREATE TABLE IF NOT EXISTS bcustomer.{table} (id INT)"),
] {
setup.send_simple(Query::new(&stmt)).await;
setup.read_until('Z').await.unwrap();
}

let mut client = TestClient::new_cross_shard_disabled(Parameters::default()).await;
client
.send_simple(Query::new(&format!("SELECT * FROM {table}")))
.await;
let err = client.read_until('Z').await.unwrap_err();
assert_eq!(err.code, "58000");
assert_eq!(err.message, "cross-shard queries are disabled");
}
34 changes: 24 additions & 10 deletions pgdog/src/frontend/router/parser/query/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ impl QueryParser {
self.recorder_mut(),
);

let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);

let shard = parser.shard()?;

if let Some(shard) = shard {
Expand All @@ -32,14 +26,34 @@ impl QueryParser {
.shards_calculator
.push(ShardWithPriority::new_table(shard));
} else {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "DELETE fell back to broadcast");
}
if is_sharded {
let schema_shard_state = parser.schema_shard_state(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
if let SchemaShardState::Resolved { shard, schema } = schema_shard_state {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(Some(shard.clone()), "DELETE matched schema");
}
context
.shards_calculator
.push(ShardWithPriority::new_search_path(shard, &schema));
} else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "DELETE fell back to broadcast");
}
context
.shards_calculator
.push(ShardWithPriority::new_table(Shard::All));
} else {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "DELETE fell back to omnisharded");
}
context
.shards_calculator
.push(ShardWithPriority::new_rr_omni(Shard::All));
Expand Down
32 changes: 22 additions & 10 deletions pgdog/src/frontend/router/parser/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
plugin::plugins,
};

use super::statement::SchemaShardState;
use super::{
explain_trace::{ExplainRecorder, ExplainSummary},
*,
Expand Down Expand Up @@ -530,18 +531,29 @@ impl QueryParser {
)
.with_schema_lookup(schema_lookup);

let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);

let shard = parser.shard()?.unwrap_or(Shard::All);
let shard = parser.shard()?;

context.shards_calculator.push(if is_sharded {
ShardWithPriority::new_table(shard.clone())
context.shards_calculator.push(if let Some(shard) = shard {
ShardWithPriority::new_table(shard)
} else {
ShardWithPriority::new_table_omni(shard)
let schema_shard_state = parser.schema_shard_state(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);

if let SchemaShardState::Resolved { shard, schema } = schema_shard_state {
ShardWithPriority::new_search_path(shard, &schema)
} else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded {
ShardWithPriority::new_table(Shard::All)
} else {
ShardWithPriority::new_table_omni(Shard::All)
}
});

let shard = context.shards_calculator.shard();
Expand Down
22 changes: 20 additions & 2 deletions pgdog/src/frontend/router/parser/query/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl QueryParser {

let mut shards = HashSet::new();

let (shard, is_sharded, tables, advisory_locks) = {
let (shard, schema_shard_state, is_sharded, tables, advisory_locks) = {
let mut statement_parser = StatementParser::from_select(
stmt,
context.router_context.bind,
Expand All @@ -72,10 +72,16 @@ impl QueryParser {
let advisory_locks = statement_parser.extract_advisory_locks();

if shard.is_some() {
(shard, true, vec![], advisory_locks)
(shard, SchemaShardState::None, true, vec![], advisory_locks)
} else {
let schema_shard_state = statement_parser.schema_shard_state(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
(
None,
schema_shard_state,
statement_parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
Expand Down Expand Up @@ -148,6 +154,18 @@ impl QueryParser {
context
.shards_calculator
.push(ShardWithPriority::new_table(shard));
} else if let SchemaShardState::Resolved { shard, schema } = schema_shard_state {
debug!("resolved schema-sharded query via search_path/default schema");

context
.shards_calculator
.push(ShardWithPriority::new_search_path(shard, &schema));
} else if matches!(schema_shard_state, SchemaShardState::Ambiguous) {
debug!("schema-sharded query is ambiguous, routing as cross-shard");

context
.shards_calculator
.push(ShardWithPriority::new_table(Shard::All));
} else if is_sharded {
debug!("table is sharded, but no sharding key detected");

Expand Down
34 changes: 24 additions & 10 deletions pgdog/src/frontend/router/parser/query/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ impl QueryParser {
self.recorder_mut(),
);

let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);

let shard = parser.shard()?;
if let Some(shard) = shard {
if let Some(recorder) = self.recorder_mut() {
Expand All @@ -31,14 +25,34 @@ impl QueryParser {
.shards_calculator
.push(ShardWithPriority::new_table(shard));
} else {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "UPDATE fell back to broadcast");
}
if is_sharded {
let schema_shard_state = parser.schema_shard_state(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
let is_sharded = parser.is_sharded(
&context.router_context.schema,
context.router_context.cluster.user(),
context.router_context.parameter_hints.search_path,
);
if let SchemaShardState::Resolved { shard, schema } = schema_shard_state {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(Some(shard.clone()), "UPDATE matched schema");
}
context
.shards_calculator
.push(ShardWithPriority::new_search_path(shard, &schema));
} else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "UPDATE fell back to broadcast");
}
context
.shards_calculator
.push(ShardWithPriority::new_table(Shard::All));
} else {
if let Some(recorder) = self.recorder_mut() {
recorder.record_entry(None, "UPDATE fell back to omnisharded");
}
context
.shards_calculator
.push(ShardWithPriority::new_table_omni(Shard::All));
Expand Down
133 changes: 133 additions & 0 deletions pgdog/src/frontend/router/parser/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ enum Statement<'a> {
Insert(&'a InsertStmt),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum SchemaShardState {
None,
Resolved { shard: Shard, schema: String },
Ambiguous,
}

/// Context for looking up table columns from the database schema.
/// Used for INSERT statements without explicit column lists.
pub struct SchemaLookupContext<'a> {
Expand Down Expand Up @@ -604,6 +611,52 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> {
Ok(None)
}

pub(crate) fn schema_shard_state(
&mut self,
db_schema: &Schema,
user: &str,
search_path: Option<&ParameterValue>,
) -> SchemaShardState {
if self.schema.schemas.is_empty() {
return SchemaShardState::None;
}

let tables = self.tables().to_vec();
let mut schema_sharder = SchemaSharder::default();
let default_schema_mapping = self.schema.schemas.get(None).is_some();
let mut ambiguous = false;

for table in tables {
if let Some(schema) = table.schema {
schema_sharder.resolve(Some(schema.into()), &self.schema.schemas);
continue;
}

if let Some(relation) = db_schema.table(table, user, search_path) {
schema_sharder.resolve(Some(relation.schema().into()), &self.schema.schemas);
continue;
}

ambiguous |= default_schema_mapping
|| self
.schema
.schemas
.keys()
.any(|schema| db_schema.get(schema, table.name).is_some());
}

if ambiguous {
SchemaShardState::Ambiguous
} else if let Some((shard, schema)) = schema_sharder.get() {
SchemaShardState::Resolved {
shard,
schema: schema.to_owned(),
}
} else {
SchemaShardState::None
}
}

/// Check that the query references a table that contains a sharded
/// column. This check is needed in case sharded tables config
/// doesn't specify a table name and should short-circuit if it does.
Expand Down Expand Up @@ -2549,6 +2602,86 @@ mod test {
assert_eq!(result2, Some(Shard::Direct(2)));
}

fn make_test_schema_with_sharded_relations() -> crate::backend::Schema {
let relations = HashMap::from([
(
("sales".into(), "products".into()),
Relation::test_table("sales", "products", IndexMap::new()),
),
(
("inventory".into(), "products".into()),
Relation::test_table("inventory", "products", IndexMap::new()),
),
(
("public".into(), "unsharded_table".into()),
Relation::test_table("public", "unsharded_table", IndexMap::new()),
),
]);
crate::backend::Schema::from_parts(vec!["$user".into(), "public".into()], relations)
}

fn run_schema_shard_state_test(
stmt: &str,
search_path: Option<ParameterValue>,
) -> Result<SchemaShardState, Error> {
let schema = ShardingSchema {
shards: 3,
schemas: ShardedSchemas::new(vec![
ShardedSchema {
database: "test".to_string(),
name: Some("sales".to_string()),
shard: 1,
all: false,
},
ShardedSchema {
database: "test".to_string(),
name: Some("inventory".to_string()),
shard: 2,
all: false,
},
]),
..Default::default()
};
let db_schema = make_test_schema_with_sharded_relations();
let raw = pg_query::parse(stmt)
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let mut parser = StatementParser::from_raw(&raw, None, &schema, None)?;
Ok(parser.schema_shard_state(&db_schema, "pgdog", search_path.as_ref()))
}

#[test]
fn test_schema_shard_state_ambiguous_without_search_path() {
let result = run_schema_shard_state_test("SELECT * FROM products", None).unwrap();
assert_eq!(result, SchemaShardState::Ambiguous);
}

#[test]
fn test_schema_shard_state_resolved_from_search_path() {
let result = run_schema_shard_state_test(
"SELECT * FROM products",
Some(ParameterValue::Tuple(vec!["sales".into(), "public".into()])),
)
.unwrap();
assert_eq!(
result,
SchemaShardState::Resolved {
shard: Shard::Direct(1),
schema: "sales".into(),
}
);
}

#[test]
fn test_schema_shard_state_none_for_unsharded_table() {
let result = run_schema_shard_state_test("SELECT * FROM unsharded_table", None).unwrap();
assert_eq!(result, SchemaShardState::None);
}

// Column-only sharded table detection tests (using loaded schema)

fn run_test_column_only(stmt: &str, bind: Option<&Bind>) -> Result<Option<Shard>, Error> {
Expand Down