Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions backend/apps/system/api/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from common.core.deps import SessionDep, Trans
from common.core.security import create_access_token
from common.core.sqlbot_cache import clear_cache
from common.utils.utils import get_origin_from_referer
from common.utils.utils import get_origin_from_referer, origin_match_domain

router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")

Expand All @@ -30,13 +30,15 @@ async def info(request: Request, response: Response, session: SessionDep, trans:
if not db_model:
raise RuntimeError(f"assistant application not exist")
db_model = AssistantModel.model_validate(db_model)
response.headers["Access-Control-Allow-Origin"] = db_model.domain

origin = request.headers.get("origin") or get_origin_from_referer(request)
if not origin:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or ''))
origin = origin.rstrip('/')
if origin != db_model.domain:
if not origin_match_domain(origin, db_model.domain):
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or ''))

response.headers["Access-Control-Allow-Origin"] = origin
return db_model


Expand All @@ -48,13 +50,14 @@ async def getApp(request: Request, response: Response, session: SessionDep, tran
if not db_model:
raise RuntimeError(f"assistant application not exist")
db_model = AssistantModel.model_validate(db_model)
response.headers["Access-Control-Allow-Origin"] = db_model.domain
origin = request.headers.get("origin") or get_origin_from_referer(request)
if not origin:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or ''))
origin = origin.rstrip('/')
if origin != db_model.domain:
if not origin_match_domain(origin, db_model.domain):
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin=origin or ''))

response.headers["Access-Control-Allow-Origin"] = origin
return db_model


Expand Down
16 changes: 16 additions & 0 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,22 @@ class AssistantOutDs:
assistant: AssistantHeader
ds_list: Optional[list[AssistantOutDsSchema]] = None
certificate: Optional[str] = None
request_origin: Optional[str] = None

def __init__(self, assistant: AssistantHeader):
self.assistant = assistant
self.ds_list = None
self.certificate = assistant.certificate
self.request_origin = assistant.request_origin
self.get_ds_from_api()

# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
def get_ds_from_api(self):
config: dict[any] = json.loads(self.assistant.configuration)
endpoint: str = config['endpoint']
endpoint = self.get_complete_endpoint(endpoint=endpoint)
if not endpoint:
raise Exception(f"Failed to get datasource list from {config['endpoint']}, error: [Assistant domain or endpoint miss]")
certificateList: list[any] = json.loads(self.certificate)
header = {}
cookies = {}
Expand Down Expand Up @@ -137,6 +142,17 @@ def get_ds_from_api(self):
else:
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")

def get_complete_endpoint(self, endpoint: str) -> str | None:
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return endpoint
domain_text = self.assistant.domain
if not domain_text:
return None
if ',' in domain_text:
return self.request_origin.strip('/') if self.request_origin else domain_text.split(',')[0].strip('/') + endpoint
else:
return f"{domain_text}{endpoint}"

def get_simple_ds_list(self):
if self.ds_list:
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
Expand Down
5 changes: 4 additions & 1 deletion backend/apps/system/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from common.core.config import settings
from common.core.schemas import TokenPayload
from common.utils.locale import I18n
from common.utils.utils import SQLBotLogUtil
from common.utils.utils import SQLBotLogUtil, get_origin_from_referer
from common.utils.whitelist import whiteUtils
from fastapi.security.utils import get_authorization_scheme_param
from common.core.deps import get_i18n
Expand All @@ -40,6 +40,9 @@ async def dispatch(self, request, call_next):
if validator[0]:
request.state.current_user = validator[1]
request.state.assistant = validator[2]
origin = request.headers.get("origin") or get_origin_from_referer(request)
if origin and validator[2]:
request.state.assistant.request_origin = origin
return await call_next(request)
message = trans('i18n_permission.authenticate_invalid', msg = validator[1])
return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"})
Expand Down
1 change: 1 addition & 0 deletions backend/apps/system/schemas/system_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class AssistantHeader(AssistantDTO):
unique: Optional[str] = None
certificate: Optional[str] = None
online: bool = False
request_origin: Optional[str] = None


class AssistantValidator(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions backend/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def get_origin_from_referer(request: Request):
SQLBotLogUtil.error(f"解析 Referer 出错: {e}")
return referer

def origin_match_domain(origin: str, domain: str) -> bool:
if not origin or not domain:
return False
origin_text = origin.rstrip('/')
domain_list = domain.replace(" ", "").split(',')
return origin_text in [d.rstrip('/') for d in domain_list]

def equals_ignore_case(str1: str, *args: str) -> bool:
if str1 is None:
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,9 @@
"creating_advanced_applications": "Creating Advanced Applications",
"configure_interface": "Configure interface",
"interface_url": "Interface URL",
"format_is_incorrect": "format is incorrect",
"format_is_incorrect": "format is incorrect{msg}",
"domain_format_incorrect": ",start with http/https, no trailing slash (/), multiple domains separated by half-width commas (,)",
"interface_url_incorrect": ",enter a relative path starting with /",
"aes_enable": "Enable AES encryption",
"aes_enable_tips": "The fields (host, user, password, dataBase, schema) are all encrypted using the AES-CBC-PKCS5Padding encryption method",
"bit": "bit",
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/i18n/ko-KR.json
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@
"private": "비공개",
"configure_interface": "인터페이스 설정",
"interface_url": "인터페이스 URL",
"format_is_incorrect": "형식이 올바르지 않습니다",
"format_is_incorrect": "형식이 올바르지 않습니다{msg}",
"domain_format_incorrect": ", http/https로 시작, 슬래시(/)로 끝나지 않음, 여러 도메인은 반각 쉼표(,)로 구분",
"interface_url_incorrect": ", 상대 경로를 입력해주세요. /로 시작합니다",
"aes_enable": "AES 암호화 활성화",
"aes_enable_tips": "암호화 필드 (host, user, password, dataBase, schema)는 모두 AES-CBC-PKCS5Padding 암호화 방식을 사용합니다",
"bit": "비트",
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/i18n/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@
"private": "私有",
"configure_interface": "配置接口",
"interface_url": "接口 URL",
"format_is_incorrect": "格式不对",
"format_is_incorrect": "格式不对{msg}",
"domain_format_incorrect": ",http或https开头,不能以 / 结尾,多个域名以逗号(半角)分隔",
"interface_url_incorrect": ",请填写相对路径,以/开头",
"aes_enable": "开启 AES 加密",
"aes_enable_tips": "加密字段 (host, user, password, dataBase, schema) 均采用 AES-CBC-PKCS5Padding 加密方式",
"bit": "位",
Expand Down
21 changes: 14 additions & 7 deletions frontend/src/views/system/embedded/Page.vue
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,20 @@ const validateUrl = (_: any, value: any, callback: any) => {
)
} else {
// var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line
var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
var objExp = new RegExp(Expression)
if (objExp.test(value) && !value.endsWith('/')) {
callback()
} else {
callback(t('embedded.format_is_incorrect'))
}
value
.trim()
.split(',')
.forEach((tempVal: string) => {
var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
var objExp = new RegExp(Expression)
if (objExp.test(tempVal) && !tempVal.endsWith('/')) {
callback()
} else {
callback(
t('embedded.format_is_incorrect', { msg: t('embedded.domain_format_incorrect') })
)
}
})
}
}
const rules = {
Expand Down
39 changes: 28 additions & 11 deletions frontend/src/views/system/embedded/iframe.vue
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,16 @@ const handleBaseEmbedded = (row: any) => {
const handleAdvancedEmbedded = (row: any) => {
advancedApplication.value = true
if (row) {
Object.assign(urlForm, cloneDeep(JSON.parse(row.configuration)))
const tempData = cloneDeep(JSON.parse(row.configuration))
if (tempData?.endpoint.startsWith('http')) {
row.domain
.trim()
.split(',')
.forEach((domain: string) => {
tempData.endpoint = tempData.endpoint.replace(domain, '')
})
}
Object.assign(urlForm, tempData)
}
ruleConfigvVisible.value = true
dialogTitle.value = row?.id
Expand Down Expand Up @@ -265,13 +274,20 @@ const validateUrl = (_: any, value: any, callback: any) => {
)
} else {
// var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line
var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
var objExp = new RegExp(Expression)
if (objExp.test(value) && !value.endsWith('/')) {
callback()
} else {
callback(t('embedded.format_is_incorrect'))
}
value
.trim()
.split(',')
.forEach((tempVal: string) => {
var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
var objExp = new RegExp(Expression)
if (objExp.test(tempVal) && !tempVal.endsWith('/')) {
callback()
} else {
callback(
t('embedded.format_is_incorrect', { msg: t('embedded.domain_format_incorrect') })
)
}
})
}
}
const rules = {
Expand Down Expand Up @@ -307,12 +323,13 @@ const validatePass = (_: any, value: any, callback: any) => {
)
} else {
// var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line
var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
// var Expression = /^https?:\/\/[^\s/?#]+(:\d+)?/i
var Expression = /^\/([a-zA-Z0-9_-]+\/)*[a-zA-Z0-9_-]+(\?[a-zA-Z0-9_=&-]+)?$/
var objExp = new RegExp(Expression)
if (objExp.test(value) && value.startsWith(currentEmbedded.domain)) {
if (objExp.test(value)) {
callback()
} else {
callback(t('embedded.format_is_incorrect'))
callback(t('embedded.format_is_incorrect', { msg: t('embedded.interface_url_incorrect') }))
}
}
}
Expand Down