Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/dflow/argo_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import shutil
import tempfile
import time
from collections import UserDict, UserList
from copy import deepcopy
Expand All @@ -13,7 +12,7 @@
from .config import config, s3_config
from .io import S3Artifact
from .op_template import get_k8s_client
from .utils import download_artifact, get_key, upload_s3
from .utils import TempDir, download_artifact, get_key, upload_s3

try:
import kubernetes
Expand Down Expand Up @@ -82,7 +81,7 @@ def __getattr__(self, key):
if ((key == "value" and "value" not in self.data) or
(key == "type" and "type" not in self.data)) and \
hasattr(self, "save_as_artifact"):
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
try:
download_artifact(self, path=tmpdir)
fs = os.listdir(tmpdir)
Expand Down Expand Up @@ -163,7 +162,7 @@ def modify_output_parameter(
self.outputs.parameters[name].value = jsonpickle.dumps(value)

if hasattr(self.outputs.parameters[name], "save_as_artifact"):
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
path = tmpdir + "/" + name
with open(path, "w") as f:
f.write(jsonpickle.dumps(value))
Expand Down
5 changes: 2 additions & 3 deletions src/dflow/io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
import tempfile
from collections import UserDict
from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Union

from .common import (CustomArtifact, HTTPArtifact, LocalArtifact, S3Artifact,
jsonpickle, param_errmsg, param_regex)
from .config import config
from .utils import randstr, s3_config, upload_s3
from .utils import TempDir, randstr, s3_config, upload_s3

try:
from argo.workflows.client import (V1alpha1ArchiveStrategy,
Expand Down Expand Up @@ -517,7 +516,7 @@ def convert_to_argo(self):
path=self.path,
_from=str(self.value))
else:
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
path = tmpdir + "/" + self.name
with open(path, "w") as f:
f.write(jsonpickle.dumps(self.value))
Expand Down
16 changes: 12 additions & 4 deletions src/dflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
pass


class TempDir(tempfile.TemporaryDirectory):
def cleanup(self):
try:
return super().cleanup()
except Exception:
pass


def get_key(artifact, raise_error=True):
if hasattr(artifact, "s3") and hasattr(artifact.s3, "key"):
return artifact.s3.key
Expand Down Expand Up @@ -123,7 +131,7 @@ def download_artifact(
if key[-4:] == ".tgz" and extract:
path = os.path.join(path, os.path.basename(key))
tf = tarfile.open(path, "r:gz")
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
tf.extractall(tmpdir)
tf.close()

Expand Down Expand Up @@ -219,7 +227,7 @@ def upload_artifact(
if archive == "default":
archive = config["archive_mode"]
cwd = os.getcwd()
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
if isinstance(path, dict) or (isinstance(path, list) and any(
[isinstance(p, (list, dict)) for p in path])):
pairs = flatten(path).items()
Expand Down Expand Up @@ -320,7 +328,7 @@ def copy_artifact(src, dst, sort=False, **kwargs) -> S3Artifact:
key=lambda item: item["order"])["order"] + 1
for item in src_catalog:
item["order"] += offset
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
catalog_dir = os.path.join(tmpdir, config["catalog_dir_name"])
os.makedirs(catalog_dir, exist_ok=True)
fpath = os.path.join(catalog_dir, str(uuid.uuid4()))
Expand Down Expand Up @@ -487,7 +495,7 @@ def catalog_of_artifact(art, storage_client=None, **kwargs) -> List[dict]:
else:
client = MinioClient(**kwargs)
catalog = []
with tempfile.TemporaryDirectory() as tmpdir:
with TempDir() as tmpdir:
objs = client.list(prefix=key)
if len(objs) == 1 and objs[0][-1] == "/":
key = objs[0]
Expand Down
Loading