Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions internal/apiserver/service/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ import (
"github.com/labstack/echo/v4/middleware"
)

type sqlWorkbenchProxyConfiguration struct {
enableCloudbeaver bool
cloudbeaverRoot string
enableOdcQuery bool
odcQueryRoot string
}

func buildSQLWorkbenchProxyConfiguration(controller *SqlWorkbenchController) sqlWorkbenchProxyConfiguration {
cfg := sqlWorkbenchProxyConfiguration{}
if controller == nil {
return cfg
}
if controller.CloudbeaverService != nil && controller.CloudbeaverService.CloudbeaverUsecase != nil {
cfg.enableCloudbeaver = controller.CloudbeaverService.CloudbeaverUsecase.IsCloudbeaverConfigured()
cfg.cloudbeaverRoot = controller.CloudbeaverService.CloudbeaverUsecase.GetRootUri()
}
if controller.SqlWorkbenchService != nil {
cfg.enableOdcQuery = controller.SqlWorkbenchService.IsConfigured()
cfg.odcQueryRoot = controller.SqlWorkbenchService.GetRootUri()
}
return cfg
}

func (s *APIServer) initRouter() error {
s.echo.GET("/swagger/*", s.DMSController.SwaggerHandler, SwaggerMiddleWare)

Expand All @@ -30,6 +53,7 @@ func (s *APIServer) initRouter() error {
return err
}
v2 := s.echo.Group(dmsV2.CurrentGroupVersion)
proxyCfg := buildSQLWorkbenchProxyConfiguration(s.SqlWorkbenchController)
// DMS RESTful resource
{
{
Expand Down Expand Up @@ -262,8 +286,8 @@ func (s *APIServer) initRouter() error {
gatewayV1.GET("/tips", s.DMSController.GetGatewayTips)
gatewayV1.PUT("/", s.DMSController.SyncGateways, s.DMSController.DMS.GatewayUsecase.Broadcast())

if s.SqlWorkbenchController.CloudbeaverService.CloudbeaverUsecase.IsCloudbeaverConfigured() {
cloudbeaverV1 := s.echo.Group(s.SqlWorkbenchController.CloudbeaverService.CloudbeaverUsecase.GetRootUri())
if proxyCfg.enableCloudbeaver {
cloudbeaverV1 := s.echo.Group(proxyCfg.cloudbeaverRoot)
targets, err := s.SqlWorkbenchController.CloudbeaverService.ProxyUsecase.GetCloudbeaverProxyTarget()
if err != nil {
return err
Expand All @@ -277,8 +301,8 @@ func (s *APIServer) initRouter() error {
}))
}

if s.SqlWorkbenchController.SqlWorkbenchService.IsConfigured() {
sqlWorkbenchV1 := s.echo.Group(s.SqlWorkbenchController.SqlWorkbenchService.GetRootUri())
if proxyCfg.enableOdcQuery {
sqlWorkbenchV1 := s.echo.Group(proxyCfg.odcQueryRoot)
targets, err := s.SqlWorkbenchController.SqlWorkbenchService.GetOdcProxyTarget()
if err != nil {
return err
Expand Down
64 changes: 64 additions & 0 deletions internal/apiserver/service/router_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package service

import (
"testing"

"github.com/actiontech/dms/internal/dms/biz"
dmsService "github.com/actiontech/dms/internal/dms/service"
sql_workbench "github.com/actiontech/dms/internal/sql_workbench/service"
"github.com/stretchr/testify/assert"
)

func TestBuildSQLWorkbenchProxyConfiguration(t *testing.T) {
t.Parallel()

tests := map[string]struct {
controller *SqlWorkbenchController
expected sqlWorkbenchProxyConfiguration
}{
"nil controller": {
controller: nil,
expected: sqlWorkbenchProxyConfiguration{},
},
"cloudbeaver enabled": {
controller: &SqlWorkbenchController{
CloudbeaverService: &dmsService.CloudbeaverService{
CloudbeaverUsecase: biz.NewCloudbeaverUsecase(noopLogger{}, &biz.CloudbeaverCfg{
Host: "cloudbeaver",
Port: "8978",
AdminUser: "admin",
AdminPassword: "password",
}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil),
},
SqlWorkbenchService: &sql_workbench.SqlWorkbenchService{},
},
expected: sqlWorkbenchProxyConfiguration{
enableCloudbeaver: true,
cloudbeaverRoot: "/sql_query",
enableOdcQuery: false,
odcQueryRoot: "/odc_query",
},
},
"all disabled": {
controller: &SqlWorkbenchController{
CloudbeaverService: &dmsService.CloudbeaverService{
CloudbeaverUsecase: biz.NewCloudbeaverUsecase(noopLogger{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil),
},
SqlWorkbenchService: &sql_workbench.SqlWorkbenchService{},
},
expected: sqlWorkbenchProxyConfiguration{
enableCloudbeaver: false,
cloudbeaverRoot: "/sql_query",
enableOdcQuery: false,
odcQueryRoot: "/odc_query",
},
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expected, buildSQLWorkbenchProxyConfiguration(tc.controller))
})
}
}
33 changes: 18 additions & 15 deletions internal/apiserver/service/sql_workbench_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ import (
"fmt"

dmsV1 "github.com/actiontech/dms/api/dms/service/v1"
sql_workbench "github.com/actiontech/dms/internal/sql_workbench/service"

"github.com/actiontech/dms/internal/apiserver/conf"
"github.com/actiontech/dms/internal/dms/service"
sql_workbench "github.com/actiontech/dms/internal/sql_workbench/service"

utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -44,6 +43,22 @@ func (cc *SqlWorkbenchController) Shutdown() error {
return nil
}

type SQLQueryConfiguration struct {
EnableSQLQuery bool `json:"enable_sql_query"`
SQLQueryRootURI string `json:"sql_query_root_uri"`
EnableOdcQuery bool `json:"enable_odc_query"`
OdcQueryRootURI string `json:"odc_query_root_uri"`
}

func (cc *SqlWorkbenchController) buildSQLQueryConfiguration() SQLQueryConfiguration {
return SQLQueryConfiguration{
EnableSQLQuery: cc.CloudbeaverService.CloudbeaverUsecase.IsCloudbeaverConfigured(),
SQLQueryRootURI: cc.CloudbeaverService.CloudbeaverUsecase.GetRootUri() + "/", // 确保URL以斜杠结尾,防止DMS开启HTTPS时,Web服务器重定向到HTTP根路径导致访问错误
EnableOdcQuery: cc.SqlWorkbenchService.IsConfigured(),
OdcQueryRootURI: cc.SqlWorkbenchService.GetRootUri(),
}
}

// swagger:route GET /v1/dms/configurations/sql_query CloudBeaver GetSQLQueryConfiguration
//
// get sql_query configuration.
Expand All @@ -52,18 +67,6 @@ func (cc *SqlWorkbenchController) Shutdown() error {
// 200: body:GetSQLQueryConfigurationReply
// default: body:GenericResp
func (cc *SqlWorkbenchController) GetSQLQueryConfiguration(c echo.Context) error {
reply := &dmsV1.GetSQLQueryConfigurationReply{
Data: struct {
EnableSQLQuery bool `json:"enable_sql_query"`
SQLQueryRootURI string `json:"sql_query_root_uri"`
EnableOdcQuery bool `json:"enable_odc_query"`
OdcQueryRootURI string `json:"odc_query_root_uri"`
}{
EnableSQLQuery: cc.CloudbeaverService.CloudbeaverUsecase.IsCloudbeaverConfigured(),
SQLQueryRootURI: cc.CloudbeaverService.CloudbeaverUsecase.GetRootUri() + "/", // 确保URL以斜杠结尾,防止DMS开启HTTPS时,Web服务器重定向到HTTP根路径导致访问错误
EnableOdcQuery: cc.SqlWorkbenchService.IsConfigured(),
OdcQueryRootURI: cc.SqlWorkbenchService.GetRootUri(),
},
}
reply := &dmsV1.GetSQLQueryConfigurationReply{Data: cc.buildSQLQueryConfiguration()}
return NewOkRespWithReply(c, reply)
}
111 changes: 111 additions & 0 deletions internal/apiserver/service/sql_workbench_controller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package service

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

apiError "github.com/actiontech/dms/internal/apiserver/pkg/error"
"github.com/actiontech/dms/internal/dms/biz"
dmsService "github.com/actiontech/dms/internal/dms/service"
sql_workbench "github.com/actiontech/dms/internal/sql_workbench/service"
bV1 "github.com/actiontech/dms/pkg/dms-common/api/base/v1"
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

type noopLogger struct{}

func (noopLogger) Log(level utilLog.Level, keyvals ...interface{}) error {
return nil
}

func TestSqlWorkbenchControllerBuildSQLQueryConfiguration(t *testing.T) {
t.Parallel()

tests := map[string]struct {
controller *SqlWorkbenchController
expected SQLQueryConfiguration
}{
"cloudbeaver enabled": {
controller: &SqlWorkbenchController{
CloudbeaverService: &dmsService.CloudbeaverService{
CloudbeaverUsecase: biz.NewCloudbeaverUsecase(noopLogger{}, &biz.CloudbeaverCfg{
Host: "cloudbeaver",
Port: "8978",
AdminUser: "admin",
AdminPassword: "password",
}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil),
},
SqlWorkbenchService: &sql_workbench.SqlWorkbenchService{},
},
expected: SQLQueryConfiguration{
EnableSQLQuery: true,
SQLQueryRootURI: "/sql_query/",
EnableOdcQuery: false,
OdcQueryRootURI: "/odc_query",
},
},
"all disabled": {
controller: &SqlWorkbenchController{
CloudbeaverService: &dmsService.CloudbeaverService{
CloudbeaverUsecase: biz.NewCloudbeaverUsecase(noopLogger{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil),
},
SqlWorkbenchService: &sql_workbench.SqlWorkbenchService{},
},
expected: SQLQueryConfiguration{
EnableSQLQuery: false,
SQLQueryRootURI: "/sql_query/",
EnableOdcQuery: false,
OdcQueryRootURI: "/odc_query",
},
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expected, tc.controller.buildSQLQueryConfiguration())
})
}
}

func TestGetSQLQueryConfiguration(t *testing.T) {
t.Parallel()

controller := &SqlWorkbenchController{
CloudbeaverService: &dmsService.CloudbeaverService{
CloudbeaverUsecase: biz.NewCloudbeaverUsecase(noopLogger{}, &biz.CloudbeaverCfg{
Host: "cloudbeaver",
Port: "8978",
AdminUser: "admin",
AdminPassword: "password",
}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil),
},
SqlWorkbenchService: &sql_workbench.SqlWorkbenchService{},
}

e := echo.New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/dms/configurations/sql_query", nil)
c := e.NewContext(req, rec)

err := controller.GetSQLQueryConfiguration(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)

var resp struct {
bV1.GenericResp
Data SQLQueryConfiguration `json:"data"`
}
assert.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
assert.Equal(t, int(apiError.StatusOK), resp.Code)
assert.Equal(t, SQLQueryConfiguration{
EnableSQLQuery: true,
SQLQueryRootURI: "/sql_query/",
EnableOdcQuery: false,
OdcQueryRootURI: "/odc_query",
}, resp.Data)
}
33 changes: 1 addition & 32 deletions internal/dms/biz/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1341,38 +1341,7 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
return err
}

// 已配置的项目管理权限和数据源工作台查询权限
projectIdMap := map[string]struct{}{}
dbServiceIdMap := map[string]struct{}{}
for _, opPermission := range opPermissions {
// project permission
if opPermission.OpRangeType == OpRangeTypeProject && opPermission.OpPermissionUID == constant.UIDOfOpPermissionProjectAdmin {
for _, rangeUid := range opPermission.RangeUIDs {
projectIdMap[rangeUid] = struct{}{}
}
}

// db_service permission
if opPermission.OpRangeType == OpRangeTypeDBService && opPermission.OpPermissionUID == constant.UIDOfOpPermissionSQLQuery {
for _, rangeUid := range opPermission.RangeUIDs {
dbServiceIdMap[rangeUid] = struct{}{}
}
}
}

var lastActiveDBServices []*DBService
for _, activeDBService := range activeDBServices {
if _, ok := projectIdMap[activeDBService.ProjectUID]; ok {
lastActiveDBServices = append(lastActiveDBServices, activeDBService)
continue
}

if _, ok := dbServiceIdMap[activeDBService.UID]; ok {
lastActiveDBServices = append(lastActiveDBServices, activeDBService)
}
}

activeDBServices = lastActiveDBServices
activeDBServices = FilterDBServicesBySQLWorkbenchAccess(activeDBServices, opPermissions)
}

cloudbeaverUser, exist, err := cu.repo.GetCloudbeaverUserByID(ctx, cloudbeaverUserId)
Expand Down
3 changes: 2 additions & 1 deletion internal/dms/biz/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package biz
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -231,7 +232,7 @@ func (f *FeishuNotifier) Notify(ctx context.Context, notificationSubject, notifi
}
}
if len(errMsgs) > 0 {
return fmt.Errorf(strings.Join(errMsgs, "\n"))
return errors.New(strings.Join(errMsgs, "\n"))
}
return nil
}
Expand Down
39 changes: 39 additions & 0 deletions internal/dms/biz/sql_workbench_access.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package biz

import pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant"

// FilterDBServicesBySQLWorkbenchAccess applies the shared SQL workbench access rules:
// project admins can access all active DB services in the project, and users with
// SQL query permission on a DB service can access that specific service.
func FilterDBServicesBySQLWorkbenchAccess(dbServices []*DBService, opPermissions []OpPermissionWithOpRange) []*DBService {
projectIDMap := make(map[string]struct{})
dbServiceIDMap := make(map[string]struct{})

for _, opPermission := range opPermissions {
if opPermission.OpRangeType == OpRangeTypeProject && opPermission.OpPermissionUID == pkgConst.UIDOfOpPermissionProjectAdmin {
for _, rangeUID := range opPermission.RangeUIDs {
projectIDMap[rangeUID] = struct{}{}
}
}

if opPermission.OpRangeType == OpRangeTypeDBService && opPermission.OpPermissionUID == pkgConst.UIDOfOpPermissionSQLQuery {
for _, rangeUID := range opPermission.RangeUIDs {
dbServiceIDMap[rangeUID] = struct{}{}
}
}
}

filteredDBServices := make([]*DBService, 0, len(dbServices))
for _, dbService := range dbServices {
if _, hasProjectPermission := projectIDMap[dbService.ProjectUID]; hasProjectPermission {
filteredDBServices = append(filteredDBServices, dbService)
continue
}

if _, hasDBServicePermission := dbServiceIDMap[dbService.UID]; hasDBServicePermission {
filteredDBServices = append(filteredDBServices, dbService)
}
}

return filteredDBServices
}
Loading