Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified gofmt.sh
100644 → 100755
Empty file.
174 changes: 172 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sql

import (
"fmt"
"io"
"strings"
)
Expand Down Expand Up @@ -1304,9 +1305,9 @@ func (p *Parser) parseTriggerBodyStatement() (stmt Statement, err error) {
case INSERT, REPLACE:
stmt, err = p.parseInsertStatement(nil)
case UPDATE:
stmt, err = p.parseUpdateStatement(nil)
stmt, err = p.parseTriggerBodyUpdateStatement(nil)
case DELETE:
stmt, err = p.parseDeleteStatement(nil)
stmt, err = p.parseTriggerBodyDeleteStatement(nil)
case WITH:
stmt, err = p.parseWithStatement()
default:
Expand All @@ -1325,6 +1326,175 @@ func (p *Parser) parseTriggerBodyStatement() (stmt Statement, err error) {
return stmt, nil
}

// parseTriggerBodyDeleteStatement parses a DELETE statement within a trigger body.
// It differs from parseDeleteStatement by only allowing unqualified table names.
func (p *Parser) parseTriggerBodyDeleteStatement(withClause *WithClause) (_ *DeleteStatement, err error) {
assert(p.peek() == DELETE)

var stmt DeleteStatement
stmt.WithClause = withClause

// Parse "DELETE FROM tbl"
stmt.Delete, _, _ = p.scan()
if p.peek() != FROM {
return &stmt, p.errorExpected(p.pos, p.tok, "FROM")
}
stmt.From, _, _ = p.scan()
if !isIdentToken(p.peek()) {
return nil, p.errorExpected(p.pos, p.tok, "table name")
}
ident, _ := p.parseIdent("table name")

// In trigger bodies, only unqualified table names are allowed
if err = p.validateUnqualifiedTableName(ident); err != nil {
return &stmt, err
}
stmt.Table = &QualifiedTableName{Name: ident}

// Parse WHERE clause.
if p.peek() == WHERE {
stmt.Where, _, _ = p.scan()
if stmt.WhereExpr, err = p.ParseExpr(); err != nil {
return &stmt, err
}
}

// Parse ORDER BY clause. This differs from the SELECT parsing in that
// if an ORDER BY is specified then the LIMIT is required.
if p.peek() == ORDER || p.peek() == LIMIT {
if p.peek() == ORDER {
stmt.Order, _, _ = p.scan()
if p.peek() != BY {
return &stmt, p.errorExpected(p.pos, p.tok, "BY")
}
stmt.OrderBy, _, _ = p.scan()

for {
term, err := p.parseOrderingTerm()
if err != nil {
return &stmt, err
}
stmt.OrderingTerms = append(stmt.OrderingTerms, term)

if p.peek() != COMMA {
break
}
p.scan()
}
}

// Parse LIMIT/OFFSET clause.
if p.peek() != LIMIT {
return &stmt, p.errorExpected(p.pos, p.tok, "LIMIT")
}
stmt.Limit, _, _ = p.scan()
if stmt.LimitExpr, err = p.ParseExpr(); err != nil {
return &stmt, err
}

if p.peek() == OFFSET {
stmt.Offset, _, _ = p.scan()
if stmt.OffsetExpr, err = p.ParseExpr(); err != nil {
return &stmt, err
}
}
}

return &stmt, nil
}

// parseTriggerBodyUpdateStatement parses an UPDATE statement within a trigger body.
// It differs from parseUpdateStatement by only allowing unqualified table names.
func (p *Parser) parseTriggerBodyUpdateStatement(withClause *WithClause) (_ *UpdateStatement, err error) {
assert(p.peek() == UPDATE)

var stmt UpdateStatement
stmt.WithClause = withClause

stmt.Update, _, _ = p.scan()
if p.peek() == OR {
stmt.UpdateOr, _, _ = p.scan()

switch p.peek() {
case ROLLBACK:
stmt.UpdateOrRollback, _, _ = p.scan()
case REPLACE:
stmt.UpdateOrReplace, _, _ = p.scan()
case ABORT:
stmt.UpdateOrAbort, _, _ = p.scan()
case FAIL:
stmt.UpdateOrFail, _, _ = p.scan()
case IGNORE:
stmt.UpdateOrIgnore, _, _ = p.scan()
default:
return &stmt, p.errorExpected(p.pos, p.tok, "ROLLBACK, REPLACE, ABORT, FAIL, or IGNORE")
}
}

if !isIdentToken(p.peek()) {
return nil, p.errorExpected(p.pos, p.tok, "table name")
}
ident, _ := p.parseIdent("table name")

// In trigger bodies, only unqualified table names are allowed
if err = p.validateUnqualifiedTableName(ident); err != nil {
return &stmt, err
}
stmt.Table = &QualifiedTableName{Name: ident}

// Parse SET + list of assignments.
if p.peek() != SET {
return &stmt, p.errorExpected(p.pos, p.tok, "SET")
}
stmt.Set, _, _ = p.scan()

for {
assignment, err := p.parseAssignment()
if err != nil {
return &stmt, err
}
stmt.Assignments = append(stmt.Assignments, assignment)

if p.peek() != COMMA {
break
}
p.scan()
}

// Parse WHERE clause.
if p.peek() == WHERE {
stmt.Where, _, _ = p.scan()
if stmt.WhereExpr, err = p.ParseExpr(); err != nil {
return &stmt, err
}
}

// Parse optional RETURNING clause.
if p.peek() == RETURNING {
if stmt.ReturningClause, err = p.parseReturningClause(); err != nil {
return &stmt, err
}
}

return &stmt, nil
}

// validateUnqualifiedTableName ensures that the next tokens do not form
// a qualified table name (schema.table or table alias).
func (p *Parser) validateUnqualifiedTableName(ident *Ident) error {
// Check for schema qualification (schema.table)
if p.peek() == DOT {
return fmt.Errorf("qualified table names not allowed in trigger body")
}

// Check for table alias
if tok := p.peek(); tok == AS || isIdentToken(tok) {
return fmt.Errorf("qualified table names not allowed in trigger body")
}

return nil
}

func (p *Parser) parseDropTriggerStatement(dropPos Pos) (_ *DropTriggerStatement, err error) {
assert(p.peek() == TRIGGER)

Expand Down
6 changes: 6 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1860,6 +1860,12 @@ func TestParser_ParseStatement(t *testing.T) {
End: pos(83),
})

// Test cases that should fail due to qualified table names in trigger body
AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN DELETE FROM host h; END`, `qualified table names not allowed in trigger body`)
AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN UPDATE host h SET x = 1; END`, `qualified table names not allowed in trigger body`)
AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN DELETE FROM schema.host; END`, `qualified table names not allowed in trigger body`)
AssertParseStatementError(t, `CREATE TRIGGER trig AFTER DELETE ON tbl BEGIN UPDATE schema.host SET x = 1; END`, `qualified table names not allowed in trigger body`)

AssertParseStatementError(t, `CREATE TRIGGER`, `1:14: expected index name, found 'EOF'`)
AssertParseStatementError(t, `CREATE TRIGGER IF`, `1:17: expected NOT, found 'EOF'`)
AssertParseStatementError(t, `CREATE TRIGGER IF NOT`, `1:21: expected EXISTS, found 'EOF'`)
Expand Down
Loading