Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pgdog/src/backend/pool/connection/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

use crate::{
frontend::{client::query_engine::TwoPcPhase, ClientRequest},
net::{parameter::Parameters, BackendKeyData, ProtocolMessage, Query},
net::{parameter::Parameters, BackendKeyData, Message, ProtocolMessage, Query},
state::State,
};

use futures::future::join_all;
use std::collections::HashMap;

use super::*;

Expand Down Expand Up @@ -53,6 +54,13 @@ impl Binding {
self.disconnect();
}

pub fn forward_with_shard(&self) -> Option<HashMap<String, Vec<usize>>> {
match self {
Binding::MultiShard(_shards, state) => state.table_shard_map(),
_ => None,
}
}

/// Are we connected to a backend?
pub fn connected(&self) -> bool {
match self {
Expand Down Expand Up @@ -91,13 +99,25 @@ impl Binding {
return Ok(message);
}
let mut read = false;
for server in shards.iter_mut() {

for (shard, server) in shards.iter_mut().enumerate() {
if !server.has_more_messages() {
continue;
}

let message = server.read().await?;

if state.display_table() {
if let Some(table_name) = message.table_name_from_dt().unwrap() {
let mut map: HashMap<String, Vec<usize>> =
state.table_shard_map().unwrap_or_default();
map.entry(table_name.clone())
.or_insert_with(Vec::new)
.push(shard);
state.set_table_shard_map(Some(map));
}
}

read = true;
if let Some(message) = state.forward(message)? {
return Ok(message);
Expand Down
13 changes: 13 additions & 0 deletions pgdog/src/backend/pool/connection/multi_shard/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Multi-shard connection state.

use context::Context;
use std::collections::HashMap;

use crate::{
frontend::{router::Route, PreparedStatements},
Expand Down Expand Up @@ -345,4 +346,16 @@ impl MultiShard {
}
}
}

pub fn display_table(&self) -> bool {
self.route.display_table()
}

pub fn set_table_shard_map(&mut self, map: Option<HashMap<String, Vec<usize>>>) {
self.route.set_table_shard_map(map);
}

pub fn table_shard_map(&self) -> Option<HashMap<String, Vec<usize>>> {
self.route.table_shard_map()
}
}
3 changes: 3 additions & 0 deletions pgdog/src/frontend/client/query_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
state::State,
};

use std::collections::HashSet;
use tracing::debug;

pub mod connect;
Expand Down Expand Up @@ -78,6 +79,7 @@ pub struct QueryEngine {
notify_buffer: NotifyBuffer,
pending_explain: Option<ExplainResponseState>,
hooks: QueryEngineHooks,
seen_tables: HashSet<String>,
}

impl QueryEngine {
Expand Down Expand Up @@ -105,6 +107,7 @@ impl QueryEngine {
pending_explain: None,
begin_stmt: None,
router: Router::default(),
seen_tables: HashSet::new(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's reasonable to add this for \dt because the original command doesn't do any aggregating or deduplication handling

})
}

Expand Down
51 changes: 49 additions & 2 deletions pgdog/src/frontend/client/query_engine/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
router::parser::{explain_trace::ExplainTrace, rewrite::statement::plan::RewriteResult},
},
net::{
DataRow, FromBytes, Message, Protocol, ProtocolMessage, Query, ReadyForQuery,
DataRow, Field, FromBytes, Message, Protocol, ProtocolMessage, Query, ReadyForQuery,
RowDescription, ToBytes, TransactionState,
},
state::State,
Expand Down Expand Up @@ -36,7 +36,7 @@ impl QueryEngine {
// We need to run a query now.
if context.in_transaction() {
// Connect to one shard if not sharded or to all shards
// for a cross-shard tranasction.
// for a cross-shard transaction.
if !self.connect_transaction(context).await? {
return Ok(());
}
Expand Down Expand Up @@ -123,8 +123,23 @@ impl QueryEngine {
) -> Result<(), Error> {
self.streaming = message.streaming();

let should_rewrite_for_display_table =
if let Some(route) = context.client_request.route.as_ref() {
route.display_table()
} else {
false
};

let code = message.code();
let payload = if code == 'T' {
if should_rewrite_for_display_table {
let mut fields = RowDescription::from_bytes(message.payload())
.unwrap()
.fields
.to_vec();
fields.push(Field::text("Shard"));
message = RowDescription::new(&fields).message()?;
}
Some(message.payload())
} else {
None
Expand Down Expand Up @@ -152,6 +167,38 @@ impl QueryEngine {
self.pending_explain = None;
}

if code == 'D' {
if should_rewrite_for_display_table {
let mut dr = DataRow::from_bytes(message.payload()).unwrap();
let col = dr.column(1).unwrap();

let shard_map = self.backend.forward_with_shard();
let table_lookup = std::str::from_utf8(&col).unwrap();

if let Some(map) = shard_map {
if self.seen_tables.contains(table_lookup) {
return Ok(());
}

self.seen_tables.insert(table_lookup.to_string());

let mut new_col = String::new();
for (i, val) in map[table_lookup].iter().enumerate() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking, maybe it would be better to have a more unique key for the HashMap, maybe instead of just the table name I should use the schema + table name? Thoughts?

if i > 0 {
new_col.push_str(", ")
}
new_col.push_str(&val.to_string());
}
dr.add(new_col);
} else {
dr.add(None);
}

message = dr.message()?;
Some(message.payload());
}
}

// Messages that we need to send to the client immediately.
// ReadyForQuery (B) | CopyInResponse (B) | ErrorResponse(B) | NoticeResponse(B) | NotificationResponse (B)
let flush = matches!(code, 'Z' | 'G' | 'E' | 'N' | 'A')
Expand Down
15 changes: 15 additions & 0 deletions pgdog/src/frontend/router/parser/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ impl QueryParser {
Command::default()
};

// Check if we are executing \dt command
if let Command::Query(route) = &mut command {
let query = match context.query() {
Ok(res) => res,
Err(e) => return Err(e),
};
if query.contains("pg_catalog.pg_class")
&& query.contains("pg_catalog.pg_namespace")
&& query.contains("relkind")
&& query.contains("pg_toast")
{
route.set_display_table(true);
}
}

if let Command::Query(route) = &mut command {
if route.is_cross_shard() && context.shards == 1 {
context
Expand Down
23 changes: 22 additions & 1 deletion pgdog/src/frontend/router/parser/route.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, ops::Deref};
use std::{collections::HashMap, fmt::Display, ops::Deref};

use lazy_static::lazy_static;

Expand Down Expand Up @@ -90,6 +90,8 @@ pub struct Route {
rollback_savepoint: bool,
search_path_driven: bool,
schema_changed: bool,
display_table: bool,
table_shard_map: Option<HashMap<String, Vec<usize>>>,
}

impl Display for Route {
Expand Down Expand Up @@ -326,6 +328,25 @@ impl Route {
ShardSource::Table(TableReason::Omni) | ShardSource::RoundRobin(RoundRobinReason::Omni)
)
}
pub fn set_display_table(&mut self, v: bool) {
self.display_table = v;
}

pub fn display_table(&self) -> bool {
self.display_table
}

pub fn table_shard_map(&self) -> Option<HashMap<String, Vec<usize>>> {
if self.table_shard_map == None {
Some(HashMap::new())
} else {
self.table_shard_map.clone()
}
}

pub fn set_table_shard_map(&mut self, map: Option<HashMap<String, Vec<usize>>>) {
self.table_shard_map = map;
}
}

/// Shard source.
Expand Down
11 changes: 11 additions & 0 deletions pgdog/src/net/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,17 @@ impl Message {
pub fn transaction_error(&self) -> bool {
self.code() == 'Z' && self.payload[5] as char == 'E'
}

pub fn table_name_from_dt(&self) -> Result<Option<String>, Error> {
if self.code() != 'D' {
return Ok(None);
}
let byte_name = DataRow::from_bytes(self.payload()).unwrap().column(1);

let table_name = std::str::from_utf8(&byte_name.unwrap())?.to_string();

return Ok(Some(table_name));
}
}

/// Check that the message we received is what we expected.
Expand Down
Loading