Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ Use the `-dbms` flag to specify the database type:
- `mysql` - MySQL
- `oracle` - Oracle
- `snowflake` - Snowflake
- `sqlite` - SQLite

## Testing

Expand Down
4 changes: 2 additions & 2 deletions cmd/sqllexer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func main() {
mode = flag.String("mode", "obfuscate_and_normalize", "Operation mode: obfuscate, normalize, tokenize, obfuscate_and_normalize")
inputFile = flag.String("input", "", "Input file (default: stdin)")
outputFile = flag.String("output", "", "Output file (default: stdout)")
dbms = flag.String("dbms", "", "Database type: mssql, postgresql, mysql, oracle, snowflake")
dbms = flag.String("dbms", "", "Database type: mssql, postgresql, mysql, oracle, snowflake, sqlite")
replaceDigits = flag.Bool("replace-digits", true, "Replace digits with placeholders")
replaceBoolean = flag.Bool("replace-boolean", true, "Replace boolean values with placeholders")
replaceNull = flag.Bool("replace-null", true, "Replace null values with placeholders")
Expand Down Expand Up @@ -252,7 +252,7 @@ Flags:
-output string
Output file (default: stdout)
-dbms string
Database type: mssql, postgresql, mysql, oracle, snowflake
Database type: mssql, postgresql, mysql, oracle, snowflake, sqlite
-replace-digits
Replace digits with placeholders (default true)
-replace-boolean
Expand Down
1 change: 1 addition & 0 deletions dbms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestQueriesPerDBMS(t *testing.T) {
DBMSSQLServer,
DBMSMySQL,
DBMSSnowflake,
DBMSSQLite,
}

for _, dbms := range dbmsTypes {
Expand Down
92 changes: 77 additions & 15 deletions sqllexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,29 @@ func (s *Lexer) Scan() *Token {
case isWildcard(ch):
return s.scanWildcard()
case ch == '$':
if isDigit(s.lookAhead(1)) {
// if the dollar sign is followed by a digit, then it's a numbered parameter
return s.scanPositionalParameter()
nextCh := s.lookAhead(1)
if isDigit(nextCh) {
// Prefix length 2: consume '$' plus the first digit of SQLite bind parameters that use $VVV,
// where V may be numeric (e.g. $1, $12). Refer to scanSQLiteBindParameter for details
return s.scanNumericParameter(2)
}
if s.config.DBMS == DBMSSQLite && isAlphaNumeric(nextCh) {
return s.scanBindParameter()
}
if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) {
if s.config.DBMS == DBMSSQLServer && isLetter(nextCh) {
return s.scanIdentifier(ch)
}
return s.scanDollarQuotedString()
case ch == ':':
if s.config.DBMS == DBMSOracle && isAlphaNumeric(s.lookAhead(1)) {
if (s.config.DBMS == DBMSOracle || s.config.DBMS == DBMSSQLite) && isAlphaNumeric(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator(ch)
case ch == '`':
if s.config.DBMS == DBMSMySQL {
if s.config.DBMS == DBMSMySQL || s.config.DBMS == DBMSSQLite {
return s.scanDoubleQuotedIdentifier('`')
}
return s.scanUnknown() // backtick is only valid in mysql
return s.scanUnknown() // backtick is only valid in mysql and sqlite
case ch == '#':
if s.config.DBMS == DBMSSQLServer {
return s.scanIdentifier(ch)
Expand All @@ -168,6 +173,13 @@ func (s *Lexer) Scan() *Token {
return s.scanSingleLineComment(ch)
}
return s.scanOperator(ch)
case ch == '?':
if s.config.DBMS == DBMSSQLite {
// Prefix length 1: consume '?' before scanning optional digits of SQLite ?NNN parameters
// SQLite treats bare '?' and '?NNN' as positional parameters (see scanSQLiteBindParameter)
return s.scanNumericParameter(1)
}
return s.scanOperator(ch)
case ch == '@':
if s.lookAhead(1) == '@' {
if isAlphaNumeric(s.lookAhead(2)) {
Expand All @@ -192,7 +204,7 @@ func (s *Lexer) Scan() *Token {
case isOperator(ch):
return s.scanOperator(ch)
case isPunctuation(ch):
if ch == '[' && s.config.DBMS == DBMSSQLServer {
if ch == '[' && (s.config.DBMS == DBMSSQLServer || s.config.DBMS == DBMSSQLite) {
return s.scanDoubleQuotedIdentifier('[')
}
return s.scanPunctuation()
Expand Down Expand Up @@ -595,21 +607,22 @@ func (s *Lexer) scanDollarQuotedString() *Token {
return s.emit(ERROR)
}

func (s *Lexer) scanPositionalParameter() *Token {
func (s *Lexer) scanNumericParameter(prefixLen int) *Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the dollar sign and the number
for {
if !isDigit(ch) {
break
}
ch := s.nextBy(prefixLen)
for isDigit(ch) {
ch = s.next()
}
return s.emit(POSITIONAL_PARAMETER)
}

func (s *Lexer) scanBindParameter() *Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the (colon|at sign) and the char
if s.config.DBMS == DBMSSQLite {
// SQLite allows named bind parameters prefixed with :, @, or $, so use the SQLite-specific scanner
return s.scanSQLiteBindParameter()
}
ch := s.nextBy(2) // consume the (colon|at sign|dollar sign) and the char
for {
if !isAlphaNumeric(ch) {
break
Expand All @@ -619,6 +632,55 @@ func (s *Lexer) scanBindParameter() *Token {
return s.emit(BIND_PARAMETER)
}

// https://sqlite.org/c3ref/bind_blob.html
func (s *Lexer) scanSQLiteBindParameter() *Token {
s.next() // consume the prefix character (:, @, or $)
s.consumeSQLiteIdentifier()

for {
if s.peek() == ':' && s.lookAhead(1) == ':' {
s.nextBy(2) // consume '::'
s.consumeSQLiteIdentifier()
continue
}
break
}

if s.peek() == '(' {
s.consumeSQLiteParameterSuffix()
}

return s.emit(BIND_PARAMETER)
}

func (s *Lexer) consumeSQLiteIdentifier() {
for {
ch := s.peek()
if ch == '_' || isAlphaNumeric(ch) {
s.next()
continue
}
break
}
}

func (s *Lexer) consumeSQLiteParameterSuffix() {
s.next() // consume '('
depth := 1
for depth > 0 {
ch := s.peek()
if isEOF(ch) {
break
}
s.next()
if ch == '(' {
depth++
} else if ch == ')' {
depth--
}
}
}

func (s *Lexer) scanSystemVariable() *Token {
s.start = s.cursor
ch := s.nextBy(2) // consume @@
Expand Down
28 changes: 28 additions & 0 deletions sqllexer_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ func addComplexTestCases(f *testing.F) {
`SELECT $1, $2 FROM @mystage/file.csv`,
}

// SQLite specific patterns
sqlitePatterns := []string{
`SELECT * FROM pragma_table_info('users')`,
`INSERT OR REPLACE INTO kv_store(key, value) VALUES(:key, json_extract($payload, '$.value'))`,
`INSERT INTO logs VALUES($ns::var, $env(config), $ns::name(sub))`,
"CREATE TABLE IF NOT EXISTS logs (id INTEGER PRIMARY KEY, payload TEXT) WITHOUT ROWID",
"WITH ranked AS (SELECT *, row_number() OVER (PARTITION BY type ORDER BY created_at DESC) AS rn FROM events) SELECT * FROM ranked WHERE rn = 1",
"SELECT [user] FROM [main].[table] WHERE [id] = 1",
"ATTACH DATABASE 'archive.db' AS archive; DETACH DATABASE archive",
}

// Common edge cases across all DBMS
commonEdgeCases := []string{
// Nested subqueries
Expand Down Expand Up @@ -181,6 +192,7 @@ func addComplexTestCases(f *testing.F) {
patterns = append(patterns, oraclePatterns...)
patterns = append(patterns, snowflakePatterns...)
patterns = append(patterns, commonEdgeCases...)
patterns = append(patterns, sqlitePatterns...)

// Add each pattern with different DBMS types
dbmsTypes := []string{
Expand All @@ -189,6 +201,7 @@ func addComplexTestCases(f *testing.F) {
string(DBMSMySQL),
string(DBMSOracle),
string(DBMSSnowflake),
string(DBMSSQLite),
}

for _, pattern := range patterns {
Expand Down Expand Up @@ -259,6 +272,15 @@ func addObfuscationTestCases(f *testing.F) {
`SELECT $1, $2, $3 FROM @mystage`,
}

// SQLite specific obfuscation patterns
sqlitePatterns := []string{
`SELECT * FROM logs WHERE id = ?5 AND tag = @tag`,
`SELECT * FROM users WHERE email = :email OR email = $email`,
`SELECT $ns::var, $env(config), $ns::name(sub)`,
`SELECT [user] FROM [main].[table] WHERE [id] = 1`,
`PRAGMA table_info('users')`,
}

// Common obfuscation patterns for all DBMS
commonPatterns := []string{
// Basic numbers
Expand Down Expand Up @@ -331,13 +353,19 @@ func addObfuscationTestCases(f *testing.F) {
f.Add(pattern, string(DBMSSnowflake))
}

// Add SQLite patterns with SQLite DBMS
for _, pattern := range sqlitePatterns {
f.Add(pattern, string(DBMSSQLite))
}

// Add common patterns and quote edge cases with all DBMS types
dbmsTypes := []string{
string(DBMSPostgres),
string(DBMSSQLServer),
string(DBMSMySQL),
string(DBMSOracle),
string(DBMSSnowflake),
string(DBMSSQLite),
}

for _, pattern := range append(commonPatterns, quoteEdgeCases...) {
Expand Down
128 changes: 128 additions & 0 deletions sqllexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,134 @@ here */`,
{BIND_PARAMETER, "@__my_id"},
},
},
{
name: "sqlite named parameters",
input: "SELECT * FROM users WHERE id = :id AND email = $email AND tag = @tag",
expected: []TokenSpec{
{COMMAND, "SELECT"},
{SPACE, " "},
{WILDCARD, "*"},
{SPACE, " "},
{KEYWORD, "FROM"},
{SPACE, " "},
{IDENT, "users"},
{SPACE, " "},
{KEYWORD, "WHERE"},
{SPACE, " "},
{IDENT, "id"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{BIND_PARAMETER, ":id"},
{SPACE, " "},
{KEYWORD, "AND"},
{SPACE, " "},
{IDENT, "email"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{BIND_PARAMETER, "$email"},
{SPACE, " "},
{KEYWORD, "AND"},
{SPACE, " "},
{IDENT, "tag"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{BIND_PARAMETER, "@tag"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLite)},
},
{
name: "sqlite positional parameters",
input: "SELECT * FROM logs WHERE id = ?5 AND alt = ?",
expected: []TokenSpec{
{COMMAND, "SELECT"},
{SPACE, " "},
{WILDCARD, "*"},
{SPACE, " "},
{KEYWORD, "FROM"},
{SPACE, " "},
{IDENT, "logs"},
{SPACE, " "},
{KEYWORD, "WHERE"},
{SPACE, " "},
{IDENT, "id"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{POSITIONAL_PARAMETER, "?5"},
{SPACE, " "},
{KEYWORD, "AND"},
{SPACE, " "},
{IDENT, "alt"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{POSITIONAL_PARAMETER, "?"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLite)},
},
{
name: "sqlite extended dollar parameters",
input: "SELECT $ns::var, $env(config), $ns::name(sub)",
expected: []TokenSpec{
{COMMAND, "SELECT"},
{SPACE, " "},
{BIND_PARAMETER, "$ns::var"},
{PUNCTUATION, ","},
{SPACE, " "},
{BIND_PARAMETER, "$env(config)"},
{PUNCTUATION, ","},
{SPACE, " "},
{BIND_PARAMETER, "$ns::name(sub)"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLite)},
},
{
name: "sqlite square bracket identifier",
input: "SELECT [user] FROM [main].[table] WHERE [id] = 1",
expected: []TokenSpec{
{COMMAND, "SELECT"},
{SPACE, " "},
{QUOTED_IDENT, "[user]"},
{SPACE, " "},
{KEYWORD, "FROM"},
{SPACE, " "},
{QUOTED_IDENT, "[main].[table]"},
{SPACE, " "},
{KEYWORD, "WHERE"},
{SPACE, " "},
{QUOTED_IDENT, "[id]"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{NUMBER, "1"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLite)},
},
{
name: "sqlite backtick quoted identifier",
input: "SELECT `user` FROM `main`.`table` WHERE `id` = 1",
expected: []TokenSpec{
{COMMAND, "SELECT"},
{SPACE, " "},
{QUOTED_IDENT, "`user`"},
{SPACE, " "},
{KEYWORD, "FROM"},
{SPACE, " "},
{QUOTED_IDENT, "`main`.`table`"},
{SPACE, " "},
{KEYWORD, "WHERE"},
{SPACE, " "},
{QUOTED_IDENT, "`id`"},
{SPACE, " "},
{OPERATOR, "="},
{SPACE, " "},
{NUMBER, "1"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLite)},
},
{
name: "select with system variable",
input: "SELECT @@VERSION AS SqlServerVersion",
Expand Down
Loading