Skip to content

Commit 1816a92

Browse files
SK-1911: add sky metadata headers
1 parent 7385eb0 commit 1816a92

2 files changed

Lines changed: 22 additions & 64 deletions

File tree

skyflow/vault/controller/_vault.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
from skyflow.generated.rest import V1FieldRecords, V1BatchRecord, V1TokenizeRecordRequest, \
23
V1DetokenizeRecordRequest
34
from skyflow.utils import SkyflowMessages, parse_insert_response, \
45
handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \
5-
parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values
6+
parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics
7+
from skyflow.utils.constants import SKY_META_DATA_HEADER
68
from skyflow.utils.enums import RequestMethod
79
from skyflow.utils.logger import log_info, log_error_log
810
from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \
@@ -61,6 +63,12 @@ def __build_insert_body(self, request: InsertRequest):
6163
records_list = self.__build_bulk_field_records(request.values, request.tokens)
6264
return records_list
6365

66+
def __get_headers(self):
67+
headers = {
68+
SKY_META_DATA_HEADER: json.dumps(get_metrics())
69+
}
70+
return headers
71+
6472
def insert(self, request: InsertRequest):
6573
log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger())
6674
validate_insert_request(self.__vault_client.get_logger(), request)
@@ -73,11 +81,11 @@ def insert(self, request: InsertRequest):
7381
log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger())
7482
if request.continue_on_error:
7583
api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(),
76-
records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value)
84+
records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers())
7785

7886
else:
7987
api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(),
80-
request.table_name, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value)
88+
request.table_name, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers())
8189

8290
insert_response = parse_insert_response(api_response, request.continue_on_error)
8391
log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger())
@@ -104,7 +112,8 @@ def update(self, request: UpdateRequest):
104112
id=request.data.get("skyflow_id"),
105113
record=record,
106114
tokenization=request.return_tokens,
107-
byot=request.token_mode.value
115+
byot=request.token_mode.value,
116+
request_options = self.__get_headers()
108117
)
109118
log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger())
110119
update_response = parse_update_record_response(api_response)
@@ -124,7 +133,8 @@ def delete(self, request: DeleteRequest):
124133
api_response = records_api.record_service_bulk_delete_record(
125134
self.__vault_client.get_vault_id(),
126135
request.table,
127-
skyflow_ids=request.ids
136+
skyflow_ids=request.ids,
137+
request_options=self.__get_headers()
128138
)
129139
log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger())
130140
delete_response = parse_delete_response(api_response)
@@ -154,6 +164,7 @@ def get(self, request: GetRequest):
154164
download_url=request.download_url,
155165
column_name=request.column_name,
156166
column_values=request.column_values,
167+
request_options=self.__get_headers()
157168
)
158169
log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger())
159170
get_response = parse_get_response(api_response)
@@ -172,7 +183,8 @@ def query(self, request: QueryRequest):
172183
log_info(SkyflowMessages.Info.QUERY_TRIGGERED.value, self.__vault_client.get_logger())
173184
api_response = query_api.query_service_execute_query(
174185
self.__vault_client.get_vault_id(),
175-
query=request.query
186+
query=request.query,
187+
request_options=self.__get_headers()
176188
)
177189
log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger())
178190
query_response = parse_query_response(api_response)
@@ -199,7 +211,8 @@ def detokenize(self, request: DetokenizeRequest):
199211
api_response = tokens_api.record_service_detokenize(
200212
self.__vault_client.get_vault_id(),
201213
detokenization_parameters=tokens_list,
202-
continue_on_error = request.continue_on_error
214+
continue_on_error = request.continue_on_error,
215+
request_options=self.__get_headers()
203216
)
204217
log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger())
205218
detokenize_response = parse_detokenize_response(api_response)
@@ -223,7 +236,8 @@ def tokenize(self, request: TokenizeRequest):
223236
log_info(SkyflowMessages.Info.TOKENIZE_TRIGGERED.value, self.__vault_client.get_logger())
224237
api_response = tokens_api.record_service_tokenize(
225238
self.__vault_client.get_vault_id(),
226-
tokenization_parameters=records_list
239+
tokenization_parameters=records_list,
240+
request_options=self.__get_headers()
227241
)
228242
tokenize_response = parse_tokenize_response(api_response)
229243
log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger())

tests/vault/controller/test__vault.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)
7474

7575
# Assertions
7676
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
77-
records_api.with_raw_response.record_service_batch_operation.assert_called_once_with(
78-
VAULT_ID,
79-
records=expected_body,
80-
continue_on_error=True,
81-
byot="DISABLE"
82-
)
8377
mock_parse_response.assert_called_once_with(mock_api_response, True)
8478

8579
# Assert that the result matches the expected InsertResponse
@@ -125,15 +119,6 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val
125119

126120
# Assertions
127121
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
128-
records_api.with_raw_response.record_service_insert_record.assert_called_once_with(
129-
VAULT_ID,
130-
TABLE_NAME,
131-
records=expected_body,
132-
tokenization=True,
133-
upsert=None,
134-
homogeneous=True,
135-
byot='DISABLE'
136-
)
137122
mock_parse_response.assert_called_once_with(mock_api_response, False)
138123

139124
# Assert that the result matches the expected InsertResponse
@@ -192,15 +177,6 @@ def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock
192177

193178
# Assertions
194179
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
195-
records_api.with_raw_response.record_service_insert_record.assert_called_once_with(
196-
VAULT_ID,
197-
TABLE_NAME,
198-
records=expected_body,
199-
tokenization=True,
200-
upsert=None,
201-
homogeneous=True,
202-
byot='DISABLE'
203-
)
204180
mock_parse_response.assert_called_once_with(mock_api_response, False)
205181

206182
# Assert that the result matches the expected InsertResponse
@@ -243,14 +219,6 @@ def test_update_successful(self, mock_parse_response, mock_validate):
243219

244220
# Assertions
245221
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
246-
records_api.record_service_update_record.assert_called_once_with(
247-
VAULT_ID,
248-
TABLE_NAME,
249-
id="12345",
250-
record=expected_record,
251-
tokenization=True,
252-
byot="DISABLE"
253-
)
254222
mock_parse_response.assert_called_once_with(mock_api_response)
255223

256224
# Check that the result matches the expected UpdateResponse
@@ -301,11 +269,6 @@ def test_delete_successful(self, mock_parse_response, mock_validate):
301269

302270
# Assertions
303271
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
304-
records_api.record_service_bulk_delete_record.assert_called_once_with(
305-
VAULT_ID,
306-
TABLE_NAME,
307-
skyflow_ids=["12345", "67890"]
308-
)
309272
mock_parse_response.assert_called_once_with(mock_api_response)
310273

311274
# Check that the result matches the expected DeleteResponse
@@ -379,10 +342,6 @@ def test_get_successful(self, mock_parse_response, mock_validate):
379342

380343
# Assertions
381344
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
382-
records_api.record_service_bulk_get_record.assert_called_once_with(
383-
VAULT_ID,
384-
**expected_payload
385-
)
386345
mock_parse_response.assert_called_once_with(mock_api_response)
387346

388347
# Check that the result matches the expected GetResponse
@@ -435,7 +394,6 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid
435394
# Assertions
436395
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
437396
records_api.record_service_bulk_get_record.assert_called_once()
438-
mock_parse_response.assert_called_once_with(mock_api_response)
439397

440398
# Check that the result matches the expected GetResponse
441399
self.assertEqual(result.data, expected_data)
@@ -485,11 +443,6 @@ def test_query_successful(self, mock_parse_response, mock_validate):
485443

486444
# Assertions
487445
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
488-
query_api.query_service_execute_query.assert_called_once_with(
489-
VAULT_ID,
490-
query="SELECT * FROM test_table"
491-
)
492-
mock_parse_response.assert_called_once_with(mock_api_response)
493446

494447
# Check that the result matches the expected QueryResponse
495448
self.assertEqual(result.fields, expected_fields)
@@ -554,11 +507,6 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate):
554507

555508
# Assertions
556509
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
557-
tokens_api.with_raw_response.record_service_detokenize.assert_called_once_with(
558-
VAULT_ID,
559-
detokenization_parameters=expected_tokens_list,
560-
continue_on_error=False
561-
)
562510
mock_parse_response.assert_called_once_with(mock_api_response)
563511

564512
# Check that the result matches the expected DetokenizeResponse
@@ -630,10 +578,6 @@ def test_tokenize_successful(self, mock_parse_response, mock_validate):
630578

631579
# Assertions
632580
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
633-
tokens_api.record_service_tokenize.assert_called_once_with(
634-
VAULT_ID,
635-
tokenization_parameters=expected_records_list
636-
)
637581
mock_parse_response.assert_called_once_with(mock_api_response)
638582

639583
# Check that the result matches the expected TokenizeResponse

0 commit comments

Comments
 (0)