Skip to content
Open
2 changes: 1 addition & 1 deletion pkg/rain/coverage_target_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ func TestCoverageDDLMethodsAndHelpers(t *testing.T) {
if _, err := columnDefinitionSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "broken_default", Type: schema.ColumnType{DataType: schema.TypeText}, HasDefault: true, Default: struct{}{}}, false); err == nil {
t.Fatalf("expected columnDefinitionSQL default error")
}
if got := columnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" {
if got := ddlColumnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" {
t.Fatalf("unexpected sqlite timestamp type: %q", got)
}
if shouldEmitAutoIncrementKeyword(pg, &schema.ColumnDef{Name: "id", Type: schema.ColumnType{DataType: schema.TypeBigSerial}}, true) {
Expand Down
68 changes: 61 additions & 7 deletions pkg/rain/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func (db *DB) CreateTableSQL(table schema.TableReference) (string, error) {
return "", errors.New("rain: create table requires a non-nil table")
}

if table.TableDef().IsView {
return createViewSQL(db.dialect, table.TableDef())
}

return createTableSQL(db.dialect, table.TableDef())
}

Expand All @@ -32,6 +36,10 @@ func (db *DB) CreateIndexesSQL(table schema.TableReference) ([]string, error) {
return nil, errors.New("rain: create indexes requires a non-nil table")
}

if table.TableDef().IsView {
return nil, nil
}

return createIndexesSQL(db.dialect, table.TableDef())
}

Expand All @@ -50,6 +58,10 @@ func (db *DB) ColumnDefinitionSQL(table schema.TableReference, columnName string
return "", fmt.Errorf("rain: table %q has no column %q", tableDef.Name, columnName)
}

if tableDef.IsView {
return db.dialect.QuoteIdentifier(column.Name) + " " + ddlColumnTypeSQL(db.dialect, column), nil
}

inlinePrimaryKey := false
tablePrimaryKey, err := tablePrimaryKeyConstraint(tableDef)
if err != nil {
Expand All @@ -73,6 +85,10 @@ func (db *DB) AddConstraintSQL(table schema.TableReference, constraintName strin
}

tableDef := table.TableDef()
if tableDef.IsView {
return "", fmt.Errorf("rain: view %q does not support constraints", tableDef.Name)
}

for _, constraint := range tableDef.Constraints {
if constraint.Name != constraintName {
continue
Expand All @@ -97,6 +113,10 @@ func (db *DB) AddForeignKeySQL(table schema.TableReference, foreignKeyName strin
}

tableDef := table.TableDef()
if tableDef.IsView {
return "", fmt.Errorf("rain: view %q does not support foreign keys", tableDef.Name)
}

for _, foreignKey := range tableDef.ForeignKeys {
if foreignKey.Name != foreignKeyName {
continue
Expand Down Expand Up @@ -132,6 +152,35 @@ func (db *DB) ColumnDefaultSQL(table schema.TableReference, columnName string) (
return columnDefaultSQL(db.dialect, column)
}

func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
if d == nil {
return "", errors.New("rain: create view requires a configured dialect")
}
if table == nil {
return "", errors.New("rain: create view requires a non-nil table")
}
if !table.IsView {
return "", fmt.Errorf("rain: table %q is not a view", table.Name)
}
if table.ViewQuery == nil {
return "", fmt.Errorf("rain: view %q requires a defining query", table.Name)
}

ctx := newCompileContext(d)
ctx.useLiterals = true
if err := ctx.writeExpressionInContext(table.ViewQuery, expressionContext{noParens: true}); err != nil {
return "", err
}

var builder strings.Builder
builder.WriteString("CREATE VIEW ")
builder.WriteString(d.QuoteIdentifier(table.Name))
builder.WriteString(" AS ")
builder.WriteString(ctx.String())

return builder.String(), nil
}

func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
if d == nil {
return "", errors.New("rain: create table requires a configured dialect")
Expand Down Expand Up @@ -297,7 +346,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche
var parts []string
parts = append(parts, d.QuoteIdentifier(column.Name))

typeSQL := columnTypeSQL(d, column)
typeSQL := ddlColumnTypeSQL(d, column)
parts = append(parts, typeSQL)

if inlinePrimaryKey {
Expand Down Expand Up @@ -338,7 +387,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche
return strings.Join(parts, " "), nil
}

func columnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string {
func ddlColumnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string {
typeSQL := d.DataType(column.Type)

if column.Type.DataType == schema.TypeVarChar && column.Type.Size > 0 && strings.EqualFold(typeSQL, "VARCHAR") {
Expand All @@ -363,20 +412,25 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef,
if !inlinePrimaryKey {
return false
}
if column.Type.DataType != schema.TypeBigSerial {
return true
}

switch d.Name() {
case "postgres":
return false
return !isPostgresSerialType(column.Type.DataType)
case "sqlite":
return true
default:
return true
}
}

func isPostgresSerialType(dataType schema.DataType) bool {
switch dataType {
case schema.TypeBigSerial, schema.TypeSerial, schema.TypeSmallSerial:
return true
default:
return false
}
}

func columnDefaultSQL(d dialect.Dialect, column *schema.ColumnDef) (string, error) {
if column.DefaultSQL != "" {
return column.DefaultSQL, nil
Expand Down
106 changes: 106 additions & 0 deletions pkg/rain/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ type ddlMembershipsTable struct {
Active *schema.Column[bool]
}

type ddlSerialTable struct {
schema.TableModel
ID *schema.Column[int32]
}

type ddlSmallSerialTable struct {
schema.TableModel
ID *schema.Column[int16]
}

type ddlUserEmailView struct {
schema.TableModel
Email *schema.Column[string]
}

func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) {
users := schema.Define("users", func(t *ddlUsersTable) {
t.ID = t.BigSerial("id").PrimaryKey()
Expand Down Expand Up @@ -78,6 +93,97 @@ func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) {
return users, posts, memberships
}

func TestCreateViewSQLRawExprUsesLiterals(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect("postgres")
if err != nil {
t.Fatalf("OpenDialect(postgres): %v", err)
}
users, _, _ := defineDDLTables()
query := db.Select().
Table(users).
Column(users.Email).
Where(schema.Raw("? = ?", users.Email, "alice@example.com"))
view := schema.DefineView("user_email_view", query, func(v *ddlUserEmailView) {
v.Email = v.VarChar("email", 255)
})

sql, err := db.CreateTableSQL(view)
if err != nil {
t.Fatalf("CreateTableSQL(view): %v", err)
}
if strings.Contains(sql, "$1") || strings.Contains(sql, "$2") {
t.Fatalf("expected view DDL to inline raw args, got:\n%s", sql)
}
if !strings.Contains(sql, `"users"."email" = 'alice@example.com'`) {
t.Fatalf("expected view DDL to include literalized raw predicate, got:\n%s", sql)
}
}

func TestAliasViewWithSelectQueryDoesNotPanic(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect("postgres")
if err != nil {
t.Fatalf("OpenDialect(postgres): %v", err)
}
users, _, _ := defineDDLTables()
query := db.Select().Table(users).Column(users.Email)
view := schema.DefineView("user_email_view_alias_source", query, func(v *ddlUserEmailView) {
v.Email = v.VarChar("email", 255)
})

aliased := schema.Alias(view, "uev")
sql, args, err := db.Select().Table(aliased).Column(aliased.Email).ToSQL()
if err != nil {
t.Fatalf("Select aliased view: %v", err)
}
if len(args) != 0 {
t.Fatalf("expected no args, got %#v", args)
}
if !strings.Contains(sql, `FROM "user_email_view_alias_source" AS "uev"`) {
t.Fatalf("expected aliased view table source, got:\n%s", sql)
}
}

func TestCreateTableSQLPostgresSerialPrimaryKeysDoNotRepeatSerialKeyword(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect("postgres")
if err != nil {
t.Fatalf("OpenDialect(postgres): %v", err)
}
serialTable := schema.Define("serial_ids", func(t *ddlSerialTable) {
t.ID = t.Serial("id").PrimaryKey()
})
smallSerialTable := schema.Define("small_serial_ids", func(t *ddlSmallSerialTable) {
t.ID = t.SmallSerial("id").PrimaryKey()
})

for _, tc := range []struct {
name string
table schema.TableReference
want string
}{
{name: "serial", table: serialTable, want: `"id" SERIAL PRIMARY KEY`},
{name: "smallserial", table: smallSerialTable, want: `"id" SMALLSERIAL PRIMARY KEY`},
} {
t.Run(tc.name, func(t *testing.T) {
sql, err := db.CreateTableSQL(tc.table)
if err != nil {
t.Fatalf("CreateTableSQL: %v", err)
}
if !strings.Contains(sql, tc.want) {
t.Fatalf("expected SQL to contain %q, got:\n%s", tc.want, sql)
}
if strings.Contains(sql, "PRIMARY KEY SERIAL") || strings.Contains(sql, "PRIMARY KEY SMALLSERIAL") {
t.Fatalf("expected SQL not to repeat serial keyword, got:\n%s", sql)
}
})
}
}

func TestCreateTableSQLAcrossDialects(t *testing.T) {
t.Parallel()

Expand Down
34 changes: 29 additions & 5 deletions pkg/rain/query_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ func (q compiledQuery) bind(args PreparedArgs) ([]any, error) {
}

type compileContext struct {
builder strings.Builder
dialect dialect.Dialect
argPlan []compiledArg
err error
skipCTEs bool
builder strings.Builder
dialect dialect.Dialect
argPlan []compiledArg
err error
skipCTEs bool
useLiterals bool
}

func newCompileContext(d dialect.Dialect) *compileContext {
Expand Down Expand Up @@ -180,6 +181,14 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex
case schema.ColumnReference:
c.writeColumn(value)
case schema.ValueExpr:
if c.useLiterals {
literal, err := literalDDLSQL(c.dialect, value.Value)
if err != nil {
return err
}
c.writeString(literal)
return nil
}
index := c.nextPlaceholderIndex()
c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value})
c.writeString(c.dialect.Placeholder(index))
Expand Down Expand Up @@ -386,6 +395,21 @@ func (c *compileContext) writeRaw(raw schema.RawExpr) error {
if argIndex >= len(raw.Args) {
return errors.New("rain: raw SQL placeholder count does not match args")
}
if c.useLiterals {
if expr, ok := raw.Args[argIndex].(schema.Expression); ok {
if err := c.writeExpression(expr); err != nil {
return err
}
} else {
literal, err := literalDDLSQL(c.dialect, raw.Args[argIndex])
if err != nil {
return err
}
c.writeString(literal)
}
argIndex++
continue
}
index := c.nextPlaceholderIndex()
c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: raw.Args[argIndex]})
c.writeString(c.dialect.Placeholder(index))
Expand Down
6 changes: 6 additions & 0 deletions pkg/rain/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ func (q *SelectQuery) clone() *SelectQuery {
return &newQ
}

// CloneExpressionForTable preserves SELECT subqueries when schema metadata is
// cloned for an alias. The query's own table sources remain unchanged.
func (q *SelectQuery) CloneExpressionForTable(*schema.TableDef) schema.Expression {
return q
}

func (q *SelectQuery) withSQLiteInsertSelectConflictWhere() *SelectQuery {
rewritten, _ := q.withSQLiteInsertSelectConflictWhereChanged()
return rewritten
Expand Down
Loading