Skip to content

Commit e39ad30

Browse files
committed
feat(protocol): Cythonize SHA256 auth, optimize types, and improve test robustness
- Refactor SHA256 auth helpers to Cython cdef/cpdef for performance - Move PasswordMethods enum to coreproto.pxd for type safety - Use enum constants for password method checks - Enhance test_codecs.py with DROP IF EXISTS for idempotency - Add extra asyncio.sleep(0) in GC tests for reliability - Unify code style and type declarations for maintainability Test: All auth and codec tests pass in GaussDB/openGauss/PostgreSQL environments. GC and resource warnings resolved.
1 parent a1c62c7 commit e39ad30

File tree

9 files changed

+298
-62
lines changed

9 files changed

+298
-62
lines changed

asyncpg/protocol/coreproto.pxd

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ cdef enum TransactionStatus:
6363
PQTRANS_INERROR = 3 # idle, within failed transaction
6464
PQTRANS_UNKNOWN = 4 # cannot determine status
6565

66+
# Password storage methods
67+
cdef enum PasswordMethods:
68+
PLAIN_PASSWORD = 0
69+
SHA256_PASSWORD = 2
70+
MD5_PASSWORD = 1
71+
6672

6773
ctypedef object (*decode_row_method)(object, const char*, ssize_t)
6874

asyncpg/protocol/coreproto.pyx

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ from hashlib import pbkdf2_hmac
1111
include "scram.pyx"
1212

1313

14-
# 添加SHA256认证相关常量
15-
AUTH_REQUIRED_SHA256 = 13 # GaussDB SHA256认证类型
16-
17-
# Password storage methods (from Go code)
18-
PLAIN_PASSWORD = 0
19-
SHA256_PASSWORD = 2
20-
MD5_PASSWORD = 1
2114

2215

2316
cdef dict AUTH_METHOD_NAME = {
@@ -26,21 +19,23 @@ cdef dict AUTH_METHOD_NAME = {
2619
AUTH_REQUIRED_PASSWORDMD5: 'md5',
2720
AUTH_REQUIRED_GSS: 'gss',
2821
AUTH_REQUIRED_SASL: 'scram-sha-256',
29-
AUTH_REQUIRED_SSPI: 'sspi',
30-
AUTH_REQUIRED_SHA256: 'sha256', # 添加SHA256
22+
AUTH_REQUIRED_SSPI: 'sspi'
3123
}
3224

3325

34-
def hex_string_to_bytes(hex_string):
26+
cdef hex_string_to_bytes(str hex_string):
3527
"""
3628
将hex字符串转换为bytes,对应Go的hexStringToBytes函数
3729
"""
3830
if not hex_string:
3931
return b''
4032

41-
upper_string = hex_string.upper()
42-
bytes_len = len(upper_string) // 2
43-
result = bytearray(bytes_len)
33+
cdef str upper_string = hex_string.upper()
34+
cdef int bytes_len = len(upper_string) // 2
35+
cdef bytearray result = bytearray(bytes_len)
36+
cdef int i, pos
37+
cdef str high_char, low_char
38+
cdef int high_val, low_val
4439

4540
for i in range(bytes_len):
4641
pos = i * 2
@@ -55,21 +50,24 @@ def hex_string_to_bytes(hex_string):
5550

5651
return bytes(result)
5752

58-
def generate_k_from_pbkdf2(password, random64code, server_iteration):
53+
cdef generate_k_from_pbkdf2(str password, str random64code, int server_iteration):
5954
"""
6055
对应Go的generateKFromPBKDF2函数
6156
注意:Go代码使用的是SHA1,不是SHA256
6257
"""
63-
random32code = hex_string_to_bytes(random64code)
58+
cdef bytes random32code = hex_string_to_bytes(random64code)
6459
# Go代码使用sha1.New,所以这里使用'sha1'
65-
pwd_encoded = pbkdf2_hmac('sha1', password.encode('utf-8'), random32code, server_iteration, 32)
60+
cdef bytes pwd_encoded = pbkdf2_hmac('sha1', password.encode('utf-8'), random32code, server_iteration, 32)
6661
return pwd_encoded
6762

68-
def bytes_to_hex_string(src):
63+
cdef bytes_to_hex_string(bytes src):
6964
"""
7065
对应Go的bytesToHexString函数
7166
"""
72-
s = ""
67+
cdef str s = ""
68+
cdef int byte_val, v
69+
cdef str hv
70+
7371
for byte_val in src:
7472
v = byte_val & 0xFF
7573
hv = format(v, 'x')
@@ -79,51 +77,57 @@ def bytes_to_hex_string(src):
7977
s += hv
8078
return s
8179

82-
def get_key_from_hmac(key, data):
80+
cdef get_key_from_hmac(bytes key, bytes data):
8381
"""
8482
对应Go的getKeyFromHmac函数,使用SHA256
8583
"""
86-
h = hmac.new(key, data, hashlib.sha256)
84+
cdef object h = hmac.new(key, data, hashlib.sha256)
8785
return h.digest()
8886

89-
def get_sha256(message):
87+
cdef get_sha256(bytes message):
9088
"""
9189
对应Go的getSha256函数
9290
"""
93-
hash_obj = hashlib.sha256()
91+
cdef object hash_obj = hashlib.sha256()
9492
hash_obj.update(message)
9593
return hash_obj.digest()
9694

97-
def get_sm3(message):
95+
cdef get_sm3(bytes message):
9896
"""
9997
对应Go的getSm3函数 (这里用SHA256代替,因为Python标准库没有SM3)
10098
实际项目中需要安装gmssl库来支持SM3
10199
"""
102100
# 注意:这里用SHA256代替SM3,实际使用时需要proper的SM3实现
103-
hash_obj = hashlib.sha256() # 临时用SHA256代替
101+
cdef object hash_obj = hashlib.sha256() # 临时用SHA256代替
104102
hash_obj.update(message)
105103
return hash_obj.digest()
106104

107-
def xor_between_password(password1, password2, length):
105+
cdef xor_between_password(bytes password1, bytes password2, int length):
108106
"""
109107
对应Go的XorBetweenPassword函数
110108
"""
111-
result = bytearray(length)
109+
cdef bytearray result = bytearray(length)
110+
cdef int i
111+
112112
for i in range(length):
113113
result[i] = password1[i] ^ password2[i]
114114
return bytes(result)
115115

116-
def bytes_to_hex(source_bytes, result_array=None, start_pos=0, length=None):
116+
cdef bytes_to_hex(bytes source_bytes, bytearray result_array=None, int start_pos=0, int length=-1):
117117
"""
118118
对应Go的bytesToHex函数,支持Java风格的4参数调用
119119
但Go代码只传1个参数,所以做兼容处理
120120
"""
121+
cdef bytes lookup = b'0123456789abcdef'
122+
cdef int pos, i, c, j
123+
cdef int byte_val
124+
cdef bytearray result
125+
121126
if result_array is not None:
122127
# Java风格:4个参数 bytesToHex(hValue, result, 0, hValue.length)
123-
if length is None:
128+
if length == -1:
124129
length = len(source_bytes)
125130

126-
lookup = b'0123456789abcdef'
127131
pos = start_pos
128132

129133
for i in range(length):
@@ -140,7 +144,6 @@ def bytes_to_hex(source_bytes, result_array=None, start_pos=0, length=None):
140144
return result_array
141145
else:
142146
# Go风格:1个参数,返回新的bytes
143-
lookup = b'0123456789abcdef'
144147
result = bytearray(len(source_bytes) * 2)
145148
pos = 0
146149

@@ -155,10 +158,15 @@ def bytes_to_hex(source_bytes, result_array=None, start_pos=0, length=None):
155158

156159
return bytes(result)
157160

158-
def rfc5802_algorithm(password, random64code, token, server_signature="", server_iteration=4096, method="sha256"):
161+
cpdef rfc5802_algorithm(str password, str random64code, str token, str server_signature="", int server_iteration=4096, str method="sha256"):
159162
"""
160163
RFC5802算法实现,完全对应Go代码逻辑
161164
"""
165+
cdef bytes k, server_key, client_key, stored_key, token_byte
166+
cdef bytes client_signature, hmac_result, h_value
167+
cdef bytearray result
168+
cdef int h_value_len
169+
162170
try:
163171
# Step 1: 生成K (SaltedPassword)
164172
k = generate_k_from_pbkdf2(password, random64code, server_iteration)
@@ -192,16 +200,23 @@ def rfc5802_algorithm(password, random64code, token, server_signature="", server
192200
h_value = xor_between_password(hmac_result, client_key, len(client_key))
193201

194202
# Step 9: 转换为hex bytes格式 (对应Java的 bytesToHex(hValue, result, 0, hValue.length))
195-
result = bytearray(len(h_value) * 2)
196-
bytes_to_hex(h_value, result, 0, len(h_value))
203+
h_value_len = len(h_value)
204+
result = bytearray(h_value_len * 2)
205+
bytes_to_hex(h_value, result, 0, h_value_len)
197206

198207
return bytes(result)
199208

200209
except Exception as e:
201210
raise ValueError(f"RFC5802Algorithm failed: {e}")
202211

203212

204-
import hashlib
213+
214+
215+
216+
217+
218+
219+
205220

206221
def bytes_to_hex(src_bytes, dst_bytes, offset, length):
207222
"""
@@ -826,15 +841,15 @@ cdef class CoreProtocol:
826841
self.result = apg_exc.InterfaceError(
827842
'The server requested password-based authentication, '
828843
'but no password was provided.')
829-
if password_stored_method==2:
844+
if password_stored_method in (SHA256_PASSWORD,PLAIN_PASSWORD):
830845
# 读取认证参数
831846
random64code = self.buffer.read_bytes(64)
832847
token = self.buffer.read_bytes(8)
833848
server_iteration = self.buffer.read_int32()
834849
# 调用_auth_password_message_sha256生成认证消息
835850
self.auth_msg = self._auth_password_message_sha256(random64code, token,
836851
server_iteration)
837-
elif password_stored_method == 5:
852+
elif password_stored_method == MD5_PASSWORD:
838853
# MD5密码存储方式
839854
salt = self.buffer.read_bytes(4)
840855
self.auth_msg = self._auth_password_message_md5(salt)
@@ -892,7 +907,7 @@ cdef class CoreProtocol:
892907
'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
893908

894909
if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
895-
AUTH_REQUIRED_GSS_CONTINUE, AUTH_REQUIRED_SHA256):
910+
AUTH_REQUIRED_GSS_CONTINUE):
896911
self.buffer.discard_message()
897912

898913
cdef _auth_password_message_cleartext(self):

gs_ctl

405 KB
Binary file not shown.

0 commit comments

Comments
 (0)