Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions csrc/polymorphic_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,67 @@

namespace nvfuser {

// Implementation of getDataType - moved from type.h to reduce template bloat.
// This function uses for_all_types which triggers heavy template instantiation.
DataType getDataType(const PolymorphicValue& value) {
std::optional<DataType> dtype = std::nullopt;
PolymorphicValue::for_all_types([&value, &dtype](auto _) {
using T = typename decltype(_)::type;
if constexpr (IsPrimitiveNativeType<T>::value) {
if (value.is<T>()) {
dtype = NativeTypeToDataType<T>::type;
}
} else if constexpr (std::is_same_v<T, std::vector<PolymorphicValue>>) {
if (value.is<T>()) {
const auto& vec = value.as<T>();
size_t size = vec.size();
NVF_CHECK(size > 0, "Empty array is not supported");
dtype =
ArrayType{std::make_shared<DataType>(getDataType(vec[0])), size};
}
} else if constexpr (std::is_same_v<T, Pointer>) {
// For pointers in polymorphic value, we only store the data size of the
// pointee, so it is impossible to infer the pointer type.
NVF_CHECK(!value.is<T>(), "Can not infer pointer type.");
} else if constexpr (std::is_same_v<T, StructHandle>) {
if (value.is<T>()) {
dtype = value.as<T>().type();
}
} else if constexpr (std::is_same_v<T, Opaque>) {
if (value.is<T>()) {
const auto& opaque = value.as<T>();
dtype = DataType(OpaqueType{
.type_info = opaque.any().type(), .size = opaque.size()});
}
}
});
NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name());
return dtype.value();
}

// Implementation of castToDtype - moved from type.h to reduce template bloat.
// This function uses for_all_types which triggers heavy template instantiation.
PolymorphicValue castToDtype(PolymorphicValue value, const DataType& dtype) {
if (!value.hasValue()) {
return value;
}
// Cast the given value to the given data type. This enables interface
// like: IrBuilder::create<Val>(0, DataType::Double) where value is
// an integer but the desired data type is double.
if (!hasCompatibleDataType(value, dtype)) {
PolymorphicValue::for_all_types([&](auto _) {
using T = typename decltype(_)::type;
if constexpr (IsPrimitiveNativeType<T>::value) {
if (isCompatibleDataType(NativeTypeToDataType<T>::type, dtype)) {
value = PolymorphicValue(static_cast<T>(value));
}
}
// TODO: support arrays and pointers
});
}
return value;
}

bool StructHandle::operator==(const StructHandle& other) const {
if (struct_ptr_ == other.struct_ptr_) {
return true;
Expand Down Expand Up @@ -42,6 +103,76 @@ bool StructHandle::operator==(const StructHandle& other) const {

namespace PolymorphicValue_functions {

// Implementation of isSame - moved from polymorphic_value.h to reduce template bloat.
// Uses operator== which triggers ForAllTypes template instantiation.
bool isSame(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a.type() != b.type()) {
return false;
}
if (a.is<at::Tensor>()) {
return (a.as<at::Tensor>().is_same(b.as<at::Tensor>()));
}
if (a.is<double>()) {
return isSameNanSensitive(a.as<double>(), b.as<double>());
}
if (a.is<std::complex<double>>()) {
return isSameNanSensitive(
a.as<std::complex<double>>(), b.as<std::complex<double>>());
}
return a == b;
}

// Implementation of ceildiv - moved from polymorphic_value.h to reduce template bloat.
// Uses operator/ which triggers ForAllTypes template instantiation.
PolymorphicValue ceildiv(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a.is<int64_t>() && b.is<int64_t>()) {
auto aa = a.as<int64_t>();
auto bb = b.as<int64_t>();
if (bb > 0) {
return PolymorphicValue((aa + bb - 1) / bb);
} else {
return PolymorphicValue((aa + bb + 1) / bb);
}
}
return PolymorphicValue(std::ceil((a / b).as<double>()));
}

// Implementation of max - moved from polymorphic_value.h to reduce template bloat.
// Uses operator!= and operator> which trigger ForAllTypes template instantiation.
PolymorphicValue max(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(a);
}
return PolymorphicValue(a > b ? a : b);
}

// Implementation of fmax - moved from polymorphic_value.h to reduce template bloat.
// Uses operator!= and operator< which trigger ForAllTypes template instantiation.
PolymorphicValue fmax(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(b);
}
return PolymorphicValue(a < b ? b : a);
}

// Implementation of min - moved from polymorphic_value.h to reduce template bloat.
// Uses operator!= and operator< which trigger ForAllTypes template instantiation.
PolymorphicValue min(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(a);
}
return PolymorphicValue(a < b ? a : b);
}

// Implementation of fmin - moved from polymorphic_value.h to reduce template bloat.
// Uses operator!= and operator> which trigger ForAllTypes template instantiation.
PolymorphicValue fmin(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(b);
}
return PolymorphicValue(a > b ? b : a);
}

size_t hash(const PolymorphicValue& v) {
constexpr size_t nan_hash_value = 572491308;
// NaNs are considered the same, so map all NaN values to same hash value.
Expand Down
76 changes: 15 additions & 61 deletions csrc/polymorphic_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,9 @@ inline bool isSameNanSensitive(const T& a, const T& b) {
return a == b;
}

inline bool isSame(const PolymorphicValue& a, const PolymorphicValue& b) {
if (a.type() != b.type()) {
return false;
}
if (a.is<at::Tensor>()) {
return (a.as<at::Tensor>().is_same(b.as<at::Tensor>()));
}
if (a.is<double>()) {
return isSameNanSensitive(a.as<double>(), b.as<double>());
}
if (a.is<std::complex<double>>()) {
return isSameNanSensitive(
a.as<std::complex<double>>(), b.as<std::complex<double>>());
}
return a == b;
}
// Declaration only - implementation in polymorphic_value.cpp
// Uses operator== which triggers ForAllTypes template instantiation
NVF_API bool isSame(const PolymorphicValue& a, const PolymorphicValue& b);

inline PolymorphicValue signbit(const PolymorphicValue& a) {
if (a.is<int64_t>()) {
Expand Down Expand Up @@ -322,56 +309,23 @@ inline PolymorphicValue fmod(
b.type().name());
}

inline PolymorphicValue ceildiv(
// Declarations only - implementations in polymorphic_value.cpp
// These functions use PolymorphicValue operators which trigger ForAllTypes
NVF_API PolymorphicValue ceildiv(
const PolymorphicValue& a,
const PolymorphicValue& b) {
if (a.is<int64_t>() && b.is<int64_t>()) {
auto aa = a.as<int64_t>();
auto bb = b.as<int64_t>();
if (bb > 0) {
return PolymorphicValue((aa + bb - 1) / bb);
} else {
return PolymorphicValue((aa + bb + 1) / bb);
}
}
return PolymorphicValue(std::ceil((a / b).as<double>()));
}

inline PolymorphicValue max(
const PolymorphicValue& b);
NVF_API PolymorphicValue max(
const PolymorphicValue& a,
const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(a);
}
return PolymorphicValue(a > b ? a : b);
}

inline PolymorphicValue fmax(
const PolymorphicValue& b);
NVF_API PolymorphicValue fmax(
const PolymorphicValue& a,
const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(b);
}
return PolymorphicValue(a < b ? b : a);
}

inline PolymorphicValue min(
const PolymorphicValue& b);
NVF_API PolymorphicValue min(
const PolymorphicValue& a,
const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(a);
}
return PolymorphicValue(a < b ? a : b);
}

inline PolymorphicValue fmin(
const PolymorphicValue& b);
NVF_API PolymorphicValue fmin(
const PolymorphicValue& a,
const PolymorphicValue& b) {
if (a != a) {
return PolymorphicValue(b);
}
return PolymorphicValue(a > b ? b : a);
}
const PolymorphicValue& b);

inline PolymorphicValue gcd(
const PolymorphicValue& a,
Expand Down
63 changes: 6 additions & 57 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,41 +414,9 @@ DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::ComplexDouble, std::complex<double>);

#undef DEFINE_DATATYPE_TO_NATIVE_TYPE

inline DataType getDataType(const PolymorphicValue& value) {
std::optional<DataType> dtype = std::nullopt;
PolymorphicValue::for_all_types([&value, &dtype](auto _) {
using T = typename decltype(_)::type;
if constexpr (IsPrimitiveNativeType<T>::value) {
if (value.is<T>()) {
dtype = NativeTypeToDataType<T>::type;
}
} else if constexpr (std::is_same_v<T, std::vector<PolymorphicValue>>) {
if (value.is<T>()) {
const auto& vec = value.as<T>();
size_t size = vec.size();
NVF_CHECK(size > 0, "Empty array is not supported");
dtype =
ArrayType{std::make_shared<DataType>(getDataType(vec[0])), size};
}
} else if constexpr (std::is_same_v<T, Pointer>) {
// For pointers in polymorphic value, we only store the data size of the
// pointee, so it is impossible to infer the pointer type.
NVF_CHECK(!value.is<T>(), "Can not infer pointer type.");
} else if constexpr (std::is_same_v<T, StructHandle>) {
if (value.is<T>()) {
dtype = value.as<T>().type();
}
} else if constexpr (std::is_same_v<T, Opaque>) {
if (value.is<T>()) {
const auto& opaque = value.as<T>();
dtype = DataType(OpaqueType{
.type_info = opaque.any().type(), .size = opaque.size()});
}
}
});
NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name());
return dtype.value();
}
// Get the DataType corresponding to the runtime type held in a PolymorphicValue.
// Implementation moved to polymorphic_value.cpp to reduce template instantiation.
NVF_API DataType getDataType(const PolymorphicValue& value);

inline bool isCompatibleDataType(DataType dtype, DataType dtype2) {
if (dtype == dtype2) {
Expand Down Expand Up @@ -1128,28 +1096,9 @@ Pointer::Pointer(void* ptr, DataType dtype)
: ptr_(reinterpret_cast<std::byte*>(ptr)),
size_bit_(dataTypeSizeBit(dtype)) {}

inline PolymorphicValue castToDtype(
PolymorphicValue value,
const DataType& dtype) {
if (!value.hasValue()) {
return value;
}
// Cast the given value to the given data type. This enables interface
// like: IrBuilder::create<Val>(0, DataType::Double) where value is
// an integer but the desired data type is double.
if (!hasCompatibleDataType(value, dtype)) {
PolymorphicValue::for_all_types([&](auto _) {
using T = typename decltype(_)::type;
if constexpr (IsPrimitiveNativeType<T>::value) {
if (isCompatibleDataType(NativeTypeToDataType<T>::type, dtype)) {
value = PolymorphicValue(static_cast<T>(value));
}
}
// TODO: support arrays and pointers
});
}
return value;
}
// Cast a PolymorphicValue to match the specified DataType.
// Implementation moved to polymorphic_value.cpp to reduce template instantiation.
NVF_API PolymorphicValue castToDtype(PolymorphicValue value, const DataType& dtype);

// Converts an enum to its underlying type.
// It corresponds with std::to_underlying introduced in c++23
Expand Down
Loading