Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/sql_workbench/service/sql_workbench_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar
// 读取请求体
bodyBytes, err := io.ReadAll(c.Request().Body)
if err != nil {
sqlWorkbenchService.log.Errorf("failed to read request body: %v", err)
return fmt.Errorf("failed to read request body: %w", err)
}
// 恢复请求体,供后续处理使用
Expand All @@ -1027,43 +1028,51 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar
// 解析请求体获取 SQL 和 datasource ID
sql, datasourceID, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes)
if err != nil {
sqlWorkbenchService.log.Errorf("failed to parse streamExecute request, skipping audit: %v", err)
return fmt.Errorf("failed to parse streamExecute request, skipping audit: %v", err)
}

if sql == "" || datasourceID == "" {
sqlWorkbenchService.log.Debugf("SQL or datasource ID is empty, skipping audit")
return fmt.Errorf("SQL or datasource ID is empty, skipping audit")
}

// 获取当前用户 ID
dmsUserId, err := sqlWorkbenchService.getDMSUserIdFromRequest(c)
if err != nil {
sqlWorkbenchService.log.Errorf("failed to get DMS user ID: %v", err)
return fmt.Errorf("failed to get DMS user ID: %v", err)
}

// 从缓存表获取 dms_db_service_id
dmsDBServiceID, err := sqlWorkbenchService.getDMSDBServiceIDFromCache(c.Request().Context(), datasourceID, dmsUserId)
if err != nil {
sqlWorkbenchService.log.Errorf("failed to get dms_db_service_id from cache: %v", err)
return fmt.Errorf("failed to get dms_db_service_id from cache: %v", err)
}

if dmsDBServiceID == "" {
sqlWorkbenchService.log.Debugf("dms_db_service_id not found in cache for datasource: %s", datasourceID)
return fmt.Errorf("dms_db_service_id not found in cache for datasource: %s", datasourceID)
}

// 获取 DBService 信息
dbService, err := sqlWorkbenchService.dbServiceUsecase.GetDBService(c.Request().Context(), dmsDBServiceID)
if err != nil {
sqlWorkbenchService.log.Errorf("failed to get DBService: %v", err)
return fmt.Errorf("failed to get DBService: %v", err)
}

// 检查是否启用 SQL 审核
if !sqlWorkbenchService.isEnableSQLAudit(dbService) {
sqlWorkbenchService.log.Debugf("SQL audit is not enabled for DBService: %s", dmsDBServiceID)
return fmt.Errorf("SQL audit is not enabled for DBService: %s", dmsDBServiceID)
}

// 调用 SQLE 审核接口
auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService)
if err != nil {
sqlWorkbenchService.log.Errorf("call SQLE audit failed: %v", err)
return fmt.Errorf("call SQLE audit failed: %v", err)
}

Expand Down Expand Up @@ -1225,6 +1234,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex
SQLContent: sql,
SQLType: "sql",
ProjectId: dbService.ProjectUID,
InstanceName: dbService.Name,
RuleTemplateName: dbService.SQLEConfig.SQLQueryConfig.RuleTemplateName,
}

Expand Down