Sym reshapes#4832
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4832 +/- ##
===========================================
+ Coverage 92.79% 92.80% +0.01%
===========================================
Files 584 586 +2
Lines 30111 30140 +29
===========================================
+ Hits 27941 27970 +29
Misses 2170 2170
🚀 New features to boost your workflow:
|
|
I would like to merge in #3753 first which unifies the reshape calculation for both reshape and reshape_lazy and it also propagates the permutation. |
| // A dim attribute entry that may be either a plain int64_t or a | ||
| // dynamic_dimension. Used by ops whose dim-valued attributes need to carry | ||
| // either static integers or dynamic/symbolic dimensions. | ||
| struct MIGRAPHX_EXPORT dim_like |
There was a problem hiding this comment.
Why do you wrap it in the class instead of using the variant directly? If its for implicit constructor for SLES maybe we need to create our own variant wrapper.
But if we do make it a seperate class we can try to make it more friendly to use.
There was a problem hiding this comment.
I think the main issue was that a lot of passes will pass some shape.lens() (size_t) to the dims attribute which would require modifying too many callsites.
Can you explain a little more about what this variant wrapper would look like? It should like a good idea since I had to do similar wrapping for scalar in sym
There was a problem hiding this comment.
I am thinking something like this:
template <class Picker, class... Ts>
struct picked_variant : std::variant<Ts...>
{
using base = std::variant<Ts...>;
using base::base; // inherit default, in_place_type, in_place_index ctors
template <class T,
MIGRAPHX_REQUIRES(std::is_base_of<T, base>{})>
picked_variant(T&& x) : base(Picker::apply(std::forward<T>(x))) {}
};But std::visit wont work directly without p2162r2 from C++23, so we may need to write a custom visit method that downcasts the variants.
There was a problem hiding this comment.
I dont quite get the meaning of MIGRAPHX_REQUIRES(std::is_base_of<T, base>{})>. Shouldn't it be something like not std::is_constructible<base, T&&>{} so it only got to the picker for types that are not supported by default?
Motivation
Refactor reshape ops to use dim-like variant to handle symbolic shapes. This will be required when trying to enable simplification passes for symbolic shapes.
Technical Details
Add
dim_likevariant that works out of the box for current int64 representation for static shapes but can be extended to dyn_dims for symbolic shapes to allow symbolic target dims.compute_shape will be modified in a later PR, the purpose here is to just introduce the new variant and ensure it causes no regressions with existing static shape compilation.
Later we would want to use this similarly for other ops such as slice, resize, etc. We can also consider refactoring broadcast ops to consolidate the current dual attribute implementation with out_lens and out_dyn_dims
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable