Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ func (d *BaseDialect) DataType(columnType schema.ColumnType) string {
typ := normalizeType(columnType.DataType)

switch typ {
case "smallserial":
return "SMALLINT"
case "serial":
return "INTEGER"
case "bigserial":
return "BIGSERIAL"
Comment on lines +64 to 69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 bigserial still returns "BIGSERIAL" in BaseDialect while the two new serial types return plain integer variants. A custom dialect that embeds BaseDialect would get auto-increment semantics for bigserial columns but not for smallserial or serial, silently dropping auto-increment behavior on those columns.

Suggested change
case "smallserial":
return "SMALLINT"
case "serial":
return "INTEGER"
case "bigserial":
return "BIGSERIAL"
case "smallserial":
return "SMALLINT"
case "serial":
return "INTEGER"
case "bigserial":
return "BIGINT"
Prompt To Fix With AI
This is a comment left during a code review.
Path: pkg/dialect/dialect.go
Line: 64-69

Comment:
`bigserial` still returns "BIGSERIAL" in `BaseDialect` while the two new serial types return plain integer variants. A custom dialect that embeds `BaseDialect` would get auto-increment semantics for `bigserial` columns but not for `smallserial` or `serial`, silently dropping auto-increment behavior on those columns.

```suggestion
	case "smallserial":
		return "SMALLINT"
	case "serial":
		return "INTEGER"
	case "bigserial":
		return "BIGINT"
```

How can I resolve this? If you propose a fix, please make it concise.

Fix in Codex

case "smallint":
Expand Down
4 changes: 4 additions & 0 deletions pkg/dialect/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func (d *MySQLDialect) DataType(columnType schema.ColumnType) string {
typ := normalizeType(columnType.DataType)

switch typ {
case "smallserial":
return "SMALLINT"
case "serial":
return "INT"
case "bigserial":
return "BIGINT"
case "smallint":
Expand Down
4 changes: 4 additions & 0 deletions pkg/dialect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ func (d *PostgresDialect) DataType(columnType schema.ColumnType) string {
typ := normalizeType(columnType.DataType)

switch typ {
case "smallserial":
return "SMALLSERIAL"
case "serial":
return "SERIAL"
case "bigserial":
return "BIGSERIAL"
case "smallint":
Expand Down
2 changes: 1 addition & 1 deletion pkg/dialect/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (d *SQLiteDialect) DataType(columnType schema.ColumnType) string {
typ := normalizeType(columnType.DataType)

switch typ {
case "bigserial":
case "smallserial", "serial", "bigserial":
return "INTEGER"
case "string", "varchar", "text":
return "TEXT"
Expand Down
15 changes: 15 additions & 0 deletions pkg/migrator/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ func planCreateAll(snapshot Snapshot) Plan {
}

func diffTable(previous, current TableSnapshot, dialectName string) ([]string, error) {
if previous.IsView != current.IsView {
return nil, fmt.Errorf("migrator: changing %q from view=%v to view=%v is not supported", current.Name, previous.IsView, current.IsView)
}

if current.IsView {
if normalizeSQL(previous.CreateTableSQL) == normalizeSQL(current.CreateTableSQL) {
return nil, nil
}
// View changed - drop and recreate
return []string{
"DROP VIEW " + quoteIdentifier(dialectName, current.Name),
current.CreateTableSQL,
}, nil
}

var statements []string

previousColumns := make(map[string]ColumnSnapshot, len(previous.Columns))
Expand Down
2 changes: 2 additions & 0 deletions pkg/migrator/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Snapshot struct {
// TableSnapshot stores a portable, deterministic representation of one table.
type TableSnapshot struct {
Name string `json:"name"`
IsView bool `json:"is_view,omitempty"`
CreateTableSQL string `json:"create_table_sql"`
Columns []ColumnSnapshot `json:"columns"`
Constraints []ConstraintSnapshot `json:"constraints"`
Expand Down Expand Up @@ -167,6 +168,7 @@ func BuildSnapshot(dialectName string, tables []schema.TableReference) (Snapshot

tableSnapshots = append(tableSnapshots, TableSnapshot{
Name: tableDef.Name,
IsView: tableDef.IsView,
CreateTableSQL: createTableSQL,
Columns: columnSnapshots,
Constraints: constraintSnapshots,
Expand Down
37 changes: 35 additions & 2 deletions pkg/rain/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
return "", errors.New("rain: create table requires a non-nil table")
}

if table.IsView {
return createViewSQL(d, table)
}

var definitions []string
tablePrimaryKey, err := tablePrimaryKeyConstraint(table)
if err != nil {
Expand Down Expand Up @@ -349,7 +353,7 @@ func columnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string {
typeSQL = fmt.Sprintf("%s(%d)", typeSQL, column.Type.TimePrecision)
}

if column.AutoIncrement && d.Name() == "sqlite" && column.Type.DataType == schema.TypeBigSerial {
if column.AutoIncrement && d.Name() == "sqlite" && (column.Type.DataType == schema.TypeSmallSerial || column.Type.DataType == schema.TypeSerial || column.Type.DataType == schema.TypeBigSerial) {
return "INTEGER"
}

Expand All @@ -363,7 +367,7 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef,
if !inlinePrimaryKey {
return false
}
if column.Type.DataType != schema.TypeBigSerial {
if column.Type.DataType != schema.TypeSmallSerial && column.Type.DataType != schema.TypeSerial && column.Type.DataType != schema.TypeBigSerial {
return true
}

Expand Down Expand Up @@ -568,8 +572,37 @@ func predicateDDLSQL(d dialect.Dialect, table *schema.TableDef, predicate schema
return expressionDDLSQL(d, table, predicate)
}

func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
if table.ViewQuery == nil {
return "", fmt.Errorf("rain: view %q requires a defining query", table.Name)
}

ctx := newCompileContext(d)
ctx.useLiterals = true
// Views usually don't support or need parentheses around the entire SELECT
// across all dialects, and SQLite specifically rejects them.
if selectQuery, ok := table.ViewQuery.(*SelectQuery); ok {
if err := selectQuery.writeSQL(ctx); err != nil {
return "", err
}
} else {
if err := ctx.writeExpression(table.ViewQuery); err != nil {
return "", err
}
}

return "CREATE VIEW " + d.QuoteIdentifier(table.Name) + " AS " + ctx.String(), nil
}

func expressionDDLSQL(d dialect.Dialect, table *schema.TableDef, expr schema.Expression) (string, error) {
switch value := expr.(type) {
case *SelectQuery:
ctx := newCompileContext(d)
ctx.useLiterals = true
if err := value.writeSQL(ctx); err != nil {
return "", err
}
return ctx.String(), nil
case schema.ColumnReference:
column := value.ColumnDef()
if column == nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/rain/model_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func supportsScanForColumn(column *schema.ColumnDef, fieldType reflect.Type) boo
baseType, _ := unwrapFieldType(fieldType)

switch column.Type.DataType {
case schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt:
case schema.TypeSmallSerial, schema.TypeSerial, schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt:
return isIntegerKind(baseType.Kind())
case schema.TypeReal, schema.TypeDouble:
return baseType.Kind() == reflect.Float32 || baseType.Kind() == reflect.Float64
Expand Down Expand Up @@ -242,7 +242,7 @@ func supportsWriteForColumn(column *schema.ColumnDef, fieldType reflect.Type) bo
baseType, _ := unwrapFieldType(fieldType)

switch column.Type.DataType {
case schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt:
case schema.TypeSmallSerial, schema.TypeSerial, schema.TypeBigSerial, schema.TypeSmallInt, schema.TypeInteger, schema.TypeBigInt:
return isIntegerKind(baseType.Kind())
case schema.TypeReal, schema.TypeDouble:
return baseType.Kind() == reflect.Float32 || baseType.Kind() == reflect.Float64
Expand Down
35 changes: 27 additions & 8 deletions pkg/rain/query_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ func (q compiledQuery) bind(args PreparedArgs) ([]any, error) {
}

type compileContext struct {
builder strings.Builder
dialect dialect.Dialect
argPlan []compiledArg
err error
skipCTEs bool
builder strings.Builder
dialect dialect.Dialect
argPlan []compiledArg
err error
skipCTEs bool
useLiterals bool
}

func newCompileContext(d dialect.Dialect) *compileContext {
Expand Down Expand Up @@ -137,6 +138,15 @@ func (c *compileContext) writeTable(table *schema.TableDef) {
}
}

func (c *compileContext) writeLiteral(value any) error {
literal, err := literalDDLSQL(c.dialect, value)
if err != nil {
return err
}
c.writeString(literal)
return nil
}

func (c *compileContext) writeReturning(exprs []schema.Expression, clause returningClause) error {
if len(exprs) == 0 {
return nil
Expand Down Expand Up @@ -180,9 +190,15 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex
case schema.ColumnReference:
c.writeColumn(value)
case schema.ValueExpr:
index := c.nextPlaceholderIndex()
c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value})
c.writeString(c.dialect.Placeholder(index))
if c.useLiterals {
if err := c.writeLiteral(value.Value); err != nil {
return err
}
} else {
index := c.nextPlaceholderIndex()
c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value})
c.writeString(c.dialect.Placeholder(index))
}
case schema.PlaceholderExpr:
if strings.TrimSpace(value.Name) == "" {
return errors.New("rain: placeholder name cannot be empty")
Expand Down Expand Up @@ -260,9 +276,12 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex
if !context.noParens {
c.writeByte('(')
}
prevSkip := c.skipCTEs
c.skipCTEs = true
if err := value.writeSQL(c); err != nil {
return err
}
c.skipCTEs = prevSkip
if !context.noParens {
c.writeByte(')')
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/rain/query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/hyperlocalise/rain-orm/pkg/dialect"
"github.com/hyperlocalise/rain-orm/pkg/schema"
Expand Down Expand Up @@ -48,6 +49,9 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) {
if q.table == nil {
return "", nil, errors.New("rain: delete query requires a table")
}
if q.table.IsView {
return "", nil, fmt.Errorf("rain: cannot delete from view %q", q.table.Name)
}
if len(q.where) == 0 && !q.unbounded {
return "", nil, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows")
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/rain/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ func (q *InsertQuery) validateSources() error {
if q.table == nil {
return errors.New("rain: insert query requires a table")
}
if q.table.IsView {
return fmt.Errorf("rain: cannot insert into view %q", q.table.Name)
}

sources := 0
if q.model != nil || len(q.values) > 0 {
Expand Down
4 changes: 4 additions & 0 deletions pkg/rain/query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/hyperlocalise/rain-orm/pkg/dialect"
"github.com/hyperlocalise/rain-orm/pkg/schema"
Expand Down Expand Up @@ -62,6 +63,9 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) {
if q.table == nil {
return "", nil, errors.New("rain: update query requires a table")
}
if q.table.IsView {
return "", nil, fmt.Errorf("rain: cannot update view %q", q.table.Name)
}
if len(q.values) == 0 {
return "", nil, errors.New("rain: update query requires at least one assignment")
}
Expand Down
46 changes: 45 additions & 1 deletion pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type TimestampKind string

// Supported schema data types.
const (
TypeSmallSerial DataType = "SMALLSERIAL"
TypeSerial DataType = "SERIAL"
TypeBigSerial DataType = "BIGSERIAL"
TypeSmallInt DataType = "SMALLINT"
TypeInteger DataType = "INTEGER"
Expand Down Expand Up @@ -111,6 +113,8 @@ type ColumnType struct {
type TableDef struct {
Name string
Alias string
IsView bool
ViewQuery Expression
Columns []*ColumnDef
Indexes []IndexDef
Constraints []ConstraintDef
Expand Down Expand Up @@ -258,6 +262,16 @@ func (t *TableModel) C(name string) *AnyColumn {
return &AnyColumn{def: col}
}

// SmallSerial adds a SMALLSERIAL column.
func (t *TableModel) SmallSerial(name string) *Column[int16] {
return addColumn[int16](t.def, name, ColumnType{DataType: TypeSmallSerial}, false, true)
}

// Serial adds a SERIAL column.
func (t *TableModel) Serial(name string) *Column[int32] {
return addColumn[int32](t.def, name, ColumnType{DataType: TypeSerial}, false, true)
}

// BigSerial adds a BIGSERIAL column.
func (t *TableModel) BigSerial(name string) *Column[int64] {
return addColumn[int64](t.def, name, ColumnType{DataType: TypeBigSerial}, false, true)
Expand Down Expand Up @@ -483,6 +497,31 @@ func Define[T any](name string, fn func(*T)) *T {
return handle
}

// DefineView creates a typed view handle backed by schema metadata and a defining query.
func DefineView[T any](name string, query Expression, fn func(*T)) *T {
if query == nil {
panic("schema: DefineView requires a query")
}

handle := new(T)
def := &TableDef{
Name: name,
IsView: true,
ViewQuery: query,
Columns: make([]*ColumnDef, 0, 8),
Indexes: make([]IndexDef, 0),
Constraints: make([]ConstraintDef, 0),
ForeignKeys: make([]ForeignKeyDef, 0),
Relations: make([]RelationDef, 0, 4),
columnsByName: make(map[string]*ColumnDef, 8),
relationsByName: make(map[string]RelationDef, 4),
}
bindTableModel(handle, def)
fn(handle)

return handle
}

// Alias clones a typed table handle with a SQL alias.
func Alias[T any](src *T, alias string) *T {
clone := new(T)
Expand Down Expand Up @@ -578,7 +617,7 @@ func (c *Column[T]) ColumnDef() *ColumnDef {
func (c *Column[T]) PrimaryKey() *Column[T] {
c.def.PrimaryKey = true
c.def.Nullable = false
if c.def.Type.DataType == TypeBigSerial {
if c.def.Type.DataType == TypeSmallSerial || c.def.Type.DataType == TypeSerial || c.def.Type.DataType == TypeBigSerial {
c.def.AutoIncrement = true
}

Expand Down Expand Up @@ -1384,6 +1423,11 @@ func cloneTableDef(src *TableDef, alias string) *TableDef {
relationsByName: make(map[string]RelationDef, len(src.Relations)),
}

cloned.IsView = src.IsView
if src.ViewQuery != nil {
cloned.ViewQuery = cloneExpressionForTable(src.ViewQuery, cloned)
}

for _, column := range src.Columns {
copyColumn := *column
copyColumn.Type.EnumValues = append([]string(nil), column.Type.EnumValues...)
Expand Down