Skip to content

Commit eb587fc

Browse files
yymulleo
authored andcommitted
fix(Hive): normalize identifier quoting and stabilize metadata/query paths
Use Hive-compatible identifier handling and template defaults so generated SQL returns real column values, while also fixing table schema parsing and adding required Hive runtime dependencies. Made-with: Cursor
1 parent d508040 commit eb587fc

5 files changed

Lines changed: 190 additions & 72 deletions

File tree

backend/apps/datasource/models/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def to_dict(self):
143143

144144

145145
class TableSchema:
146-
def __init__(self, attr1, attr2):
146+
def __init__(self, attr1, attr2=None):
147147
self.tableName = attr1
148148
self.tableComment = attr2 if attr2 is None or isinstance(attr2, str) else attr2.decode("utf-8")
149149

backend/apps/db/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DB(Enum):
2929
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL', [])
3030
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks', [])
3131
sqlite = ('sqlite', 'SQLite', '"', '"', ConnectType.sqlalchemy, 'SQLite', [])
32-
hive = ('hive', 'Apache Hive', '"', '"', ConnectType.py_driver, 'Hive', [])
32+
hive = ('hive', 'Apache Hive', '`', '`', ConnectType.py_driver, 'Hive', [])
3333

3434
def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str,
3535
illegalParams: List[str]):

backend/apps/db/db.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import platform
5+
import re
56
import urllib.parse
67
from datetime import datetime, date, time, timedelta
78
from decimal import Decimal
@@ -35,12 +36,8 @@
3536
import sqlglot
3637
from sqlglot import expressions as exp
3738
from sqlalchemy.pool import NullPool
39+
from pyhive import hive
3840

39-
try:
40-
from pyhive import hive
41-
PYHIVE_AVAILABLE = True
42-
except ImportError:
43-
PYHIVE_AVAILABLE = False
4441

4542
try:
4643
if os.path.exists(settings.ORACLE_CLIENT_PATH):
@@ -259,25 +256,22 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
259256
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
260257
return False
261258
elif equals_ignore_case(ds.type, 'hive'):
262-
if PYHIVE_AVAILABLE:
263-
try:
264-
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
265-
database=conf.database, **extra_config_dict)
266-
cursor = conn.cursor()
267-
cursor.execute('select 1')
268-
cursor.fetchall()
269-
cursor.close()
270-
conn.close()
271-
SQLBotLogUtil.info("success")
272-
return True
273-
except Exception as e:
274-
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
275-
if is_raise:
276-
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
277-
return False
278-
else:
279-
SQLBotLogUtil.error("pyhive not installed")
259+
try:
260+
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
261+
database=conf.database, **extra_config_dict)
262+
cursor = conn.cursor()
263+
cursor.execute('select 1')
264+
cursor.fetchall()
265+
cursor.close()
266+
conn.close()
267+
SQLBotLogUtil.info("success")
268+
return True
269+
except Exception as e:
270+
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
271+
if is_raise:
272+
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
280273
return False
274+
281275
elif equals_ignore_case(ds.type, 'es'):
282276
es_conn = get_es_connect(conf)
283277
if es_conn.ping():
@@ -403,6 +397,30 @@ def get_schema(ds: CoreDatasource):
403397
res = cursor.fetchall()
404398
res_list = [item[0] for item in res]
405399
return res_list
400+
elif equals_ignore_case(ds.type, 'hive'):
401+
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
402+
database=conf.database, **extra_config_dict)
403+
cursor = conn.cursor()
404+
cursor.execute('SHOW DATABASES')
405+
res = cursor.fetchall()
406+
res_list = [item[0] for item in res]
407+
cursor.close()
408+
conn.close()
409+
return res_list
410+
elif equals_ignore_case(ds.type, 'doris', 'starrocks'):
411+
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
412+
port=conf.port, db=conf.database, connect_timeout=10,
413+
read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
414+
cursor.execute('SHOW DATABASES')
415+
res = cursor.fetchall()
416+
res_list = [item[0] for item in res]
417+
return res_list
418+
elif equals_ignore_case(ds.type, 'ck'):
419+
with get_session(ds) as session:
420+
with session.execute(text('SHOW DATABASES')) as result:
421+
res = result.fetchall()
422+
res_list = [item[0] for item in res]
423+
return res_list
406424

407425

408426
def get_tables(ds: CoreDatasource):
@@ -465,17 +483,15 @@ def get_tables(ds: CoreDatasource):
465483
res_list = [TableSchema(*item) for item in res]
466484
return res_list
467485
elif equals_ignore_case(ds.type, 'hive'):
468-
if PYHIVE_AVAILABLE:
469-
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
470-
database=conf.database, **extra_config_dict)
471-
cursor = conn.cursor()
472-
cursor.execute(sql)
473-
res = cursor.fetchall()
474-
res_list = [TableSchema(*item) for item in res]
475-
cursor.close()
476-
conn.close()
477-
return res_list
478-
return []
486+
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
487+
database=conf.database, **extra_config_dict)
488+
cursor = conn.cursor()
489+
cursor.execute(sql)
490+
res = cursor.fetchall()
491+
res_list = [TableSchema(*item) for item in res]
492+
cursor.close()
493+
conn.close()
494+
return res_list
479495

480496

481497
def get_fields(ds: CoreDatasource, table_name: str = None):
@@ -538,17 +554,15 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
538554
res_list = [ColumnSchema(*item) for item in res]
539555
return res_list
540556
elif equals_ignore_case(ds.type, 'hive'):
541-
if PYHIVE_AVAILABLE:
542-
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
543-
database=conf.database, **extra_config_dict)
544-
cursor = conn.cursor()
545-
cursor.execute(sql)
546-
res = cursor.fetchall()
547-
res_list = [ColumnSchema(*item) for item in res]
548-
cursor.close()
549-
conn.close()
550-
return res_list
551-
return []
557+
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
558+
database=conf.database, **extra_config_dict)
559+
cursor = conn.cursor()
560+
cursor.execute(sql)
561+
res = cursor.fetchall()
562+
res_list = [ColumnSchema(*item) for item in res]
563+
cursor.close()
564+
conn.close()
565+
return res_list
552566

553567

554568
def convert_value(value, datetime_format='space'):
@@ -737,37 +751,53 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
737751
except Exception as ex:
738752
raise Exception(str(ex))
739753
elif equals_ignore_case(ds.type, 'hive'):
740-
if PYHIVE_AVAILABLE:
741-
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
742-
database=conf.database, **extra_config_dict)
743-
cursor = conn.cursor()
744-
try:
745-
cursor.execute(sql)
746-
res = cursor.fetchall()
747-
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
748-
field in
749-
cursor.description]
750-
result_list = [
751-
{str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in
752-
res
753-
]
754-
return {"fields": columns, "data": result_list,
755-
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
756-
except Exception as ex:
757-
raise ParseSQLResultError(str(ex))
758-
finally:
759-
cursor.close()
760-
conn.close()
761-
raise Exception("pyhive not installed")
754+
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
755+
database=conf.database, **extra_config_dict)
756+
cursor = conn.cursor()
757+
try:
758+
# Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback.
759+
hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql)
760+
cursor.execute(hive_sql)
761+
res = cursor.fetchall()
762+
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
763+
field in
764+
cursor.description]
765+
result_list = [
766+
{str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in
767+
res
768+
]
769+
return {"fields": columns, "data": result_list,
770+
"sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))}
771+
except Exception as ex:
772+
raise ParseSQLResultError(str(ex))
773+
finally:
774+
cursor.close()
775+
conn.close()
762776

763777

764778
def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
765779
try:
780+
normalized_sql = sql.strip().lstrip("(").strip()
781+
first_keyword = normalized_sql.split(None, 1)[0].upper() if normalized_sql else ""
782+
allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
783+
denied_write_commands = {
784+
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
785+
"TRUNCATE", "MERGE", "COPY", "REPLACE", "GRANT", "REVOKE",
786+
"USE", "SET", "CALL"
787+
}
788+
789+
if not first_keyword:
790+
raise ValueError("Parse SQL Error")
791+
if first_keyword in denied_write_commands:
792+
return False
793+
766794
dialect = None
767795
if equals_ignore_case(ds.type, 'mysql', 'doris', 'starrocks'):
768796
dialect = 'mysql'
769797
elif equals_ignore_case(ds.type, 'sqlServer'):
770798
dialect = 'tsql'
799+
elif equals_ignore_case(ds.type, 'hive'):
800+
dialect = 'hive'
771801

772802
statements = sqlglot.parse(sql, dialect=dialect)
773803

@@ -777,7 +807,7 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
777807
write_types = (
778808
exp.Insert, exp.Update, exp.Delete,
779809
exp.Create, exp.Drop, exp.Alter,
780-
exp.Merge, exp.Command, exp.Copy
810+
exp.Merge, exp.Copy
781811
)
782812

783813
for stmt in statements:
@@ -786,7 +816,7 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
786816
if isinstance(stmt, write_types):
787817
return False
788818

789-
return True
819+
return first_keyword in allowed_read_commands
790820

791821
except Exception as e:
792822
raise ValueError(f"Parse SQL Error: {e}")

backend/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ dependencies = [
5353
"elasticsearch[requests] (>=7.10,<8.0)",
5454
"ldap3>=2.9.1",
5555
"sqlglot>=28.6.0",
56-
"numpy==2.3.5"
56+
"numpy==2.3.5",
57+
"pyhive[hive]>=0.7.0",
58+
"thrift-sasl"
5759
]
5860

5961
[project.optional-dependencies]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
template:
2+
quot_rule: |
3+
<rule>
4+
必须对数据库名、表名、字段名、别名外层加反引号(`)。
5+
<note>
6+
1. 点号(.)不能包含在引号内,必须写成 `database`.`table`
7+
2. 即使标识符不含特殊字符或非关键字,也需强制加反引号
8+
</note>
9+
</rule>
10+
11+
limit_rule: |
12+
<rule>
13+
当需要限制行数时,必须使用标准的LIMIT语法
14+
</rule>
15+
16+
other_rule: |
17+
<rule>必须为每个表生成别名(不加AS)</rule>
18+
{multi_table_condition}
19+
<rule>禁止使用星号(*),必须明确字段名</rule>
20+
<rule>中文/特殊字符字段需保留原名并添加英文别名</rule>
21+
<rule>不能用 + 拼接字符串,字符串必须使用单引号</rule>
22+
<rule>分组非常严格:SELECT 里的字段必须出现在 GROUP BY 里,或者是聚合函数</rule>
23+
<rule>函数字段必须加别名</rule>
24+
<rule>百分比字段保留两位小数并以%结尾</rule>
25+
<rule>WHERE 条件中不能使用 >、<、>=、<= 等比较运算符,必须使用 =</rule>
26+
<rule>HIVE 中没有 NOT IN 操作符,必须使用 LEFT JOIN 或 EXISTS 替代</rule>
27+
<rule>判空使用 NVL()函数</rule>
28+
<rule>避免与数据库关键字冲突</rule>
29+
30+
basic_example: |
31+
<basic-examples>
32+
<intro>
33+
📌 以下示例严格遵循<Rules>中的 Hive 规范,展示符合要求的 SQL 写法与典型错误案例。
34+
⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。
35+
🔍 重点观察:
36+
1. 反引号包裹所有数据库对象的规范用法
37+
2. 中英别名/百分比/函数等特殊字段的处理
38+
3. 关键字冲突的规避方式
39+
</intro>
40+
<example>
41+
<input>查询 ods.orders 表的前100条订单(含中文字段和百分比)</input>
42+
<output-bad>
43+
SELECT * FROM ods.orders LIMIT 100 -- 错误:未加引号、使用星号
44+
SELECT `订单ID`, `金额` FROM `ods`.`orders` `t1` LIMIT 100 -- 错误:缺少英文别名
45+
SELECT COUNT(`订单ID`) FROM `ods`.`orders` `t1` -- 错误:函数未加别名
46+
</output-bad>
47+
<output-good>
48+
SELECT
49+
`t1`.`订单ID` AS `order_id`,
50+
`t1`.`金额` AS `amount`,
51+
COUNT(`t1`.`订单ID`) AS `total_orders`,
52+
CONCAT(CAST(ROUND(`t1`.`折扣率` * 100, 2) AS STRING), '%') AS `discount_percent`
53+
FROM `ods`.`orders` `t1`
54+
LIMIT 100
55+
</output-good>
56+
</example>
57+
58+
<example>
59+
<input>统计 dim.users(含关键字字段user)的活跃占比</input>
60+
<output-bad>
61+
SELECT user, status FROM dim.users -- 错误:未处理关键字和引号
62+
SELECT `user`, ROUND(active_ratio) FROM `dim`.`users` -- 错误:百分比格式错误
63+
</output-bad>
64+
<output-good>
65+
SELECT
66+
`u`.`user` AS `username`,
67+
CONCAT(CAST(ROUND(`u`.`active_ratio` * 100, 2) AS STRING), '%') AS `active_percent`
68+
FROM `dim`.`users` `u`
69+
WHERE `u`.`status` = 1
70+
</output-good>
71+
</example>
72+
</basic-examples>
73+
74+
example_engine: Apache Hive 2.X
75+
example_answer_1: |
76+
{"success":true,"sql":"SELECT `country` AS `country_name`, `continent` AS `continent_name`, `year` AS `year`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` ORDER BY `country`, `year`","tables":["sample_country_gdp"],"chart-type":"line"}
77+
example_answer_1_with_limit: |
78+
{"success":true,"sql":"SELECT `country` AS `country_name`, `continent` AS `continent_name`, `year` AS `year`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` ORDER BY `country`, `year` LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}
79+
example_answer_2: |
80+
{"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2024' ORDER BY `gdp` DESC","tables":["sample_country_gdp"],"chart-type":"pie"}
81+
example_answer_2_with_limit: |
82+
{"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2024' ORDER BY `gdp` DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}
83+
example_answer_3: |
84+
{"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2025' AND `country` = '中国'","tables":["sample_country_gdp"],"chart-type":"table"}
85+
example_answer_3_with_limit: |
86+
{"success":true,"sql":"SELECT `country` AS `country_name`, `gdp` AS `gdp` FROM `Sample_Database`.`sample_country_gdp` WHERE `year` = '2025' AND `country` = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}

0 commit comments

Comments
 (0)