Skip to content

Commit 69f380b

Browse files
committed
feat: prevent duplicated table name across many schema
1 parent 3d16bf2 commit 69f380b

9 files changed

Lines changed: 179 additions & 27 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/sqlx_gen/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sqlx-gen"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
edition = "2021"
55
description = "Generate Rust structs from database schema introspection"
66
license = "MIT"

crates/sqlx_gen/src/codegen/crud_gen.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ mod tests {
638638
ParsedEntity {
639639
struct_name: "Users".to_string(),
640640
table_name: "users".to_string(),
641+
schema_name: None,
641642
is_view: false,
642643
fields: vec![
643644
make_field("id", "id", "i32", false, true),
@@ -1043,6 +1044,7 @@ mod tests {
10431044
let entity = ParsedEntity {
10441045
struct_name: "Connector".to_string(),
10451046
table_name: "connector".to_string(),
1047+
schema_name: None,
10461048
is_view: false,
10471049
fields: vec![
10481050
make_field("id", "id", "i32", false, true),
@@ -1064,6 +1066,7 @@ mod tests {
10641066
let entity = ParsedEntity {
10651067
struct_name: "Logs".to_string(),
10661068
table_name: "logs".to_string(),
1069+
schema_name: None,
10671070
is_view: false,
10681071
fields: vec![
10691072
make_field("message", "message", "String", false, false),
@@ -1082,6 +1085,7 @@ mod tests {
10821085
let entity = ParsedEntity {
10831086
struct_name: "Logs".to_string(),
10841087
table_name: "logs".to_string(),
1088+
schema_name: None,
10851089
is_view: false,
10861090
fields: vec![
10871091
make_field("message", "message", "String", false, false),
@@ -1107,6 +1111,7 @@ mod tests {
11071111
let entity = ParsedEntity {
11081112
struct_name: "Users".to_string(),
11091113
table_name: "users".to_string(),
1114+
schema_name: None,
11101115
is_view: false,
11111116
fields: vec![
11121117
make_field("id", "id", "Uuid", false, true),

crates/sqlx_gen/src/codegen/entity_parser.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ pub struct ParsedEntity {
2626
pub struct_name: String,
2727
/// Original table/view name from `#[sqlx_gen(table = "...")]`
2828
pub table_name: String,
29+
/// Schema name from `#[sqlx_gen(schema = "...")]`
30+
pub schema_name: Option<String>,
2931
/// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
3032
pub is_view: bool,
3133
/// Parsed fields
@@ -84,7 +86,7 @@ fn extract_use_imports(file: &syn::File) -> Vec<String> {
8486
if let syn::Item::Use(use_item) = item {
8587
let text = use_item.to_token_stream().to_string();
8688
// Skip serde and sqlx imports — the CRUD generator adds those itself
87-
if text.contains("serde") || text.contains("sqlx") {
89+
if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
8890
return None;
8991
}
9092
// Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
@@ -112,7 +114,7 @@ fn normalize_use_statement(s: &str) -> String {
112114
fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
113115
let struct_name = item.ident.to_string();
114116

115-
let (kind, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
117+
let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
116118
let is_view = kind.as_deref() == Some("view");
117119

118120
// Fall back to struct name if no table annotation
@@ -132,16 +134,18 @@ fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
132134
Ok(ParsedEntity {
133135
struct_name,
134136
table_name,
137+
schema_name,
135138
is_view,
136139
fields,
137140
imports: Vec::new(), // filled by parse_entity_source
138141
})
139142
}
140143

141-
/// Parse `#[sqlx_gen(kind = "...", table = "...")]` from struct attributes.
142-
/// Returns (kind, table_name).
143-
fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>) {
144+
/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
145+
/// Returns (kind, schema_name, table_name).
146+
fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
144147
let mut kind = None;
148+
let mut schema_name = None;
145149
let mut table_name = None;
146150

147151
for attr in attrs {
@@ -150,13 +154,16 @@ fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Opt
150154
if let Some(k) = extract_attr_value(&tokens, "kind") {
151155
kind = Some(k);
152156
}
157+
if let Some(s) = extract_attr_value(&tokens, "schema") {
158+
schema_name = Some(s);
159+
}
153160
if let Some(t) = extract_attr_value(&tokens, "table") {
154161
table_name = Some(t);
155162
}
156163
}
157164
}
158165

159-
(kind, table_name)
166+
(kind, schema_name, table_name)
160167
}
161168

162169
/// Extract a named string value from an attribute token string.
@@ -549,6 +556,23 @@ mod tests {
549556
assert!(entity.imports.is_empty());
550557
}
551558

559+
#[test]
560+
fn test_imports_keep_serde_json() {
561+
let source = r#"
562+
use serde::{Serialize, Deserialize};
563+
use serde_json::Value;
564+
565+
#[derive(Debug, Clone, sqlx::FromRow)]
566+
#[sqlx_gen(kind = "table", table = "users")]
567+
pub struct Users {
568+
pub data: Value,
569+
}
570+
"#;
571+
let entity = parse_entity_source(source).unwrap();
572+
assert_eq!(entity.imports.len(), 1);
573+
assert!(entity.imports[0].contains("serde_json"));
574+
}
575+
552576
#[test]
553577
fn test_imports_exclude_sqlx() {
554578
let source = r#"

crates/sqlx_gen/src/codegen/mod.rs

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ pub fn normalize_module_name(name: &str) -> String {
6262
result
6363
}
6464

65+
/// Well-known default schemas that don't need a prefix in filenames.
66+
const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67+
68+
/// Returns true if the schema is a well-known default (public, main, dbo).
69+
pub fn is_default_schema(schema: &str) -> bool {
70+
DEFAULT_SCHEMAS.contains(&schema)
71+
}
72+
73+
/// Build a module name, prefixing with schema when there are multiple schemas
74+
/// and the schema is not a well-known default.
75+
pub fn build_module_name(schema_name: &str, table_name: &str, has_multiple_schemas: bool) -> String {
76+
if !has_multiple_schemas || DEFAULT_SCHEMAS.contains(&schema_name) {
77+
normalize_module_name(table_name)
78+
} else {
79+
normalize_module_name(&format!("{}_{}", schema_name, table_name))
80+
}
81+
}
82+
6583
/// A generated code file with its content and required imports.
6684
#[derive(Debug, Clone)]
6785
pub struct GeneratedFile {
@@ -81,13 +99,23 @@ pub fn generate(
8199
) -> Vec<GeneratedFile> {
82100
let mut files = Vec::new();
83101

102+
// Detect if multiple schemas are present
103+
let mut schemas = BTreeSet::new();
104+
for t in &schema_info.tables {
105+
schemas.insert(t.schema_name.as_str());
106+
}
107+
for v in &schema_info.views {
108+
schemas.insert(v.schema_name.as_str());
109+
}
110+
let has_multiple_schemas = schemas.len() > 1;
111+
84112
// Generate struct files for each table
85113
for table in &schema_info.tables {
86114
let (tokens, imports) =
87115
struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
88116
let imports = filter_imports(&imports, single_file);
89117
let code = format_tokens_with_imports(&tokens, &imports);
90-
let module_name = normalize_module_name(&table.name);
118+
let module_name = build_module_name(&table.schema_name, &table.name, has_multiple_schemas);
91119
let origin = format!("Table: {}.{}", table.schema_name, table.name);
92120
files.push(GeneratedFile {
93121
filename: format!("{}.rs", module_name),
@@ -102,7 +130,7 @@ pub fn generate(
102130
struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
103131
let imports = filter_imports(&imports, single_file);
104132
let code = format_tokens_with_imports(&tokens, &imports);
105-
let module_name = normalize_module_name(&view.name);
133+
let module_name = build_module_name(&view.schema_name, &view.name, has_multiple_schemas);
106134
let origin = format!("View: {}.{}", view.schema_name, view.name);
107135
files.push(GeneratedFile {
108136
filename: format!("{}.rs", module_name),
@@ -417,6 +445,55 @@ mod tests {
417445
assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
418446
}
419447

448+
// ========== build_module_name ==========
449+
450+
#[test]
451+
fn test_build_single_schema_no_prefix() {
452+
assert_eq!(build_module_name("public", "users", false), "users");
453+
}
454+
455+
#[test]
456+
fn test_build_multi_schema_default_no_prefix() {
457+
assert_eq!(build_module_name("public", "users", true), "users");
458+
}
459+
460+
#[test]
461+
fn test_build_multi_schema_non_default_prefixed() {
462+
assert_eq!(build_module_name("billing", "users", true), "billing_users");
463+
}
464+
465+
#[test]
466+
fn test_build_multi_schema_dbo_no_prefix() {
467+
assert_eq!(build_module_name("dbo", "users", true), "users");
468+
}
469+
470+
#[test]
471+
fn test_build_multi_schema_main_no_prefix() {
472+
assert_eq!(build_module_name("main", "users", true), "users");
473+
}
474+
475+
#[test]
476+
fn test_build_normalizes_double_underscore() {
477+
assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
478+
}
479+
480+
// ========== is_default_schema ==========
481+
482+
#[test]
483+
fn test_default_schema_public() {
484+
assert!(is_default_schema("public"));
485+
}
486+
487+
#[test]
488+
fn test_default_schema_main() {
489+
assert!(is_default_schema("main"));
490+
}
491+
492+
#[test]
493+
fn test_non_default_schema() {
494+
assert!(!is_default_schema("billing"));
495+
}
496+
420497
// ========== imports_for_derives ==========
421498

422499
#[test]
@@ -851,6 +928,39 @@ mod tests {
851928
assert!(files[0].code.contains("Option<String>"));
852929
}
853930

931+
#[test]
932+
fn test_generate_multi_schema_prefixes_non_default() {
933+
let schema = SchemaInfo {
934+
tables: vec![
935+
make_table("users", vec![make_col("id", "int4")]),
936+
TableInfo {
937+
schema_name: "billing".to_string(),
938+
name: "users".to_string(),
939+
columns: vec![make_col("id", "int4")],
940+
},
941+
],
942+
..Default::default()
943+
};
944+
let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
945+
let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
946+
assert!(filenames.contains(&"users.rs"));
947+
assert!(filenames.contains(&"billing_users.rs"));
948+
}
949+
950+
#[test]
951+
fn test_generate_single_schema_no_prefix() {
952+
let schema = SchemaInfo {
953+
tables: vec![
954+
make_table("users", vec![make_col("id", "int4")]),
955+
make_table("posts", vec![make_col("id", "int4")]),
956+
],
957+
..Default::default()
958+
};
959+
let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
960+
assert_eq!(files[0].filename, "users.rs");
961+
assert_eq!(files[1].filename, "posts.rs");
962+
}
963+
854964
#[test]
855965
fn test_generate_view_single_file_mode() {
856966
let schema = SchemaInfo {

crates/sqlx_gen/src/codegen/struct_gen.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ pub fn generate_struct(
9494
.collect();
9595

9696
let table_name_str = &table.name;
97+
let schema_name_str = &table.schema_name;
9798
let kind_str = if is_view { "view" } else { "table" };
9899

99100
let tokens = quote! {
100101
#[derive(#(#derive_tokens),*)]
101-
#[sqlx_gen(kind = #kind_str, table = #table_name_str)]
102+
#[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
102103
pub struct #struct_name {
103104
#(#fields)*
104105
}

crates/sqlx_gen/src/main.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,14 @@ fn run_crud(args: CrudArgs) -> Result<()> {
148148
} else {
149149
std::fs::create_dir_all(&args.output_dir)?;
150150

151-
let filename = format!("{}_repository.rs", entity.table_name);
151+
let base_name = match &entity.schema_name {
152+
Some(schema) if !codegen::is_default_schema(schema) => {
153+
format!("{}_{}", schema, entity.table_name)
154+
}
155+
_ => entity.table_name.clone(),
156+
};
157+
let normalized = codegen::normalize_module_name(&base_name);
158+
let filename = format!("{}_repository.rs", normalized);
152159
let file_path = args.output_dir.join(&filename);
153160

154161
let content = format!(
@@ -160,7 +167,7 @@ fn run_crud(args: CrudArgs) -> Result<()> {
160167
let edition = detect_edition(&args.output_dir);
161168
rustfmt_file(&file_path, &edition);
162169

163-
let mod_name = format!("{}_repository", entity.table_name);
170+
let mod_name = format!("{}_repository", normalized);
164171
update_mod_rs(&args.output_dir, &mod_name)?;
165172
}
166173

0 commit comments

Comments
 (0)