Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ SELECT unnest(@book_ids::bigint[]), unnest(@tag_ids::bigint[]);
```

The generator will:

- On **PostgreSQL**: delegate `AddBookTags` directly to the underlying sqlc implementation
- On **SQLite/MySQL**: generate a loop that calls `AddBookTag` once per element

Expand Down
6 changes: 3 additions & 3 deletions example/pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
"github.com/kalbasit/sqlc-multi-db/example/pkg/database/postgresdb"
"github.com/kalbasit/sqlc-multi-db/example/pkg/database/sqlitedb"

_ "github.com/go-sql-driver/mysql" // MySQL driver
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
_ "github.com/mattn/go-sqlite3" // SQLite driver
_ "github.com/go-sql-driver/mysql" // MySQL driver
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
_ "github.com/mattn/go-sqlite3" // SQLite driver
)

// Open opens a database connection and returns a Querier.
Expand Down
6 changes: 2 additions & 4 deletions example/pkg/database/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import (
"github.com/mattn/go-sqlite3"
)

var (
// ErrUnsupportedDriver is returned when the database driver is not recognized.
ErrUnsupportedDriver = errors.New("unsupported database driver")
)
// ErrUnsupportedDriver is returned when the database driver is not recognized.
var ErrUnsupportedDriver = errors.New("unsupported database driver")

// IsDeadlockError checks if the error is a deadlock or "database busy" error.
func IsDeadlockError(err error) bool {
Expand Down
1 change: 0 additions & 1 deletion example/pkg/database/postgresdb/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions generator/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ const (
typeAny = "interface{}"
typeBool = "bool"
typeString = "string"
typeInt = "int"
typeBytes = "[]byte"
zeroNil = "nil"
typeInt16 = "int16"
typeInt32 = "int32"
Expand Down
77 changes: 77 additions & 0 deletions generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,80 @@ func TestGenerateFieldConversion(t *testing.T) {
})
}
}

// TestJoinParamsCallFieldMapping tests that JoinParamsCall correctly maps
// struct fields even when field names differ between source and target.
// This is a regression test for the MySQL LIMIT parameter issue where
// sqlc generates different field names (e.g., BatchSize vs Limit).
func TestJoinParamsCallFieldMapping(t *testing.T) {
t.Parallel()

// Source structs (domain) - what the wrapper API uses
sourceStructs := map[string]generator.StructInfo{
"GetStuckNarFilesParams": {
Name: "GetStuckNarFilesParams",
Fields: []generator.FieldInfo{
{Name: "CutoffTime", Type: "time.Time"},
{Name: "BatchSize", Type: "int32"},
},
},
}

// Target structs (adapter) - what the database engine generates
// MySQL generates different names: CreatedAt instead of CutoffTime, Limit instead of BatchSize
targetStructs := map[string]generator.StructInfo{
"GetStuckNarFilesParams": {
Name: "GetStuckNarFilesParams",
Fields: []generator.FieldInfo{
{Name: "CreatedAt", Type: "time.Time"},
{Name: "Limit", Type: "int32"},
},
},
}

// Target method info
targetMethod := generator.MethodInfo{
Name: "GetStuckNarFiles",
Params: []generator.Param{
{Name: "ctx", Type: "context.Context"},
{Name: "arg", Type: "GetStuckNarFilesParams"},
},
}

tests := []struct {
name string
params []generator.Param
engPkg string
want string
wantErr bool
}{
{
name: "Field mapping with different names",
params: []generator.Param{
{Name: "ctx", Type: "context.Context"},
{Name: "arg", Type: "GetStuckNarFilesParams"},
},
engPkg: "mysqldb",
// Expected: both fields should be mapped even though names differ
// The target struct uses its own field names (CreatedAt, Limit), not source names
want: "ctx, mysqldb.GetStuckNarFilesParams{\nCreatedAt: arg.CutoffTime,\nLimit: arg.BatchSize,\n}",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got, err := generator.JoinParamsCall(tt.params, tt.engPkg, targetMethod, targetStructs, sourceStructs)
if (err != nil) != tt.wantErr {
t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr)

return
}

if got != tt.want {
t.Errorf("JoinParamsCall() = %v, want %v", got, tt.want)
}
})
}
}
133 changes: 120 additions & 13 deletions generator/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,106 @@ func JoinParamsCall(
return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs)
}

// findSourceField finds a matching field in available source fields using multiple strategies:
// 1. Exact name match
// 2. Case-insensitive match
// 3. Snake_case match
// 4. Position-based match (fallback when structs have same field count).
// The availableSourceFields map is modified to remove matched fields.
func findSourceField(
targetField FieldInfo,
targetIdx int,
targetStruct StructInfo,
sourceStruct StructInfo,
availableSourceFields map[string]FieldInfo,
) (FieldInfo, bool) {
// Strategy 1: Exact name match
if sf, ok := availableSourceFields[targetField.Name]; ok {
return sf, true
}

// Strategy 2: Case-insensitive match
for _, sf := range availableSourceFields {
if strings.EqualFold(sf.Name, targetField.Name) {
return sf, true
}
}

// Strategy 3: Snake_case match
targetSnake := toSnakeCase(targetField.Name)
for _, sf := range availableSourceFields {
if toSnakeCase(sf.Name) == targetSnake {
return sf, true
}
}

// Strategy 4: Position-based match (fallback when structs have same field count)
// Only use position matching if the structs have the same number of fields
if len(sourceStruct.Fields) != len(targetStruct.Fields) || len(sourceStruct.Fields) == 0 {
return FieldInfo{}, false
}
// Match by position - use the field at the same index in source
if targetIdx >= len(sourceStruct.Fields) {
return FieldInfo{}, false
}

originalSourceField := sourceStruct.Fields[targetIdx]
// Check if it's still available
sf, ok := availableSourceFields[originalSourceField.Name]
if !ok {
return FieldInfo{}, false
}
// Verify types are compatible
if fieldsCompatible(sf.Type, targetField.Type) {
return sf, true
}

return FieldInfo{}, false
}

// fieldsCompatible checks if two field types are compatible for mapping.
func fieldsCompatible(sourceType, targetType string) bool {
// Normalize types for comparison
sourceBase := normalizeType(sourceType)
targetBase := normalizeType(targetType)

return sourceBase == targetBase
}

// normalizeType normalizes a type string for comparison.
func normalizeType(t string) string {
// Remove common prefixes/suffixes
t = strings.TrimPrefix(t, "[]")
t = strings.TrimPrefix(t, "*")

// Handle time types
if strings.Contains(t, "time.Time") || strings.Contains(t, "NullTime") {
return "time"
}

// Handle numeric types
switch t {
case typeInt, "int8", "int16", "int32", "int64",
"uint", "uint8", "uint16", "uint32", "uint64",
sqlNullInt32, sqlNullInt64:
return typeInt
case "float32", "float64", sqlNullFloat64:
return "float"
case typeString, sqlNullString, typeBytes:
return typeString
case typeBool, sqlNullBool:
return typeBool
}

// Remove package prefix if present
parts := strings.Split(t, ".")
if len(parts) > 1 {
return parts[len(parts)-1]
}

return t
}

func joinDomainStructParam(
param Param,
i int,
Expand All @@ -153,23 +253,28 @@ func joinDomainStructParam(

if targetParamType != "" {
sourceStruct := sourceStructs[param.Type]
targetStruct := targetStructs[targetParamType]

var fields []string
// Target struct keys may include the package prefix (e.g., "mysqldb.GetStuckNarFilesParams")
// Try with prefix first, then without
targetStructKey := targetParamType
if engPkg != "" {
if _, ok := targetStructs[engPkg+"."+targetParamType]; ok {
targetStructKey = engPkg + "." + targetParamType
}
// Otherwise keep using targetParamType (no prefix)
}

for _, targetField := range targetStruct.Fields {
var sourceField FieldInfo
targetStruct := targetStructs[targetStructKey]

found := false
// Create a map of available source fields to track which fields have been mapped.
availableSourceFields := make(map[string]FieldInfo, len(sourceStruct.Fields))
for _, sf := range sourceStruct.Fields {
availableSourceFields[sf.Name] = sf
}

for _, sf := range sourceStruct.Fields {
if sf.Name == targetField.Name {
sourceField = sf
found = true
var fields []string

break
}
}
for targetIdx, targetField := range targetStruct.Fields {
sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct, availableSourceFields)

if found {
conversion := generateFieldConversion(
Expand All @@ -179,6 +284,8 @@ func joinDomainStructParam(
fmt.Sprintf("%s.%s", param.Name, sourceField.Name),
)
fields = append(fields, conversion)
// Remove the mapped field so it can't be used again.
delete(availableSourceFields, sourceField.Name)
}
}

Expand Down
Loading