Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ set(_tvm_ffi_extra_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/deep_copy.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/module.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc"
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/ffi/extra/deep_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ffi/extra/deep_copy.h
* \brief Reflection-based object copy utilities
*/
#ifndef TVM_FFI_EXTRA_DEEP_COPY_H_
#define TVM_FFI_EXTRA_DEEP_COPY_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/extra/base.h>

namespace tvm {
namespace ffi {

/**
* \brief Deep copy an ffi::Any value.
*
* Recursively copies the value and all reachable objects in its object graph.
* Copy-constructible types with `ObjectDef` registration automatically support deep copy.
* Primitive types, strings, and bytes are returned as-is (they are immutable).
* Arrays, Lists, and Maps are recursively deep copied.
* Objects without copy support cause a runtime error.
*
* \param value The value to deep copy.
* \return The deep copied value.
*/
TVM_FFI_EXTRA_CXX_API Any DeepCopy(const Any& value);

} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_EXTRA_DEEP_COPY_H_
28 changes: 27 additions & 1 deletion include/tvm/ffi/reflection/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ struct init {
}
};

/*! \brief Well-known type attribute names used by the reflection system. */
namespace type_attr {
inline constexpr const char* kInit = "__ffi_init__";
inline constexpr const char* kShallowCopy = "__ffi_shallow_copy__";
} // namespace type_attr

/*!
* \brief Helper to register Object's reflection metadata.
* \tparam Class The class type.
Expand All @@ -481,6 +487,7 @@ class ObjectDef : public ReflectionDefBase {
explicit ObjectDef(ExtraArgs&&... extra_args)
: type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {
RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
AutoRegisterCopy();
}

/*!
Expand Down Expand Up @@ -591,6 +598,25 @@ class ObjectDef : public ReflectionDefBase {
template <typename T>
friend class OverloadObjectDef;

/*! \brief Shallow-copy \p self via the C++ copy constructor. */
static ObjectRef ShallowCopy(const Class* self) {
return ObjectRef(ffi::make_object<Class>(*self));
}

void AutoRegisterCopy() {
if constexpr (std::is_copy_constructible_v<Class>) {
// Register __ffi_shallow_copy__ as an instance method
RegisterMethod(type_attr::kShallowCopy, false, &ObjectDef::ShallowCopy);
// Also register as a type attribute for generic deep copy lookup
Function copy_fn = GetMethod(std::string(type_key_) + "." + type_attr::kShallowCopy,
&ObjectDef::ShallowCopy);
TVMFFIByteArray attr_name = {type_attr::kShallowCopy,
std::char_traits<char>::length(type_attr::kShallowCopy)};
TVMFFIAny attr_value = AnyView(copy_fn).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &attr_name, &attr_value));
}
}

template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
TVMFFITypeMetadata info;
Expand Down Expand Up @@ -663,7 +689,7 @@ class ObjectDef : public ReflectionDefBase {

int32_t type_index_;
const char* type_key_;
static constexpr const char* kInitMethodName = "__ffi_init__";
static constexpr const char* kInitMethodName = type_attr::kInit;
};

/*!
Expand Down
31 changes: 18 additions & 13 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,27 @@ def type_info_to_cls(
attrs[field.name] = field.as_property(cls)

# Step 3. Add methods
def _add_method(name: str, func: Callable[..., Any]) -> None:
if name == "__ffi_init__":
name = "__c_ffi_init__"
# Allow overriding methods (including from base classes like Object.__repr__)
# by always adding to attrs, which will be used when creating the new class
func.__module__ = cls.__module__
func.__name__ = name # ty: ignore[unresolved-attribute]
func.__qualname__ = f"{cls.__qualname__}.{name}" # ty: ignore[unresolved-attribute]
func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
attrs[name] = func

for name, method_impl in methods.items():
if method_impl is not None:
_add_method(name, method_impl)
method_impl.__module__ = cls.__module__
method_impl.__name__ = name # ty: ignore[unresolved-attribute]
method_impl.__qualname__ = f"{cls.__qualname__}.{name}" # ty: ignore[unresolved-attribute]
method_impl.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
attrs[name] = method_impl
for method in type_info.methods:
_add_method(method.name, method.func)
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
# as_callable wraps instance methods so `self` is passed to the C++ function,
# and wraps static methods with staticmethod(); it also sets __module__,
# __name__, __qualname__, and __doc__ so we insert directly into attrs.
func = method.as_callable(cls)
if name != method.name:
# Rename was applied (e.g. __ffi_init__ -> __c_ffi_init__)
inner = func.__func__ if isinstance(func, staticmethod) else func
inner.__name__ = name # ty: ignore[invalid-assignment]
inner.__qualname__ = f"{cls.__qualname__}.{name}" # ty: ignore[invalid-assignment]
attrs[name] = func

# Step 4. Create the new class
new_cls = type(cls.__name__, cls_bases, attrs)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
methods={"__init__": fn_init, "__repr__": fn_repr},
)
_set_type_cls(type_info, type_cls)
# Step 4. Set up __copy__, __deepcopy__, __replace__
from ..registry import _setup_copy_methods # noqa: PLC0415

has_shallow_copy = any(m.name == "__ffi_shallow_copy__" for m in type_info.methods)
_setup_copy_methods(type_cls, has_shallow_copy)
return type_cls

return decorator
Expand Down
78 changes: 77 additions & 1 deletion python/tvm_ffi/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import functools
import json
import sys
from typing import Any, Callable, Literal, Sequence, TypeVar, overload
Expand Down Expand Up @@ -335,25 +336,100 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
if not hasattr(type_cls, name): # skip already defined attributes
setattr(type_cls, name, field.as_property(type_cls))
has_c_init = False
has_shallow_copy = False
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
has_c_init = True
if not hasattr(type_cls, name):
if name == "__ffi_shallow_copy__":
has_shallow_copy = True
# Always override: shallow copy is type-specific and must not be inherited
setattr(type_cls, name, method.as_callable(type_cls))
elif not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
if "__init__" not in type_cls.__dict__:
if has_c_init:
setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
elif not issubclass(type_cls, core.PyNativeObject):
setattr(type_cls, "__init__", __init__invalid)
is_container = type_info.type_key in ("ffi.Array", "ffi.Map")
_setup_copy_methods(type_cls, has_shallow_copy, is_container=is_container)
return type_cls


def _setup_copy_methods(
type_cls: type, has_shallow_copy: bool, *, is_container: bool = False
) -> None:
"""Set up __copy__, __deepcopy__, __replace__ based on copy support."""
if has_shallow_copy:
if "__copy__" not in type_cls.__dict__:
setattr(type_cls, "__copy__", _copy_supported)
if "__deepcopy__" not in type_cls.__dict__:
setattr(type_cls, "__deepcopy__", _deepcopy_supported)
if "__replace__" not in type_cls.__dict__:
setattr(type_cls, "__replace__", _replace_supported)
else:
if "__copy__" not in type_cls.__dict__:
setattr(type_cls, "__copy__", _copy_unsupported)
if "__deepcopy__" not in type_cls.__dict__:
# Containers (Array, Map) support deepcopy via ffi.DeepCopy
# even without __ffi_shallow_copy__
if is_container:
setattr(type_cls, "__deepcopy__", _deepcopy_supported)
else:
setattr(type_cls, "__deepcopy__", _deepcopy_unsupported)
if "__replace__" not in type_cls.__dict__:
setattr(type_cls, "__replace__", _replace_unsupported)


def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("The __init__ method of this class is not implemented.")


def _copy_supported(self: Any) -> Any:
return self.__ffi_shallow_copy__()


def _deepcopy_supported(self: Any, memo: Any = None) -> Any:
return _get_deep_copy_func()(self)


@functools.lru_cache(maxsize=1)
def _get_deep_copy_func() -> core.Function:
return get_global_func("ffi.DeepCopy")


def _replace_supported(self: Any, **kwargs: Any) -> Any:
import copy # noqa: PLC0415

obj = copy.copy(self)
for key, value in kwargs.items():
setattr(obj, key, value)
return obj


def _copy_unsupported(self: Any) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support copy. "
f"The underlying C++ type is not copy-constructible."
)


def _deepcopy_unsupported(self: Any, memo: Any = None) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support deepcopy. "
f"The underlying C++ type is not copy-constructible."
)


def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
raise TypeError(
f"Type `{type(self).__name__}` does not support replace. "
f"The underlying C++ type is not copy-constructible."
)


def get_registered_type_keys() -> Sequence[str]:
"""Get the list of valid type keys registered to TVM-FFI.

Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._ffi_api import * # noqa: F403
from .testing import (
TestIntPair,
TestNonCopyable,
TestObjectBase,
TestObjectDerived,
_SchemaAllTypes,
Expand Down
7 changes: 7 additions & 0 deletions python/tvm_ffi/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class TestObjectDerived(TestObjectBase):
# tvm-ffi-stubgen(end)


@register_object("testing.TestNonCopyable")
class TestNonCopyable(Object):
"""Test object with deleted copy constructor."""

value: int


@register_object("testing.SchemaAllTypes")
class _SchemaAllTypes:
# tvm-ffi-stubgen(ty-map): testing.SchemaAllTypes -> testing._SchemaAllTypes
Expand Down
Loading