Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
python: ['3.13']
transformers: ['5.0', 'main']
transformers: ['5.1.0', 'main']
torch: ['2.10', 'main']

steps:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.0', 'main']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.1.0', 'main']
torch: ['2.10', 'main']
exclude:
# 3.10 - torch
Expand All @@ -29,7 +29,7 @@ jobs:
- python: '3.10'
transformers: '4.57.6'
- python: '3.10'
transformers: '5.0'
transformers: '5.1.0'
- python: '3.10'
transformers: 'main'
# 3.11 - torch
Expand All @@ -41,7 +41,7 @@ jobs:
- python: '3.11'
transformers: '4.57.6'
- python: '3.11'
transformers: '5.0'
transformers: '5.1.0'
- python: '3.11'
transformers: 'main'
# 3.13 - torch
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.9.1
+++++

* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0, 5.2.0 (see https://github.com/huggingface/transformers/pull/43765/)

0.9.0
+++++

Expand Down
12 changes: 11 additions & 1 deletion _doc/status/patches_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,21 @@ Those two versions leads to the following list of patches.
patch_details=details,
):
pass
done = set()
for patch in details.patched:
print(f"* {patch.family} - {getattr(patch.function_to_patch, '__name__', patch.function_to_patch)}")
if patch.function_to_patch == patch.patch:
continue
if patch.refid in done:
continue
done.add(patch.refid)
print(f"* :ref:`{patch.refid}`")
print()
print()
done = set()
for patch in details.patched:
if patch.refid in done:
continue
done.add(patch.refid)
if patch.function_to_patch == patch.patch:
continue
rst = patch.format_diff(format="rst")
Expand Down
2 changes: 1 addition & 1 deletion _scripts/test_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):

backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)")

if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"):
if pv.Version(onnxruntime.__version__) <= pv.Version("1.25"):
backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
)


if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"):
if pv.Version(onnxruntime.__version__) <= pv.Version("1.25"):
backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)")

# import all test cases at global scope to make them visible to python.unittest
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self):
model(**inputs)
model(**data["inputs2"])
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
if not has_transformers("5.0.99"):
if not has_transformers("5.2.99"):
raise unittest.SkipTest("The model has control flow.")
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
Expand Down
6 changes: 3 additions & 3 deletions _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestTasksImageTextToText(ExtTestCase):
@hide_stdout()
@requires_transformers("5.0.99")
@requires_transformers("5.2.99")
@requires_torch("2.7.99")
def test_image_text_to_text_idefics(self):
mid = "HuggingFaceM4/tiny-random-idefics"
Expand All @@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self):
self.assertEqualAny(expected, ep.module()(**inputs), atol=1)

@hide_stdout()
@requires_transformers("5.0.99")
@requires_transformers("5.2.99")
@requires_torch("2.7.99")
def test_image_text_to_text_tiny_gemma3(self):
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self):
self.assertEqualAny(expected, ep.module()(**inputs))

@hide_stdout()
@requires_transformers("5.0.99")
@requires_transformers("5.2.99")
@requires_torch("2.7.99")
def test_image_text_to_text_zai_glm(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def test_plug_multi_head_attention_qwen25_packed_float16(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

@requires_onnxruntime("1.24")
@requires_onnxruntime("1.25")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
Expand Down Expand Up @@ -738,7 +738,7 @@ def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

@requires_onnxruntime("1.24")
@requires_onnxruntime("1.25")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
Expand Down Expand Up @@ -773,7 +773,7 @@ def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
self.assertLess(results.diffs[0]["abs"], 1e-5)

@requires_onnxruntime("1.24")
@requires_onnxruntime("1.25")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
Expand Down Expand Up @@ -801,7 +801,7 @@ def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
self.assertLess(results.diffs[0]["abs"], 0.005)

@requires_onnxruntime("1.24")
@requires_onnxruntime("1.25")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_onnx/test_discrepancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def qwen_sdpa_attention(
return attn_output

for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]:
if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"):
if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.25"):
# not available
continue
with self.subTest(model=model_name):
Expand Down
9 changes: 9 additions & 0 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,15 @@ def torch_deepcopy(value: Any) -> Any:
if value.__class__.__name__ == "DynamicCache":
from .cache_helper import CacheKeyValue

if (
hasattr(value, "layers")
and len(value.layers) == 1
and value.layers[0].keys is None
):
import transformers

return transformers.cache_utils.DynamicCache(None)

ca = CacheKeyValue(value)
pairs = list(zip(ca.key_cache, ca.value_cache))
assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
Expand Down
4 changes: 2 additions & 2 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def patch_module_or_classes(
if isinstance(mod, list):
to_patch = mod
name = "list"
list_name = "auto/list"
list_name = "_PATCHED_list"
else:
name, to_patch = get_patches(mod, verbose)
list_name = f"auto/{mod.__name__.split('.')[-1]}"
list_name = f"_PATCHED_{mod.__name__.split('.')[-1]}"

res = {}
for cls in to_patch:
Expand Down
16 changes: 16 additions & 0 deletions onnx_diagnostic/torch_export_patches/patch_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def make_diff(self) -> str:
def function_name(cls, f: Callable) -> str:
return f.__qualname__

@property
def refid(self) -> str:
kind = self.family or ""
patch_name = (
self.function_name(self.patch)
.replace(".", "-")
.replace("/", "-")
.replace(">", "")
.replace("<", "")
)
return f"patch-{kind}-{patch_name}"

def format_diff(self, format: str = "raw") -> str:
"""
Format a diff between two function as a string.
Expand Down Expand Up @@ -149,11 +161,15 @@ def format_diff(self, format: str = "raw") -> str:
else self.function_name(self.function_to_patch)
)
patch_name = self.function_name(self.patch)
kind = kind.replace("_PATCHED_", "")
title = f"{kind}{function_to_pach_name} -> {patch_name}"
if format == "raw":
return f"{title}\n{diff}"

rows = [
"",
f".. _{self.refid}:",
"",
title,
"=" * len(title),
"",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
try:
import transformers.utils.output_capturing # noqa: F401

patch_output_capturing = True
except ImportError:
patch_output_capturing = False


if patch_output_capturing:
# Introduced in 5.2.0
# https://github.com/huggingface/transformers/pull/43765/
# changes#diff-b5f9fdbe43ffd89fbdf2b246dc78dd32aa4bdb587e7a53e4dad37b7efd79ab0a
import torch
import transformers
from transformers.utils.import_utils import is_torchdynamo_compiling

class patched_CompileableContextVar:
_PATCHES_ = ["set"]
_PATCHED_CLASS_ = transformers.utils.output_capturing.CompileableContextVar

def set(self, value):
if is_torchdynamo_compiling() and not torch.compiler.is_exporting():
self.global_var = value
self.compiling = True
return None
else:
return self.context_var.set(value)
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
patched_sdpa_mask_recent_torch,
)

from ._patch_transformers_output_capturing import patch_output_capturing

if patch_output_capturing:
from ._patch_transformers_output_capturing import patched_CompileableContextVar


# transformers models dependent patches

if _has_transformers("4.51"):
Expand Down
Loading