Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 82 additions & 72 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def _create_s3_client(bucket, credentials=None):


def _get_data_from_bucket(object_key, bucket, client):
"""Get a file from an S3 bucket as a bytes object.

Args:
object_key (str):
The key of the object to get.
bucket (str):
The name of the bucket to get the object from.
client (botocore.client.S3):
S3 client.

Returns:
bytes:
The file data from the S3 object as a bytes object.
"""
response = client.get_object(Bucket=bucket, Key=object_key)
return response['Body'].read()

Expand Down Expand Up @@ -271,11 +285,31 @@ def wrapper(*args, **kwargs):


def _download(modality, dataset_name, bucket, credentials=None):
"""Download dataset resources from a bucket.
"""Download dataset resources from S3 bucket and return the bytes.

Args:
modality (str):
The modality of the dataset: ``'single_table'``, ``'multi_table'``,
``'sequential'``.
dataset_name (str):
The name of the dataset.
bucket (str):
The name of the bucket to download from.
credentials (dict or None):
Dictionary containing DataCebo license key and username. It takes the form:
{ 'username': 'example@datacebo.com', 'license_key': '<MY_LICENSE_KEY>' }
If None, the function will use the default credentials.

Returns:
tuple:
(BytesIO(zip_bytes), metadata_bytes)
tuple[BytesIO, bytes]:
(data_bytes, metadata_bytes)
The data is bytes of the ``data.zip`` and
``metadata_bytes`` is the raw bytes of the metadata JSON.

Raises:
DemoResourceNotFoundError:
If the dataset prefix is missing in the bucket, if ``data.zip`` is
missing, or if no V1 metadata file is present.
"""
client = _create_s3_client(bucket=bucket, credentials=credentials)
dataset_prefix = f'{modality}/{dataset_name}/'
Expand All @@ -286,86 +320,64 @@ def _download(modality, dataset_name, bucket, credentials=None):
)
contents = _list_objects(dataset_prefix, bucket=bucket, client=client)
zip_key = _find_data_zip_key(contents, dataset_prefix, bucket)
zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client)
data_bytes = io.BytesIO(_get_data_from_bucket(zip_key, bucket=bucket, client=client))
metadata_bytes = _get_first_v1_metadata_bytes(
contents, dataset_prefix, bucket=bucket, client=client
)

return io.BytesIO(zip_bytes), metadata_bytes
return data_bytes, metadata_bytes


def _extract_data(bytes_io, output_folder_name):
with ZipFile(bytes_io) as zf:
if output_folder_name:
os.makedirs(output_folder_name, exist_ok=True)
zf.extractall(output_folder_name)

else:
in_memory_directory = {}
for name in zf.namelist():
in_memory_directory[name] = zf.read(name)
def _load_data_from_zip(zip_bytes, bucket, dataset_name, output_folder_name=None):
"""Load CSV tables from in-memory zip bytes into a dict of DataFrames.

return in_memory_directory
This function iterates over the zip bytes and parses each CSV entry with
``pandas.read_csv``. Non-CSV entries are recorded as skipped. When
``output_folder_name`` is provided, the archive is also extracted to disk
as a side effect so the caller keeps a local copy.

Args:
zip_bytes (io.BytesIO):
File-like object containing the bytes of ``data.zip``.
bucket (str):
The name of the bucket the zip was downloaded from. Used only for
error messages.
dataset_name (str):
The name of the dataset. Used only for error messages.
output_folder_name (str or None):
Optional folder path where the zip will also be extracted to disk so
the user keeps a local copy. The returned data dict is always built
in-memory. If ``None``, no folder is created.

def _get_data_with_output_folder(output_folder_name):
"""Load CSV tables from an extracted folder on disk.
Returns:
dict[str, pandas.DataFrame]:
Mapping of table name to DataFrame.

Returns a tuple of (data_dict, skipped_files).
Non-CSV files are ignored.
Raises:
DemoResourceNotFoundError:
If the zip contains no valid CSV entries.
"""
data = {}
skipped_files = []
for root, _dirs, files in os.walk(output_folder_name):
for filename in files:
with ZipFile(zip_bytes, 'r') as z:
if output_folder_name:
os.makedirs(output_folder_name, exist_ok=True)
z.extractall(output_folder_name)

for filename in z.namelist():
if not filename.lower().endswith('.csv'):
skipped_files.append(filename)
continue

table_name = Path(filename).stem
data_path = os.path.join(root, filename)
try:
data[table_name] = pd.read_csv(data_path)
with z.open(filename) as f:
data[table_name] = pd.read_csv(f, low_memory=False)
except UnicodeDecodeError:
data[table_name] = pd.read_csv(data_path, encoding=FALLBACK_ENCODING)
with z.open(filename) as f:
data[table_name] = pd.read_csv(f, low_memory=False, encoding=FALLBACK_ENCODING)
except Exception as e:
rel = os.path.relpath(data_path, output_folder_name)
skipped_files.append(f'{rel}: {e}')

return data, skipped_files


def _get_data_without_output_folder(in_memory_directory):
"""Load CSV tables directly from in-memory zip contents.

Returns a tuple of (data_dict, skipped_files).
Non-CSV entries are ignored.
"""
data = {}
skipped_files = []
for filename, file_ in in_memory_directory.items():
if not filename.lower().endswith('.csv'):
skipped_files.append(filename)
continue

table_name = Path(filename).stem
try:
data[table_name] = pd.read_csv(io.BytesIO(file_), low_memory=False)
except UnicodeDecodeError:
data[table_name] = pd.read_csv(
io.BytesIO(file_), low_memory=False, encoding=FALLBACK_ENCODING
)
except Exception as e:
skipped_files.append(f'{filename}: {e}')

return data, skipped_files


def _get_data(modality, output_folder_name, in_memory_directory, bucket, dataset_name):
if output_folder_name:
data, skipped_files = _get_data_with_output_folder(output_folder_name)
else:
data, skipped_files = _get_data_without_output_folder(in_memory_directory)
skipped_files.append(f'{filename}: {e}')

if skipped_files:
warnings.warn('Skipped files: ' + ', '.join(sorted(skipped_files)))
Expand All @@ -376,9 +388,6 @@ def _get_data(modality, output_folder_name, in_memory_directory, bucket, dataset
'The dataset is missing `csv` file/s.'
)

if modality != 'multi_table':
data = data.popitem()[1]

return data


Expand Down Expand Up @@ -464,17 +473,18 @@ def download_demo(
_validate_modalities(modality)
_validate_output_folder(output_folder_name)

data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
in_memory_directory = _extract_data(data_io, output_folder_name)
data = _get_data(
data, metadata_bytes = _download(
modality,
output_folder_name,
in_memory_directory,
s3_bucket_name,
dataset_name,
s3_bucket_name,
credentials,
)
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
data = _load_data_from_zip(data, s3_bucket_name, dataset_name, output_folder_name)

if modality != 'multi_table':
data = data.popitem()[1]

metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
return data, metadata


Expand Down
59 changes: 53 additions & 6 deletions tests/integration/datasets/test_demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pandas as pd
import pytest

from sdv.datasets.demo import download_demo, get_available_demos
Expand Down Expand Up @@ -28,36 +29,82 @@ def test_get_available_demos_multi_table():

@pytest.mark.parametrize('output_path', [None, 'tmp_path'])
def test_download_demo_single_table(output_path, tmp_path):
"""Test that the `download_demo` function works as intended for single-table."""
# Run
"""Test `download_demo` function works for single-table."""
# Setup
output_folder_name = tmp_path / 'sdv' if output_path else None

# Run
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests',
output_folder_name=output_folder_name,
)

# Assert
assert isinstance(metadata, Metadata)
metadata.validate()
assert isinstance(data, pd.DataFrame)
metadata.validate_data({'fake_hotel_guests': data})
assert len(data) > 1
assert isinstance(metadata, Metadata)
if output_folder_name:
assert (output_folder_name / 'metadata.json').is_file()
csv_files = list((output_folder_name / 'data').glob('*.csv'))
assert len(csv_files) == 1
assert csv_files[0].name == 'fake_hotel_guests.csv'


@pytest.mark.parametrize('output_path', [None, 'tmp_path'])
def test_download_demo_multi_table(output_path, tmp_path):
"""Test that the `download_demo` function works as intended for multi-table."""
# Run
"""Test `download_demo` function works for multi-table."""
# Setup
output_folder_name = tmp_path / 'sdv' if output_path else None

# Run
data, metadata = download_demo(
modality='multi_table',
dataset_name='fake_hotels',
output_folder_name=output_folder_name,
)

# Assert
assert isinstance(metadata, Metadata)
metadata.validate()
assert isinstance(data, dict)
metadata.validate_data(data)
expected_tables = ['hotels', 'guests']
assert set(expected_tables) == set(data)
assert isinstance(metadata, Metadata)
assert len(data['hotels']) > 1
assert len(data['guests']) > 1
if output_folder_name is not None:
assert (output_folder_name / 'metadata.json').is_file()
csv_files = list((output_folder_name / 'data').glob('*.csv'))
csv_files = [f.name for f in csv_files]
assert len(csv_files) == 2
assert 'hotels.csv' in csv_files
assert 'guests.csv' in csv_files


@pytest.mark.parametrize('output_path', [None, 'tmp_path'])
def test_download_demo_sequential(output_path, tmp_path):
"""Test `download_demo` function works for sequential."""
# Setup
output_folder_name = tmp_path / 'sdv' if output_path else None

# Run
data, metadata = download_demo(
modality='sequential',
dataset_name='ArticularyWordRecognition',
output_folder_name=output_folder_name,
)

# Assert
assert isinstance(metadata, Metadata)
metadata.validate()
metadata = metadata._convert_to_single_table()
metadata.validate_data(data)
assert len(data) > 1
if output_folder_name:
assert (output_folder_name / 'metadata.json').is_file()
csv_files = list((output_folder_name / 'data').glob('*.csv'))
assert len(csv_files) == 1
assert csv_files[0].name == 'ArticularyWordRecognition.csv'
Loading
Loading