Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestPostgresDialect(t *testing.T) {
if got := d.Name(); got != "postgres" {
t.Fatalf("unexpected name: %q", got)
}
if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureCTE|FeatureDefaultPlaceholder|FeatureSavepoint|FeatureSelectLocking|FeatureNullsOrder {
if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureCTE|FeatureDefaultPlaceholder|FeatureSavepoint|FeatureSelectLocking|FeatureNullsOrder|FeatureSelectDistinctOn {
t.Fatalf("unexpected features: %b", got)
}
if got := d.QuoteIdentifier(`user"name`); got != `"user""name"` {
Expand Down
1 change: 1 addition & 0 deletions pkg/dialect/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
FeatureSavepoint
FeatureSelectLocking
FeatureNullsOrder
FeatureSelectDistinctOn
)

// HasFeature reports whether a feature set includes the requested capability.
Expand Down
3 changes: 2 additions & 1 deletion pkg/dialect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func (d *PostgresDialect) Features() Feature {
FeatureDefaultPlaceholder |
FeatureSavepoint |
FeatureSelectLocking |
FeatureNullsOrder
FeatureNullsOrder |
FeatureSelectDistinctOn
}

// QuoteIdentifier quotes identifiers with double quotes.
Expand Down
2 changes: 1 addition & 1 deletion pkg/rain/query_runtime_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func TestPreparedSelectQueryAllowsScanWhenPreparedCountIsUnsupported(t *testing.
if len(rows) != 1 || rows[0].UserCount != 2 {
t.Fatalf("unexpected grouped rows: %#v", rows)
}
if _, err := prepared.Count(ctx, PreparedArgs{"active": true}); err == nil || !strings.Contains(err.Error(), "aggregate helpers do not support DISTINCT, GROUP BY, or HAVING clauses") {
if _, err := prepared.Count(ctx, PreparedArgs{"active": true}); err == nil || !strings.Contains(err.Error(), "aggregate helpers do not support DISTINCT, DISTINCT ON, GROUP BY, or HAVING clauses") {
t.Fatalf("expected prepared count grouped-query error, got %v", err)
}
}
Expand Down
36 changes: 33 additions & 3 deletions pkg/rain/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type SelectQuery struct {
firstOperand *SelectQuery
setOps []setOperation
distinct bool
distinctOn []schema.Expression
limit int
offset int
relationNames []string
Expand Down Expand Up @@ -89,6 +90,13 @@ func (q *SelectQuery) Distinct() *SelectQuery {
return q
}

// DistinctOn marks the SELECT query as DISTINCT ON the provided expressions.
// Supported by PostgreSQL.
func (q *SelectQuery) DistinctOn(exprs ...schema.Expression) *SelectQuery {
q.distinctOn = append(q.distinctOn, exprs...)
return q
}

// GroupBy appends GROUP BY expressions.
func (q *SelectQuery) GroupBy(exprs ...schema.Expression) *SelectQuery {
q.groupBy = append(q.groupBy, exprs...)
Expand Down Expand Up @@ -234,6 +242,7 @@ func (q *SelectQuery) clone() *SelectQuery {
newQ.having = append([]schema.Predicate(nil), q.having...)
newQ.ctes = append([]cteDefinition(nil), q.ctes...)
newQ.setOps = append([]setOperation(nil), q.setOps...)
newQ.distinctOn = append([]schema.Expression(nil), q.distinctOn...)
newQ.relationNames = append([]string(nil), q.relationNames...)
if q.locking != nil {
copyLocking := *q.locking
Expand Down Expand Up @@ -287,7 +296,7 @@ func (q *SelectQuery) withSQLiteInsertSelectConflictWhereChanged() (*SelectQuery
func (q *SelectQuery) isBareCompound() bool {
return q.firstOperand != nil &&
len(q.order) == 0 && q.limit == 0 && q.offset == 0 &&
!q.distinct && len(q.cols) == 0 && q.table == nil &&
!q.distinct && len(q.distinctOn) == 0 && len(q.cols) == 0 && q.table == nil &&
len(q.where) == 0 && len(q.joins) == 0 &&
len(q.groupBy) == 0 && len(q.having) == 0 &&
len(q.relationNames) == 0 && len(q.ctes) == 0 &&
Expand Down Expand Up @@ -390,6 +399,20 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error {
ctx.writeString("SELECT ")
if q.distinct {
ctx.writeString("DISTINCT ")
} else if len(q.distinctOn) > 0 {
if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureSelectDistinctOn) {
return fmt.Errorf("rain: SELECT DISTINCT ON is not supported by %s dialect", ctx.dialect.Name())
}
ctx.writeString("DISTINCT ON (")
for idx, expr := range q.distinctOn {
if idx > 0 {
ctx.writeString(", ")
}
if err := ctx.writeExpression(expr); err != nil {
return err
}
}
ctx.writeString(") ")
}
if len(q.cols) == 0 {
ctx.writeString("*")
Expand Down Expand Up @@ -799,10 +822,17 @@ func (q *SelectQuery) compile() (compiledQuery, error) {
return compiledQuery{}, errors.New("rain: select query requires a table")
}

if q.distinct && len(q.distinctOn) > 0 {
return compiledQuery{}, errors.New("rain: SELECT DISTINCT and DISTINCT ON cannot be used together")
}

if q.firstOperand != nil {
if q.distinct {
return compiledQuery{}, errors.New("rain: compound queries do not support DISTINCT")
}
if len(q.distinctOn) > 0 {
return compiledQuery{}, errors.New("rain: compound queries do not support DISTINCT ON")
}
if len(q.cols) > 0 {
return compiledQuery{}, errors.New("rain: compound queries do not support Column()")
}
Expand Down Expand Up @@ -846,8 +876,8 @@ func (q *SelectQuery) compileAggregate(selection string) (compiledQuery, error)
if len(q.ctes) > 0 {
return compiledQuery{}, errors.New("rain: aggregate helpers do not support WITH clauses")
}
if q.distinct || len(q.groupBy) > 0 || len(q.having) > 0 {
return compiledQuery{}, errors.New("rain: aggregate helpers do not support DISTINCT, GROUP BY, or HAVING clauses")
if q.distinct || len(q.distinctOn) > 0 || len(q.groupBy) > 0 || len(q.having) > 0 {
return compiledQuery{}, errors.New("rain: aggregate helpers do not support DISTINCT, DISTINCT ON, GROUP BY, or HAVING clauses")
}
if q.locking != nil {
return compiledQuery{}, errors.New("rain: aggregate helpers do not support FOR locking clauses")
Expand Down
119 changes: 119 additions & 0 deletions pkg/rain/query_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,122 @@ func TestSelectLockingToSQL(t *testing.T) {
})
}
}

func TestSelectDistinctOnToSQL(t *testing.T) {
t.Parallel()

users, _ := defineTables()

type tc struct {
name string
dialect string
build func(*rain.DB) *rain.SelectQuery
wantSQL string
wantErr string
}

cases := []tc{
{
name: "distinct on one column postgres",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(users.ID).Column(users.ID, users.Email)
},
wantSQL: `SELECT DISTINCT ON ("users"."id") "users"."id", "users"."email" FROM "users"`,
},
{
name: "distinct on multiple columns postgres",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(users.ID, users.Email).Column(users.ID)
},
wantSQL: `SELECT DISTINCT ON ("users"."id", "users"."email") "users"."id" FROM "users"`,
},
{
name: "distinct on with expression postgres",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(schema.Raw("LOWER(email)")).Column(users.ID)
},
wantSQL: `SELECT DISTINCT ON (LOWER(email)) "users"."id" FROM "users"`,
},
{
name: "distinct on unsupported on mysql",
dialect: "mysql",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(users.ID)
},
wantErr: "SELECT DISTINCT ON is not supported by mysql dialect",
},
{
name: "distinct on unsupported on sqlite",
dialect: "sqlite",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(users.ID)
},
wantErr: "SELECT DISTINCT ON is not supported by sqlite dialect",
},
{
name: "distinct and distinct on together error",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).Distinct().DistinctOn(users.ID)
},
wantErr: "SELECT DISTINCT and DISTINCT ON cannot be used together",
},
{
name: "distinct on in compound query error",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
q1 := db.Select().Table(users)
q2 := db.Select().Table(users)
return q1.Union(q2).DistinctOn(users.ID)
},
wantErr: "compound queries do not support DISTINCT ON",
},
{
name: "aggregate helper fails with distinct on",
dialect: "postgres",
build: func(db *rain.DB) *rain.SelectQuery {
return db.Select().Table(users).DistinctOn(users.ID)
},
wantErr: "aggregate helpers do not support DISTINCT, DISTINCT ON, GROUP BY, or HAVING clauses",
},
}

for _, tt := range cases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect(tt.dialect)
if err != nil {
t.Fatalf("OpenDialect returned error: %v", err)
}

q := tt.build(db)

if strings.HasPrefix(tt.name, "aggregate helper fails") {
_, err := q.Count(context.Background())
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("expected error containing %q, got %v", tt.wantErr, err)
}
return
}

sqlText, _, err := q.ToSQL()
if tt.wantErr != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("expected error containing %q, got %v", tt.wantErr, err)
}
return
}
if err != nil {
t.Fatalf("ToSQL returned error: %v", err)
}
if sqlText != tt.wantSQL {
t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", tt.wantSQL, sqlText)
}
})
}
}
3 changes: 2 additions & 1 deletion pkg/rain/query_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ func TestDialectFeatures(t *testing.T) {
dialect.FeatureDefaultPlaceholder |
dialect.FeatureSavepoint |
dialect.FeatureSelectLocking |
dialect.FeatureNullsOrder,
dialect.FeatureNullsOrder |
dialect.FeatureSelectDistinctOn,
},
{
name: "mysql",
Expand Down