66
77
88import hashlib
9+ import hmac
10+ from hashlib import pbkdf2_hmac
11+ include " scram.pyx"
912
1013
11- include " scram.pyx"
14+ # 添加SHA256认证相关常量
15+ AUTH_REQUIRED_SHA256 = 13 # GaussDB SHA256认证类型
16+
17+ # Password storage methods (from Go code)
18+ PLAIN_PASSWORD = 0
19+ SHA256_PASSWORD = 2
20+ MD5_PASSWORD = 1
1221
1322
1423cdef dict AUTH_METHOD_NAME = {
@@ -18,9 +27,230 @@ cdef dict AUTH_METHOD_NAME = {
1827 AUTH_REQUIRED_GSS: ' gss' ,
1928 AUTH_REQUIRED_SASL: ' scram-sha-256' ,
2029 AUTH_REQUIRED_SSPI: ' sspi' ,
30+ AUTH_REQUIRED_SHA256: ' sha256' , # 添加SHA256
2131}
2232
2333
34+ def hex_string_to_bytes (hex_string ):
35+ """
36+ 将hex字符串转换为bytes,对应Go的hexStringToBytes函数
37+ """
38+ if not hex_string:
39+ return b' '
40+
41+ upper_string = hex_string.upper()
42+ bytes_len = len (upper_string) // 2
43+ result = bytearray(bytes_len)
44+
45+ for i in range (bytes_len):
46+ pos = i * 2
47+ high_char = upper_string[pos]
48+ low_char = upper_string[pos + 1 ]
49+
50+ # 将字符转换为数值
51+ high_val = " 0123456789ABCDEF" .index(high_char)
52+ low_val = " 0123456789ABCDEF" .index(low_char)
53+
54+ result[i] = (high_val << 4 ) | low_val
55+
56+ return bytes(result)
57+
58+ def generate_k_from_pbkdf2 (password , random64code , server_iteration ):
59+ """
60+ 对应Go的generateKFromPBKDF2函数
61+ 注意:Go代码使用的是SHA1,不是SHA256
62+ """
63+ random32code = hex_string_to_bytes(random64code)
64+ # Go代码使用sha1.New,所以这里使用'sha1'
65+ pwd_encoded = pbkdf2_hmac(' sha1' , password.encode(' utf-8' ), random32code, server_iteration, 32 )
66+ return pwd_encoded
67+
68+ def bytes_to_hex_string (src ):
69+ """
70+ 对应Go的bytesToHexString函数
71+ """
72+ s = " "
73+ for byte_val in src:
74+ v = byte_val & 0xFF
75+ hv = format(v, ' x' )
76+ if len (hv) < 2 :
77+ s += " 0" + hv
78+ else :
79+ s += hv
80+ return s
81+
82+ def get_key_from_hmac (key , data ):
83+ """
84+ 对应Go的getKeyFromHmac函数,使用SHA256
85+ """
86+ h = hmac.new(key, data, hashlib.sha256)
87+ return h.digest()
88+
89+ def get_sha256 (message ):
90+ """
91+ 对应Go的getSha256函数
92+ """
93+ hash_obj = hashlib.sha256()
94+ hash_obj.update(message)
95+ return hash_obj.digest()
96+
97+ def get_sm3 (message ):
98+ """
99+ 对应Go的getSm3函数 (这里用SHA256代替,因为Python标准库没有SM3)
100+ 实际项目中需要安装gmssl库来支持SM3
101+ """
102+ # 注意:这里用SHA256代替SM3,实际使用时需要proper的SM3实现
103+ hash_obj = hashlib.sha256() # 临时用SHA256代替
104+ hash_obj.update(message)
105+ return hash_obj.digest()
106+
107+ def xor_between_password (password1 , password2 , length ):
108+ """
109+ 对应Go的XorBetweenPassword函数
110+ """
111+ result = bytearray(length)
112+ for i in range (length):
113+ result[i] = password1[i] ^ password2[i]
114+ return bytes(result)
115+
116+ def bytes_to_hex (source_bytes , result_array = None , start_pos = 0 , length = None ):
117+ """
118+ 对应Go的bytesToHex函数,支持Java风格的4参数调用
119+ 但Go代码只传1个参数,所以做兼容处理
120+ """
121+ if result_array is not None :
122+ # Java风格:4个参数 bytesToHex(hValue, result, 0, hValue.length)
123+ if length is None :
124+ length = len (source_bytes)
125+
126+ lookup = b' 0123456789abcdef'
127+ pos = start_pos
128+
129+ for i in range (length):
130+ if i >= len (source_bytes):
131+ break
132+ byte_val = source_bytes[i]
133+ c = int (byte_val & 0xFF )
134+ j = c >> 4
135+ result_array[pos] = lookup[j]
136+ pos += 1
137+ j = c & 0xF
138+ result_array[pos] = lookup[j]
139+ pos += 1
140+ return result_array
141+ else :
142+ # Go风格:1个参数,返回新的bytes
143+ lookup = b' 0123456789abcdef'
144+ result = bytearray(len (source_bytes) * 2 )
145+ pos = 0
146+
147+ for byte_val in source_bytes:
148+ c = int (byte_val & 0xFF )
149+ j = c >> 4
150+ result[pos] = lookup[j]
151+ pos += 1
152+ j = c & 0xF
153+ result[pos] = lookup[j]
154+ pos += 1
155+
156+ return bytes(result)
157+
158+ def rfc5802_algorithm (password , random64code , token , server_signature = " " , server_iteration = 4096 , method = " sha256" ):
159+ """
160+ RFC5802算法实现,完全对应Go代码逻辑
161+ """
162+ try :
163+ # Step 1: 生成K (SaltedPassword)
164+ k = generate_k_from_pbkdf2(password, random64code, server_iteration)
165+
166+ # Step 2: 生成ServerKey和ClientKey
167+ server_key = get_key_from_hmac(k, b" Sever Key" ) # 保持"Sever Key"拼写
168+ client_key = get_key_from_hmac(k, b" Client Key" )
169+
170+ # Step 3: 生成StoredKey
171+ if method.lower() == " sha256" :
172+ stored_key = get_sha256(client_key)
173+ elif method.lower() == " sm3" :
174+ stored_key = get_sm3(client_key)
175+ else :
176+ stored_key = get_sha256(client_key) # 默认使用SHA256
177+
178+ # Step 4: 转换token为bytes
179+ token_byte = hex_string_to_bytes(token)
180+
181+ # Step 5: 计算clientSignature (实际上是ServerSignature,用于验证)
182+ client_signature = get_key_from_hmac(server_key, token_byte)
183+
184+ # Step 6: 验证serverSignature (如果提供)
185+ if server_signature and server_signature != bytes_to_hex_string(client_signature):
186+ return b" "
187+
188+ # Step 7: 计算真正的ClientSignature
189+ hmac_result = get_key_from_hmac(stored_key, token_byte)
190+
191+ # Step 8: XOR操作得到ClientProof
192+ h_value = xor_between_password(hmac_result, client_key, len (client_key))
193+
194+ # Step 9: 转换为hex bytes格式 (对应Java的 bytesToHex(hValue, result, 0, hValue.length))
195+ result = bytearray(len (h_value) * 2 )
196+ bytes_to_hex(h_value, result, 0 , len (h_value))
197+
198+ return bytes(result)
199+
200+ except Exception as e:
201+ raise ValueError (f" RFC5802Algorithm failed: {e}" )
202+
203+
204+ import hashlib
205+
206+ def bytes_to_hex (src_bytes , dst_bytes , offset , length ):
207+ """
208+ Java: bytesToHex(byte[] src, byte[] dst, int offset, int length)
209+ - src: 源字节数组
210+ - dst: 目标字节数组
211+ - offset: dst写入起始位置
212+ - length: 需要转换的src字节数量
213+ 写入的输出是十六进制ASCII字节(不是16进制数值),每个字节转换成2个字母。
214+ """
215+ HEX_DIGITS = b' 0123456789abcdef'
216+ for i in range (length):
217+ v = src_bytes[i]
218+ if isinstance (v, str ):
219+ v = ord (v)
220+ if v < 0 :
221+ v += 256
222+ dst_bytes[offset + (i * 2 )] = HEX_DIGITS[v >> 4 ]
223+ dst_bytes[offset + (i * 2 ) + 1 ] = HEX_DIGITS[v & 0x0F ]
224+
225+ def SHA256_MD5encode (user: bytes , password: bytes , salt: bytes ) -> bytes:
226+ try:
227+ md = hashlib.md5()
228+ md.update(password )
229+ md.update(user )
230+ temp_digest = md.digest() # 16 bytes
231+
232+ # hex_digest 70字节(实际前6和后64有效)
233+ hex_digest = bytearray(70 )
234+
235+ # 前32个字节为temp_digest的hex(16字节*2)
236+ bytes_to_hex(temp_digest , hex_digest , 0, 16)
237+
238+ # 取前32字节(hex后缀): hex_digest[0 :32 ], 作为SHA256输入
239+ sha = hashlib.sha256()
240+ sha.update(hex_digest[0 :32 ])
241+ sha.update(salt)
242+ pass_digest = sha.digest() # 32 bytes
243+
244+ # pass_digest的hex写到hex_digest[6:]
245+ bytes_to_hex(pass_digest, hex_digest, 6 , 32 )
246+
247+ # 填入ASCII签名'sha256'
248+ hex_digest[0 :6 ] = b' sha256'
249+
250+ except Exception as e:
251+ raise ValueError (' SHA256_MD5encode failed: %s ' % str (e))
252+ return bytes(hex_digest)
253+
24254cdef class CoreProtocol:
25255
26256 def __init__ (self , addr , con_params ):
@@ -564,7 +794,10 @@ cdef class CoreProtocol:
564794 bytes md5_salt
565795 list sasl_auth_methods
566796 list unsupported_sasl_auth_methods
567-
797+ int32_t password_stored_method
798+ bytes random64code
799+ bytes token
800+ int32_t server_iteration
568801 status = self .buffer.read_int32()
569802
570803 if status == AUTH_SUCCESSFUL:
@@ -587,32 +820,30 @@ cdef class CoreProtocol:
587820 # This requires making additional requests to the server in order
588821 # to follow the SCRAM protocol defined in RFC 5802.
589822 # get the SASL authentication methods that the server is providing
590- sasl_auth_methods = []
591- unsupported_sasl_auth_methods = []
592- # determine if the advertised authentication methods are supported,
593- # and if so, add them to the list
594- auth_method = self .buffer.read_null_str()
595- while auth_method:
596- if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS:
597- sasl_auth_methods.append(auth_method)
598- else :
599- unsupported_sasl_auth_methods.append(auth_method)
600- auth_method = self .buffer.read_null_str()
601-
602- # if none of the advertised authentication methods are supported,
603- # raise an error
604- # otherwise, initialize the SASL authentication exchange
605- if not sasl_auth_methods:
606- unsupported_sasl_auth_methods = [m.decode(" ascii" )
607- for m in unsupported_sasl_auth_methods]
823+ password_stored_method = self .buffer.read_int32()
824+ if not self .password:
608825 self .result_type = RESULT_FAILED
609826 self .result = apg_exc.InterfaceError(
610- ' unsupported SASL Authentication methods requested by the '
611- ' server: {!r}' .format(
612- " , " .join(unsupported_sasl_auth_methods)))
827+ ' The server requested password-based authentication, '
828+ ' but no password was provided.' )
829+ if password_stored_method== 2 :
830+ # 读取认证参数
831+ random64code = self .buffer.read_bytes(64 )
832+ token = self .buffer.read_bytes(8 )
833+ server_iteration = self .buffer.read_int32()
834+ # 调用_auth_password_message_sha256生成认证消息
835+ self .auth_msg = self ._auth_password_message_sha256(random64code, token,
836+ server_iteration)
837+ elif password_stored_method == 5 :
838+ # MD5密码存储方式
839+ salt = self .buffer.read_bytes(4 )
840+ self .auth_msg = self ._auth_password_message_md5(salt)
841+
613842 else :
614- self .auth_msg = self ._auth_password_message_sasl_initial(
615- sasl_auth_methods)
843+ self .result_type = RESULT_FAILED
844+ self .result = apg_exc.InterfaceError(
845+ f' The password-stored method {password_stored_method} is not supported, '
846+ ' must be plain, md5 or sha256.' )
616847
617848 elif status == AUTH_SASL_CONTINUE:
618849 # AUTH_SASL_CONTINUE
@@ -661,7 +892,7 @@ cdef class CoreProtocol:
661892 ' server: {!r}' .format(AUTH_METHOD_NAME.get(status, status)))
662893
663894 if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
664- AUTH_REQUIRED_GSS_CONTINUE):
895+ AUTH_REQUIRED_GSS_CONTINUE, AUTH_REQUIRED_SHA256 ):
665896 self .buffer.discard_message()
666897
667898 cdef _auth_password_message_cleartext(self ):
@@ -690,6 +921,37 @@ cdef class CoreProtocol:
690921
691922 return msg
692923
924+ cdef _auth_password_message_sha256(self , bytes random64code, bytes token,
925+ int32_t server_iteration):
926+ """
927+ 处理SHA256认证消息
928+ """
929+ cdef:
930+ WriteBuffer msg
931+ bytes result
932+
933+ # 调用RFC5802算法计算认证结果
934+ result = rfc5802_algorithm(
935+ self .password,
936+ random64code.decode(' utf-8' ),
937+ token.decode(' utf-8' ),
938+ ' ' , # salt为空
939+ server_iteration,
940+ ' sha256'
941+ )
942+ if not result:
943+ self .result_type = RESULT_FAILED
944+ self .result = apg_exc.InterfaceError(
945+ ' Invalid username/password, login denied.' )
946+ return None
947+
948+ # 构建认证响应消息
949+ msg = WriteBuffer.new_message(b' p' )
950+ msg.write_bytes(result)
951+ msg.end_message()
952+
953+ return msg
954+
693955 cdef _auth_password_message_sasl_initial(self , list sasl_auth_methods):
694956 cdef:
695957 WriteBuffer msg
@@ -938,7 +1200,7 @@ cdef class CoreProtocol:
9381200
9391201 # protocol version
9401202 buf.write_int16(3 )
941- buf.write_int16(0 )
1203+ buf.write_int16(51 )
9421204
9431205 buf.write_bytestring(b' client_encoding' )
9441206 buf.write_bytestring(" '{}'" .format(self .encoding).encode(' ascii' ))
0 commit comments