Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions tensorizer/stream_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import functools
import hashlib
import io
import logging
import os
Expand Down Expand Up @@ -433,26 +435,47 @@ def s3_upload(
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str = default_s3_write_endpoint,
s3_sse_customer_key: Optional[bytes] = None,
s3_sse_customer_algorithm: Optional[str] = None,
):
bucket, key = _parse_s3_uri(target_uri)
client = _new_s3_client(s3_access_key_id, s3_secret_access_key, s3_endpoint)
client.upload_file(path, bucket, key)
extra_args = {}
if s3_sse_customer_key is not None:
extra_args["SSECustomerAlgorithm"] = s3_sse_customer_algorithm
if s3_sse_customer_algorithm is not None:
extra_args["SSECustomerKey"] = s3_sse_customer_key
client.upload_file(path, bucket, key, ExtraArgs=extra_args)


def s3_download(
path_uri: str,
s3_access_key_id: str,
s3_secret_access_key: str,
s3_endpoint: str = default_s3_read_endpoint,
s3_sse_customer_key: Optional[bytes] = None,
s3_sse_customer_algorithm: Optional[str] = None,
) -> CURLStreamFile:
bucket, key = _parse_s3_uri(path_uri)
client = _new_s3_client(s3_access_key_id, s3_secret_access_key, s3_endpoint)
encryption_params = {}
if s3_sse_customer_key is not None:
encryption_params["SSECustomerAlgorithm"] = s3_sse_customer_algorithm
if s3_sse_customer_algorithm is not None:
encryption_params["SSECustomerKey"] = s3_sse_customer_key
url = client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": bucket, "Key": key},
Params={"Bucket": bucket, "Key": key, **encryption_params},
ExpiresIn=300,
)
return CURLStreamFile(url)
request_headers = {}
if s3_sse_customer_algorithm is not None:
request_headers['x-amz-server-side-encryption-customer-algorithm'] = s3_sse_customer_algorithm
if s3_sse_customer_key is not None:
request_headers['x-amz-server-side-encryption-customer-key'] = base64.b64encode(s3_sse_customer_key).decode()
key_md5 = hashlib.md5(s3_sse_customer_key).digest()
request_headers['x-amz-server-side-encryption-customer-key-MD5'] = base64.b64encode(key_md5).decode()
return CURLStreamFile(url, headers=request_headers)


def _infer_credentials(
Expand Down Expand Up @@ -601,6 +624,8 @@ def open_stream(
s3_secret_access_key: Optional[str] = None,
s3_endpoint: Optional[str] = None,
s3_config_path: Optional[Union[str, bytes, os.PathLike]] = None,
s3_sse_customer_key: Optional[bytes] = None,
s3_sse_customer_algorithm: Optional[str] = None,
) -> Union[CURLStreamFile, typing.BinaryIO]:
"""
Open a file path, http(s):// URL, or s3:// URI.
Expand Down Expand Up @@ -638,6 +663,10 @@ def open_stream(
s3_config_path: An explicit path to the `~/.s3cfg` config file
to be parsed if full credentials are not provided.
If None, platform-specific default paths are used.
s3_sse_customer_key: Specifies the customer-provided encryption
key for Amazon S3 to use in encrypting data.
s3_sse_customer_algorithm: Specifies the algorithm to use to
when encrypting the object (for example, AES256).

Returns:
An opened file-like object representing the target resource.
Expand Down Expand Up @@ -753,13 +782,20 @@ def open_stream(
s3_access_key_id,
s3_secret_access_key,
s3_endpoint,
s3_sse_customer_key,
s3_sse_customer_algorithm,
)
temp_file.close = guaranteed_closer
return temp_file
else:
s3_endpoint = s3_endpoint or default_s3_read_endpoint
curl_stream_file = s3_download(
path_uri, s3_access_key_id, s3_secret_access_key, s3_endpoint
path_uri,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
s3_sse_customer_key=s3_sse_customer_key,
s3_sse_customer_algorithm=s3_sse_customer_algorithm,
)
if error_context:
curl_stream_file.register_error_context(error_context)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_stream_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def test_upload(self):
s3_access_key_id=self.ACCESS_KEY,
s3_secret_access_key=self.SECRET_KEY,
s3_endpoint=self.endpoint,
s3_sse_customer_key=os.urandom(32),
s3_sse_customer_algorithm="AES256",
)
long_string = b"Hello" * 1024
s.write(long_string)
Expand All @@ -184,5 +186,7 @@ def test_download(self):
s3_access_key_id="X",
s3_secret_access_key="X",
s3_endpoint=endpoint,
s3_sse_customer_key=os.urandom(32),
s3_sse_customer_algorithm="AES256",
) as s:
self.assertEqual(s.read(), long_string)