Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flytekit.core import utils
from flytekit.core.base_task import IgnoreOutputs, PythonTask
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.constants import FLYTE_FAIL_ON_ERROR
from flytekit.core.constants import FLYTE_FAIL_ON_ERROR, RUNTIME_PACKAGES_ENV_NAME
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
Expand Down Expand Up @@ -63,6 +63,18 @@ def get_version_message():
return f"Welcome to Flyte! Version: {flytekit.__version__}"


def _run_subprocess(cmd: List[str], env: Optional[dict] = None) -> int:
"""Run cmd with proper SIGTERM handling."""
p = subprocess.Popen(cmd, env=env)

def handle_sigterm(signum, frame):
logger.info(f"passing signum {signum} [frame={frame}] to subprocess")
p.send_signal(signum)

signal.signal(signal.SIGTERM, handle_sigterm)
return p.wait()


def _compute_array_job_index():
"""
Computes the absolute index of the current array job. This is determined by summing the compute-environment-specific
Expand Down Expand Up @@ -432,6 +444,14 @@ def setup_execution(

compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

if runtime_packages := os.getenv(RUNTIME_PACKAGES_ENV_NAME):
import importlib
import site

dev_packages_list = runtime_packages.split(" ")
_run_subprocess([sys.executable, "-m", "pip", "install", *dev_packages_list])
importlib.reload(site)

ctx = FlyteContextManager.current_context()
# Create directories
user_workspace_dir = ctx.file_access.get_random_local_directory()
Expand Down Expand Up @@ -751,14 +771,7 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec
env["PYTHONPATH"] += os.pathsep + dest_dir_resolved
else:
env["PYTHONPATH"] = dest_dir_resolved
p = subprocess.Popen(cmd, env=env)

def handle_sigterm(signum, frame):
logger.info(f"passing signum {signum} [frame={frame}] to subprocess")
p.send_signal(signum)

signal.signal(signal.SIGTERM, handle_sigterm)
returncode = p.wait()
returncode = _run_subprocess(cmd, env)
exit(returncode)


Expand Down
6 changes: 5 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,15 @@ def run_remote(
cluster_pool=run_level_params.cluster_pool,
execution_cluster_label=run_level_params.execution_cluster_label,
)
additional_info_for_execution = get_plugin().get_additional_info_for_execution(
remote.generate_console_http_domain(), execution
)
s = (
click.style("\n[✔] ", fg="green")
+ "Go to "
+ click.style(execution.execution_url, fg="cyan")
+ " to see execution in the console."
+ " to see the execution in the console."
+ additional_info_for_execution
)
click.echo(s)

Expand Down
26 changes: 24 additions & 2 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
```
"""

import os
from typing import List, Optional, Protocol, runtime_checkable
from typing import List, Optional, Protocol, Union, runtime_checkable

from click import Group
from importlib_metadata import entry_points

from flytekit import CachePolicy
from flytekit.configuration import Config, get_config_file
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.workflow import WorkflowBase
from flytekit.loggers import logger
from flytekit.remote import FlyteRemote
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution


@runtime_checkable
Expand Down Expand Up @@ -58,6 +60,15 @@ def get_auth_success_html(endpoint: str) -> Optional[str]:
def get_default_cache_policies() -> List[CachePolicy]:
"""Get default cache policies for tasks."""

@staticmethod
def get_additional_context_for_version_hash(entity: Union[PythonAutoContainerTask, WorkflowBase]) -> List[bytes]:
"""Get additional context to be used for calculating the version hash."""

def get_additional_info_for_execution(
console_http_domain: str, entity: Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution]
) -> str:
"""Get additional info for a given execution. Useful to pass in additional urls."""


class FlytekitPlugin:
@staticmethod
Expand Down Expand Up @@ -113,6 +124,17 @@ def get_default_cache_policies() -> List[CachePolicy]:
"""Get default cache policies for tasks."""
return []

@staticmethod
def get_additional_context_for_version_hash(entity: Union[PythonAutoContainerTask, WorkflowBase]) -> List[bytes]:
"""Get additional context to be used for calculating the version hash."""
return []

def get_additional_info_for_execution(
console_http_domain: str, entity: Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution]
) -> str:
"""Get additional info for a given execution. Useful to pass in additional urls."""
return ""


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@
# Shared memory mount name and path
SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory"
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"

# Packages to be installed at the beginning of runtime
RUNTIME_PACKAGES_ENV_NAME = "_F_RUNTIME_PACKAGES"
7 changes: 7 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.constants import CopyFileDetection
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.constants import RUNTIME_PACKAGES_ENV_NAME
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources
Expand Down Expand Up @@ -231,6 +232,12 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
for elem in (settings.env, self.environment):
if elem:
env.update(elem)

# Add runtime dependencies into environment
if isinstance(self.container_image, ImageSpec) and self.container_image.runtime_packages:
runtime_packages = " ".join(self.container_image.runtime_packages)
env[RUNTIME_PACKAGES_ENV_NAME] = runtime_packages

return _get_container_definition(
image=self.get_image(settings),
resource_spec=self.resources,
Expand Down
48 changes: 48 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import mimetypes
import os
import re
import sys
import textwrap
import threading
Expand Down Expand Up @@ -102,6 +103,53 @@ def get_batch_size(t: Type) -> Optional[int]:
return None


class FileDownloadConfig:
"""
This is used to annotate a FlyteFile when we want to download the file with a specific extension. For example,

```python
# ContainerTask
def t1(file: Annotated[FlyteFile, FileDownloadConfig(file_extension="csv")]):
... # copilot downloads the file to e.g. /inputs/file.csv

versus...

def t1(file: FlyteFile["csv"]):
... # copilot downloads the file to e.g. /inputs/file
```

file_extension: (Default is "") The file extension (e.g. "csv", "parquet") to use during copilot download.
enable_legacy_filename: (Default is False) When true and file_extension is non-empty, the copilot download phase
writes the blob to both the full path (with extension) and the old path (without extension), preserving backward compatibility for
workflows with tasks that may read from both.
"""

def __init__(self, file_extension: str = "", enable_legacy_filename: bool = False):
self._file_extension = file_extension
self._enable_legacy_filename = enable_legacy_filename

if self._file_extension is not "":
pattern = r"^[a-zA-Z0-9]+(\.[a-zA-Z0-9]+)*$"
if not re.match(pattern, self._file_extension):
raise ValueError(f"Invalid file extension: {self._file_extension}")

@property
def file_extension(self) -> str:
return self._file_extension

@property
def enable_legacy_filename(self) -> bool:
return self._enable_legacy_filename


def get_file_download_config(t: Type) -> Optional[FileDownloadConfig]:
if is_annotated(t):
for arg in get_args(t):
if isinstance(arg, FileDownloadConfig):
return arg
return None


def modify_literal_uris(lit: Literal):
"""
Modifies the literal object recursively to replace the URIs with the native paths in case they are of
Expand Down
3 changes: 3 additions & 0 deletions flytekit/image_spec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from .default_builder import DefaultImageBuilder
from .image_spec import ImageBuildEngine, ImageSpec
from .noop_builder import NoOpBuilder

# Set this to a lower priority compared to `envd` to maintain backward compatibility
ImageBuildEngine.register(DefaultImageBuilder.builder_type, DefaultImageBuilder(), priority=1)
# Lower priority compared to Default.
ImageBuildEngine.register(NoOpBuilder.builder_type, NoOpBuilder(), priority=0)
Loading