-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathedb_controller.py
More file actions
305 lines (233 loc) · 13.1 KB
/
edb_controller.py
File metadata and controls
305 lines (233 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import hashlib
from itertools import cycle
from RocksDBWrapper import RocksDBWrapper as RocksWrapper
from Crypto.Cipher import AES
import utilities
from flush_server_observer import Flushing_SMonitor
class SSE_Server:
def __init__(self, ):
self.rockwrappers = RocksWrapper()
#is connectionRocksDB alive
self.isAlive = False
#flushing_monitor
self.f_mode = False
def import_encrypted_batch(self, encrypted_batch):
for k,v in encrypted_batch.items():
self.rockwrappers.put(k,v)
if self.f_mode:
self.monitor.tracker += 1
return len(encrypted_batch)
def start_monitor(self, f_period,f_interval):
self.f_mode = True
self.monitor = Flushing_SMonitor(f_period,f_interval)
print(">Monitoring thread enabled")
self.monitor.start()
return 1
def close_monitor(self):
print(">Monitoring thread disabled")
self.monitor._stop()
return 1
def open_conn(self):
self.isAlive = True
return self.rockwrappers.open()
def close_conn(self):
self.isAlive = False
return self.rockwrappers.close()
def get_db_info(self):
return self.rockwrappers.getInfo()
def getStatus(self):
return self.isAlive
def search(self, token):
encrypted_IDs = list([])
try:
k_w = token[0]
k_id= token[1]
latest_state= token[2]
latest_count= token[3]
iv= token[4]
while(latest_count !=0):
count_in_batch = 0
for count_in_batch in range(latest_count):
primitive_enc_state_count = self.pseudo_function_F(latest_state,str(count_in_batch),iv)
u = self.primitive_hash_h(primitive_enc_state_count + k_w)
v = self.rockwrappers.get(u) # data type of v is bytes
v_part = self.primitive_hash_h(self.pseudo_function_F(latest_state,str(count_in_batch),iv) + k_id)
encoded_id = [ a ^ b for (a,b) in zip(v, cycle(v_part)) ]
#client needs to decrypt encoded_id with proper k_id
encrypted_IDs.append(bytes(encoded_id))
#retrieve the previous k_e and count
count_in_batch +=1
primitive_enc_key_count = self.pseudo_function_F(latest_state,str(count_in_batch),iv)
u_k = self.primitive_hash_h( primitive_enc_key_count + k_w)
v_k = self.rockwrappers.get(u_k) # v_k has been bytes
v_k_part = self.primitive_hash_h(primitive_enc_key_count + k_id) #256 binary length
K_e_c = [ a ^ b for (a,b) in zip(v_k, cycle(v_k_part)) ] # this content k_c and previous count
k_e_dec = bytes(K_e_c)
k_e = k_e_dec[:16]
str_previous_counter = k_e_dec[16:].decode("ascii")
#update previous count
latest_count = int(str_previous_counter)
#identify previous state
latest_state = self.pseudo_inverse_permutation_P(k_e,latest_state,iv)
except:
print("Exception")
pass
finally:
return encrypted_IDs
def search_delete(self, token):
#extract token
k_w = token[0]
k_id= token[1]
latest_state= token[2]
latest_count= token[3]
iv= token[4]
#deletion key in server
del_key = list([])
encrypted_IDs = list([])
while(latest_count !=0):
count_in_batch = 0
for count_in_batch in range(latest_count):
primitive_enc_state_count = self.pseudo_function_F(latest_state,str(count_in_batch),iv)
u = self.primitive_hash_h(primitive_enc_state_count + k_w)
del_key.append(u)
v = self.rockwrappers.get(u) # data type of v is bytes
v_part = self.primitive_hash_h(self.pseudo_function_F(latest_state,str(count_in_batch),iv) + k_id)
encoded_id = [ a ^ b for (a,b) in zip(v, cycle(v_part)) ]
#client needs to decrypt encoded_id with proper k_id
encrypted_IDs.append(bytes(encoded_id))
#retrieve the previous k_e and count
count_in_batch +=1
primitive_enc_key_count = self.pseudo_function_F(latest_state,str(count_in_batch),iv)
u_k = self.primitive_hash_h( primitive_enc_key_count + k_w)
del_key.append(u_k)
v_k = self.rockwrappers.get(u_k) # v_k has been bytes
v_k_part = self.primitive_hash_h(primitive_enc_key_count + k_id) #256 binary length
K_e_c = [ a ^ b for (a,b) in zip(v_k, cycle(v_k_part)) ] # this content k_c and previous count
k_e_dec = bytes(K_e_c)
k_e = k_e_dec[:16]
str_previous_counter = k_e_dec[16:].decode("ascii")
#update previous count
latest_count = int(str_previous_counter)
#identify previous state
latest_state = self.pseudo_inverse_permutation_P(k_e,latest_state,iv)
# delete these k/values in dict
for item in del_key:
self.rockwrappers.delete(item)
return encrypted_IDs
####################### primitive functions ##############
def _pad_str(self, s, bs=32):
return s + (bs - len(s) % bs) * chr(bs - len(s) % bs)
def _unpad_str(self,s):
return s[:-ord(s[len(s)-1:])]
def primitive_hash_h(self, msg):
m= hashlib.sha256()
m.update(msg)
hash_msg = m.digest()
return hash_msg
def pseudo_permutation_P(self, key, raw, iv):
cipher = AES.new(key,AES.MODE_CBC,iv) #raw must be multiple of 16
return cipher.encrypt(raw)
def pseudo_inverse_permutation_P(self, key, ctext,iv):
cipher = AES.new(key,AES.MODE_CBC,iv)
return cipher.decrypt(ctext)
def pseudo_function_F(self, key, raw, iv):
raw = self._pad_str(raw)
cipher = AES.new(key,AES.MODE_CBC,iv)
return cipher.encrypt(raw)
def pseudo_inverse_function_F(self, key, ctext, iv):
cipher = AES.new(key,AES.MODE_CBC,iv)
return self._unpad_str(cipher.decrypt(ctext))
#######################Counting attack ######################################33
def perform_count_attacks(self, lookUp, query_tokens,access_patterns, query_keywords,max_pad):
query_number = len(query_tokens)
''' construct co_occurance Cq'''
occurence_C_q = [[] for _ in range(query_number)]
possible_candidate_list_dict ={} #each element should be a set
#initialise possible_candidate_list_dict
for i in range(query_number):
possible_candidate_list_dict[i] = set([])
query_map = []
for i in range(query_number):
for j in range(query_number):
docID_keyword1 = access_patterns[i]
docID_keyword2 = access_patterns[j]
occurence_C_q[i].append(len(docID_keyword1 & docID_keyword2))
''' find out the possible candidate for each query token based on max_pad window'''
''' different queries can have identical candidates - these possible candidates must be available at server side
format would be possible_candidate_list_dict = {1: [11,12,13,14,15,16], 2: [9,12,14,15]}
'''
#print("max pad" + str(max_pad))
#for i in range(query_number):
#print(str(len(access_patterns[i])))
#keyword = query_keywords[i]
#if keyword in lookUp:
# print(str(len(lookUp[keyword])) + " " + str(len(access_patterns[i]) - max_pad))
for k,v in lookUp.items():
for i in range(query_number):
result_length = len(access_patterns[i])
temp_len = len(v)
if temp_len in range(result_length-max_pad, result_length+1):
if i in possible_candidate_list_dict:
possible_candidate_list_dict[i].add(k) #possible candidate for query token #i
else:
possible_candidate_list_dict[i] = set([k])
''' double check to see whether there is any query that only has one candidate '''
for query_index in range(query_number):
if len(possible_candidate_list_dict[query_index]) == 1:
for keyword in possible_candidate_list_dict[query_index]:
break
query_map.append((query_index, query_tokens[query_index],keyword,len(access_patterns[query_index])))
del possible_candidate_list_dict[query_index]
#double check with keywords in query map
guess_count = 0
for item in query_map:
guesskeyword = item[2]
correctkeyword = query_keywords[item[0]]
if guesskeyword == correctkeyword:
guess_count += 1
print("Knowledge query map size caused by unique length is %i " % guess_count)
''' run the guess for other unknown queries'''
for key, value in possible_candidate_list_dict.items():
query_index = key
possible_keyword_list = value
remove_keyword_set = [] #index of items that should be deleted
for candidate_keyword in possible_keyword_list:
''' check with known queries first '''
inconsistency_stop = False
for known_query in query_map:
if inconsistency_stop== False:
count_in_Cq = occurence_C_q[query_index][known_query[0]]
count_in_Cl = utilities.find_cooccurence(candidate_keyword,known_query[2],lookUp)
if (count_in_Cq < count_in_Cl) or (count_in_Cq > (count_in_Cl+max_pad)):
remove_keyword_set.append(candidate_keyword)
inconsistency_stop = True
''' check with unknown queries '''
if inconsistency_stop ==False:
''' verify with the next unknown query's candidates'''
for unknown_query_index in range(query_number):
if inconsistency_stop== False and unknown_query_index!=query_index:
isknown = utilities.check_known_query_index(unknown_query_index,query_map)
if isknown == False:
possible_keyword_list_next_query = possible_candidate_list_dict[unknown_query_index]
#current value of Cq
Cq_co_value = occurence_C_q[query_index][unknown_query_index]
isconsistency_unknown_queries = utilities.check_consistency_with_unknown_query(candidate_keyword,Cq_co_value,possible_keyword_list_next_query,lookUp,max_pad)
if isconsistency_unknown_queries ==False:
remove_keyword_set.append(candidate_keyword)
inconsistency_stop = True
remove_keyword_set = list(set(remove_keyword_set))
for item in remove_keyword_set:
possible_keyword_list.remove(item)
if len(possible_keyword_list) == 1:
#print(">> One remained" )
for keyword in possible_keyword_list:
break
query_map.append((query_index, query_tokens[query_index], keyword,len(access_patterns[query_index])))
if keyword == query_keywords[query_index]:
#print(">>" + str(query_index+1) + " True" )
guess_count+=1
#else:
# print(">>" + str(query_index+1) + " False" )
#else:
# print(">>" + str(query_index+1) + " False -not distinguishable" )
return guess_count/query_number