Skip to content

Commit e84839e

Browse files
committed
Add dict input for pad2d operator
1 parent 0a2c9a2 commit e84839e

2 files changed

Lines changed: 247 additions & 2 deletions

File tree

src/xtc/graphs/xtc/operators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,20 @@ def __init__(self, **attrs: XTCOperatorAttr) -> None:
452452
if isinstance(padding, int):
453453
padding = {axes[0]: (padding, padding), axes[1]: (padding, padding)}
454454
else:
455-
assert isinstance(padding, tuple), (
456-
f"padding for pad2d of wrong type, expect int or tuple: {padding}"
455+
assert isinstance(padding, (tuple, dict)), (
456+
f"padding for pad2d of wrong type, expect int or tuple or dict: {padding}"
457457
)
458458
if len(padding) == 1:
459459
padding = {
460460
axes[0]: (padding[0], padding[0]),
461461
axes[1]: (padding[0], padding[0]),
462462
}
463+
elif isinstance(padding, dict) and len(padding) == 2:
464+
padding = {
465+
i: (pad, pad) if isinstance(pad, int) else pad
466+
for i, pad in padding.items()
467+
}
468+
pass
463469
elif all(isinstance(pad, int) for pad in padding) and len(padding) == 2:
464470
padding = {
465471
axes[0]: (padding[0], padding[1]),

0 commit comments

Comments
 (0)