-
Notifications
You must be signed in to change notification settings - Fork 338
Add file_extension fields to BlobType #3406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||
| import json | ||||||
| import mimetypes | ||||||
| import os | ||||||
| import re | ||||||
| import sys | ||||||
| import textwrap | ||||||
| import threading | ||||||
|
|
@@ -105,6 +106,44 @@ def get_batch_size(t: Type) -> Optional[int]: | |||||
| return None | ||||||
|
|
||||||
|
|
||||||
| class FileExtension: | ||||||
| """ | ||||||
| This is used to annotate a FlyteFile when we want to download the file with a specific extension. For example, | ||||||
|
|
||||||
| ```python | ||||||
| # ContainerTask | ||||||
| def t1(file: Annotated[FlyteFile, FileExtension("csv")]): | ||||||
| ... # copilot downloads the file to e.g. /inputs/file.csv | ||||||
|
|
||||||
| versus... | ||||||
|
|
||||||
| def t1(file: FlyteFile["csv"]): | ||||||
| ... # copilot downloads the file to e.g. /inputs/file | ||||||
| ``` | ||||||
|
|
||||||
| val: (Default is "") The file extension (e.g. "csv", "parquet") to use during copilot download. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this docstring to val: The file extension (e.g. "csv", "parquet") to use during copilot download. |
||||||
| """ | ||||||
|
|
||||||
| def __init__(self, val: str = ""): | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| self._val = val | ||||||
|
|
||||||
| pattern = r"^[a-zA-Z0-9]+(\.[a-zA-Z0-9]+)*$" | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Could we add a comment saying that this matches single and multi-part file extension (e.g. tar.gz)? |
||||||
| if not re.match(pattern, self._val): | ||||||
| raise ValueError(f"Invalid file extension: {self._val}") | ||||||
|
|
||||||
| @property | ||||||
| def val(self) -> str: | ||||||
| return self._val | ||||||
|
|
||||||
|
|
||||||
| def get_file_extension(t: Type) -> Optional[str]: | ||||||
| if is_annotated(t): | ||||||
| for annotation in get_args(t)[1:]: | ||||||
| if isinstance(annotation, FileExtension): | ||||||
| return annotation.val | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def modify_literal_uris(lit: Literal): | ||||||
| """ | ||||||
| Modifies the literal object recursively to replace the URIs with the native paths in case they are of | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -38,13 +38,16 @@ class BlobDimensionality(object): | |||||
| SINGLE = _types_pb2.BlobType.SINGLE | ||||||
| MULTIPART = _types_pb2.BlobType.MULTIPART | ||||||
|
|
||||||
| def __init__(self, format, dimensionality): | ||||||
| def __init__(self, format, dimensionality, file_extension=""): | ||||||
| """ | ||||||
| :param Text format: A string describing the format of the underlying blob data. | ||||||
| :param int dimensionality: An integer from BlobType.BlobDimensionality enum | ||||||
| :param Text file_extension: The file extension (e.g. "csv", "parquet") to use | ||||||
| during copilot download, e.g. "csv", "parquet". Empty by default. | ||||||
| """ | ||||||
| self._format = format | ||||||
| self._dimensionality = dimensionality | ||||||
| self._file_extension = file_extension | ||||||
|
|
||||||
| @property | ||||||
| def format(self): | ||||||
|
|
@@ -62,16 +65,33 @@ def dimensionality(self): | |||||
| """ | ||||||
| return self._dimensionality | ||||||
|
|
||||||
| @property | ||||||
| def file_extension(self): | ||||||
| """ | ||||||
| The file extension (e.g. "csv", "parquet") to use during copilot download. | ||||||
| Default is "", which means no extension is appended. | ||||||
| :rtype: Text | ||||||
| """ | ||||||
| return self._file_extension | ||||||
|
|
||||||
| def to_flyte_idl(self): | ||||||
| """ | ||||||
| :rtype: flyteidl.core.types_pb2.BlobType | ||||||
| """ | ||||||
| return _types_pb2.BlobType(format=self.format, dimensionality=self.dimensionality) | ||||||
| return _types_pb2.BlobType( | ||||||
| format=self.format, | ||||||
| dimensionality=self.dimensionality, | ||||||
| file_extension=self._file_extension, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit |
||||||
| ) | ||||||
|
|
||||||
| @classmethod | ||||||
| def from_flyte_idl(cls, proto): | ||||||
| """ | ||||||
| :param flyteidl.core.types_pb2.BlobType proto: | ||||||
| :rtype: BlobType | ||||||
| """ | ||||||
| return cls(format=proto.format, dimensionality=proto.dimensionality) | ||||||
| return cls( | ||||||
| format=proto.format, | ||||||
| dimensionality=proto.dimensionality, | ||||||
| file_extension=proto.file_extension, | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| from flytekit.core.hash import HashMethod | ||
| from flytekit.core.launch_plan import LaunchPlan | ||
| from flytekit.core.task import task | ||
| from flytekit.core.type_engine import TypeEngine | ||
| from flytekit.core.type_engine import FileExtension, TypeEngine | ||
| from flytekit.core.workflow import workflow | ||
| from flytekit.models.core.types import BlobType | ||
| from flytekit.models.literals import LiteralMap, Blob, BlobMetadata | ||
|
|
@@ -764,6 +764,32 @@ def test_headers(): | |
| assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1 | ||
|
|
||
|
|
||
| def test_transform_flytefile_with_file_extension(): | ||
| csv_file_no_file_extension = FlyteFile["csv"] | ||
| lt = FlyteFilePathTransformer().get_literal_type(csv_file_no_file_extension) | ||
| assert lt.blob.file_extension == "" | ||
|
|
||
| csv_file_with_file_extension = Annotated[FlyteFile["csv"], FileExtension("csv")] | ||
| lt = FlyteFilePathTransformer().get_literal_type(csv_file_with_file_extension) | ||
| assert lt.blob.file_extension == "csv" | ||
|
|
||
|
|
||
| def test_file_extension_valid_compound_extension(): | ||
| extension = FileExtension("tar.gz") | ||
| assert extension.val == "tar.gz" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("bad_ext", [ | ||
| ".csv", | ||
| "my file", | ||
| "../../escape", | ||
| "csv!", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also add "" here |
||
| ]) | ||
| def test_file_extension_rejects_invalid_extensions(bad_ext): | ||
| with pytest.raises(ValueError, match="Invalid file extension"): | ||
| FileExtension(bad_ext) | ||
|
|
||
|
|
||
| def test_new_remote_file(): | ||
| nf = FlyteFile.new_remote_file(name="foo.txt") | ||
| assert isinstance(nf, FlyteFile) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follows same pattern as BatchSize:
flytekit/flytekit/core/type_engine.py
Line 75 in 71194a4