Skip to content

Commit 3945e71

Browse files
committed
doc: EinsumTree
1 parent dd63379 commit 3945e71

1 file changed

Lines changed: 205 additions & 51 deletions

File tree

docs_sphinx/chapters/einsum_trees.rst

Lines changed: 205 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
11
Einsum Trees
22
============
33

4+
This chapter expands the capabilities of our tensor compiler by adding support for einsum trees. Specifically, we execute einsum trees
5+
by mapping them to a tree of unary and binary tensor operations. These operations can then be executed by our tensor operation backend.
6+
47
Lowering
58
--------
69

7-
This section expands the capabilities of our tensor compiler by adding support for einsum trees. Specifically, we execute einsum trees
8-
by mapping them to a tree of unary and binary tensor operations. These operations can then be executed by our tensor operation backend.
10+
An einsum tree represents multiple dependent tensor operations. It may contain nodes with two children (contractions), nodes with one
11+
child (transpositions), or leaf nodes (input tensors). The output tensor of the root node represents the result of the entire tree.
912

1013
1. Parsing
1114
^^^^^^^^^^
1215

1316
**Task**: Implement a function that parses the string representation of a tree and the numerically sorted dimension sizes.
1417

15-
First, we implemented a struct called ``EinsumNode`` to parse the string representation of a tree and the numerically sorted dimension sizes.
18+
To complete the task we have to parse a string of the format, where a tensor is represented by a list of numerical dimension e.g. ``[3,2,4,1]``.
19+
Further a operation is written by a ``->`` symbol e.g. a contraction ``[<input0>][<input1>]->[<output>]`` or a unary ``[<input>]->[<output>]``.
20+
Also note that we can produce dependencies on other operation by nesting a operation inside a Tensor e.g. ``[[<in0>][<in1>]->[<op0>]]->[<output>]``,
21+
where the output tensor of the nested operations defines the shape of the immediate input tensor.
22+
23+
A complete example of a einsum tree input is the string:
24+
25+
.. code-block::
26+
27+
[[[[3,6,8,9]->[8,6,9,3]],[[2,5,7,9]->[7,5,2,9]]->[7,8,5,6,2,3]],[0,4,5,6]->[0,4,7,8,2,3]],[1,4,7,8]->[0,1,2,3]
28+
29+
And an additional list of dimensions sizes sorted by the numerical id of the dimension:
30+
31+
.. code-block::
32+
33+
60,60,20,20,8,8,8,8,8,8
34+
35+
36+
To achieve this goal, we implemented a struct called ``EinsumNode`` to parse the string representation of a tree and the numerically sorted dimension sizes.
1637
This structure holds one node of the tree, its possible children, dimension sizes, and a tensor representing an intermediate or final
1738
(root node) result.
1839

@@ -105,6 +126,65 @@ To ensure the success of all tensor operations, the methods return an ``ErrorExe
105126

106127
**Task**: Benchmark the performance of your implementation for the above examples. Report the measured performance in GFLOPS.
107128

129+
**First Example**
130+
131+
Einsum tree:
132+
133+
.. code-block:: vim
134+
135+
0,1,2,3,4
136+
├─ 7,3,4
137+
| ├─ 8,4
138+
| └─ 7,3,8
139+
└─ 0,1,2,7
140+
├─ 1,2,5,7
141+
| ├─ 2,6,7
142+
| └─ 1,5,6
143+
└─ 0,5
144+
145+
String representation:
146+
147+
.. code-block::
148+
149+
[[8,4],[7,3,8]->[7,3,4]],[[[2,6,7],[1,5,6]->[1,2,5,7]],[0,5]->[0,1,2,7]]->[0,1,2,3,4]
150+
151+
Dimension sizes (sorted by numerical ID):
152+
153+
.. code-block::
154+
155+
100,72,128,128,3,71,305,32,3
156+
157+
**Second Example**
158+
159+
Einsum tree:
160+
161+
.. code-block:: vim
162+
163+
0,1,2,3
164+
├─ 0,4,7,8,2,3
165+
| ├─ 7,8,5,6,2,3
166+
| | ├─ 8,6,9,3
167+
| | | └─ 3,6,8,9
168+
| | └─ 7,5,2,9
169+
| | └─ 2,5,7,9
170+
| └─ 0,4,5,6
171+
└─ 1,4,7,8
172+
173+
String representation:
174+
175+
.. code-block::
176+
177+
[[[[3,6,8,9]->[8,6,9,3]],[[2,5,7,9]->[7,5,2,9]]->[7,8,5,6,2,3]],[0,4,5,6]->[0,4,7,8,2,3]],[1,4,7,8]->[0,1,2,3]
178+
179+
Dimension sizes (sorted by numerical ID):
180+
181+
.. code-block::
182+
183+
60,60,20,20,8,8,8,8,8,8
184+
185+
186+
Performing a benchmark on both Einsum Trees, we get the following performance:
187+
108188
.. code-block:: bash
109189
:emphasize-lines: 4, 8
110190
@@ -132,18 +212,25 @@ Optimization
132212

133213
**Task**: Develop an optimization pass for einsum trees that applies the three transformations.
134214

215+
Three transformation that can be performed on the einsum tree are reorder, swap and permutation insert.
216+
217+
- **Reorder**: Operates on individual tensor to reorder its dimensions such that next involved tensor operation has a better performance.
218+
- **Swap**: Swap the two children of a contraction to mitigate the usage of permutation inserts.
219+
- **Permutation Insert**: Inserts an additional node in the tree to perform a reordering for the next tensor operation.
220+
135221
Reorder Node
136222
""""""""""""
137223

138-
For the reorder node we divided into an different optimization pass for the left and the right node.
224+
For the reorder node we divided the optimization into an different pass for the left and the right node.
139225

140226
For the reorder pass, we divided the transformation into two methods. The first is ``reorder_left_node``, which reorders the left child node
141227
of a node. The second method is ``reorder_right_node``, which is designed to reorder the right child node of a node.
142-
This division is due to the fact that the left node requires the M dimension as the unit stride, while the right node requires the K1 dimension.
228+
This division is due to the fact that the left node requires the M dimension as the unit-stride, while the right node requires the K1 dimension
229+
as unit-stride.
143230

144231
*Left Node:*
145232

146-
The method ``reorder_left_node`` checks if the last dimensions of the left child node are ``KM``. If not, it permutes the dimensions to
233+
The method ``reorder_left_node`` checks if the last dimensions of the left child node are ``KM`` dimensions. If not, it permutes the dimensions to
147234
move ``KM`` to the rightmost location. First, we determine the index of the first occurrence of the ``M`` and ``K`` dimension in the left
148235
child node of the node from right to left. If they are already in order, we return. Otherwise, we place them at the desired index location.
149236

@@ -221,7 +308,7 @@ child node of the node from right to left. If they are already in order, we retu
221308
222309
*Right Node:*
223310

224-
The method ``reorder_right_node`` checks if the last dimensions of the right child node are ``NK``. If not, it permutes the dimensions to
311+
The method ``reorder_right_node`` checks if the last dimensions of the right child node are ``NK`` dimensions. If not, it permutes the dimensions to
225312
move ``NK`` to the rightmost location. First, we determine the index of the first occurrence of the ``N`` and ``K`` dimension in the right
226313
child node of the node from right to left. If they are already in order, we return. Otherwise, we place them at the desired index location.
227314

@@ -256,8 +343,9 @@ The right node reordering is very similar to the left node reordering, but it or
256343
Insert Permutation Node
257344
"""""""""""""""""""""""
258345

259-
If the ``reorder_left_node`` or ``reorder_right_node`` method reorders a leaf node, an additional permutation node is inserted. Here the
260-
fragment in the ``reorder_left_node`` method:
346+
The permutation node is only added if the ``reorder_left_node`` or ``reorder_right_node`` method reorders a leaf node i.e. a node that is provided by the user.
347+
348+
The code fragment of a permutation node in the ``reorder_left_node`` method:
261349

262350
.. code-block:: cpp
263351
@@ -301,9 +389,10 @@ And for the ``reorder_right_node`` method:
301389
Swap Contraction Nodes
302390
""""""""""""""""""""""
303391

304-
For our current needs, a conditional swap is sufficient. The idea behind the method is to check if a node's unit stride dimension is of type
305-
``N``. If this is the case, we swap its children to later obtain a unit stride dimension in the first input tensor (left child node). We use
306-
the C++ ``swap`` method to swap the child nodes of a node, swapping the left child node pointer with the right child node pointer.
392+
The swap method allows optimization so that the order of the input tensor does not affect the performance of the contraction.
393+
Therefore, the idea behind the swap method is to check if a node's unit-stride dimension is of type ``N``.
394+
If this is the case, we swap its children to obtain a unit-stride dimension in the first input tensor (left child node).
395+
We use the C++ ``swap`` method to swap the child nodes of a node, swapping the left child node pointer with the right child node pointer.
307396

308397
.. code-block:: cpp
309398
@@ -314,66 +403,131 @@ the C++ ``swap`` method to swap the child nodes of a node, swapping the left chi
314403
{
315404
std::swap(node->left, node->right);
316405
}
317-
}.. code-block:: cpp
318-
319-
void mini_jit::EinsumTree::reorder_left_node(EinsumNode *node)
320-
{
321-
...
322-
323-
if (node->left->type == NodeType::Leaf)
324-
{
325-
// Add additional Permutation Node
326-
EinsumNode *reorderNode = new EinsumNode();
327-
reorderNode->type = NodeType::Transposition;
328-
reorderNode->output_dim_ids = std::move(reorderDimIds);
329-
330-
reorderNode->left = node->left;
331-
node->left = reorderNode;
332-
}
333-
else
334-
{
335-
// Only reorder the output of the left operation
336-
node->left->output_dim_ids = std::move(reorderDimIds);
337-
}
338406
}
339407
340408
Heuristic
341409
"""""""""
342410

343-
We used a heuristic to apply the optimization passes to our einsum tree.
411+
To apply the optimization passes to three, we used a heuristic to decided when and how the optimization are applied.
412+
We do the following steps:
413+
414+
1. First, we check whether the node is a contraction node, and if it is, we proceed to the next check. Otherwise we return from the optimization.
415+
2. Next, we check if the unit stride dimension type of the node is ``N``. If so, we swap the child nodes of the node to get a unit stride
416+
in the ``M`` dimension of the first input tensor (the left child node).
417+
3. We call the ``reorder_left_node`` method on the node. The method then checks if the last dimensions of the left child node are
418+
``KM``. If not, it permutes the dimensions to move ``KM`` to the rightmost location.
419+
4. We call the ``reorder_right_node`` method on the node. The method then checks if the last dimensions of the right child node are
420+
``NK``. If not, it permutes the dimensions to move ``NK`` to the rightmost location.
421+
5. We call on both child nodes recursively the optimization pass.
422+
423+
Implementation of the heuristic:
344424

345425
.. code-block:: cpp
346426
347427
void mini_jit::EinsumTree::optimize(EinsumNode *node)
348428
{
349-
if (node->type != NodeType::Contraction)
350-
{
351-
return;
352-
}
429+
if (node->type != NodeType::Contraction)
430+
{
431+
return;
432+
}
353433
354-
conditional_swap(node);
434+
conditional_swap(node);
355435
356-
reorder_left_node(node);
357-
reorder_right_node(node);
436+
reorder_left_node(node);
437+
reorder_right_node(node);
358438
359-
optimize(node->left);
360-
optimize(node->right);
439+
optimize(node->left);
440+
optimize(node->right);
361441
}
362442
363-
1. First, we check whether the node is a contraction node, and if it is, we proceed to the next check. Otherwise we return from the optimization.
364-
2. Next, we check if the unit stride dimension type of the node is ``N``. If so, we swap the child nodes of the node to get a unit stride
365-
in the ``M`` dimension of the first input tensor (the left child node).
366-
3. We call the ``reorder_left_node`` method on the node. The method then checks if the last dimensions of the left child node are
367-
``KM``. If not, it permutes the dimensions to move ``KM`` to the rightmost location.
368-
4. We call the ``reorder_right_node`` method on the node. The method then checks if the last dimensions of the right child node are
369-
``NK``. If not, it permutes the dimensions to move ``NK`` to the rightmost location.
370-
5. We call on both child nodes recursively the optimization pass.
371443
372444
2. Performance
373445
^^^^^^^^^^^^^^
374446

375447
**Task**: Benchmark the performance of your implementation on the provided examples. Report the measured performance in GFLOPS.
376448

449+
**First Example**
450+
451+
Einsum tree:
452+
453+
.. code-block::
454+
455+
0,1,2,3,4
456+
├─ 7,3,4
457+
| ├─ 7,3,8
458+
| └─ 8,4
459+
└─ 0,1,2,7
460+
├─ 0,5
461+
└─ 5,1,2,7
462+
├─ 5,1,6
463+
└─ 6,2,7
464+
465+
String representation:
466+
467+
.. code-block::
468+
469+
[[7,3,8],[8,4]->[7,3,4]],[[0,5],[[5,1,6],[6,2,7]->[5,1,2,7]]->[0,1,2,7]]->[0,1,2,3,4]
470+
471+
Dimension sizes (by numerical ID):
472+
473+
.. code-block::
474+
475+
100,72,128,128,3,71,305,32,3
476+
477+
**Second Example**
478+
479+
Einsum tree:
480+
481+
.. code-block:: vim
482+
483+
0,1,2,3
484+
├─ 1,4,7,8
485+
└─ 0,4,2,7,3,8
486+
├─ 0,4,5,6
487+
└─ 2,5,7,3,6,8
488+
├─ 2,5,7,9
489+
└─ 3,6,8,9
490+
491+
String representation:
492+
493+
.. code-block::
494+
495+
[1,4,7,8],[[0,4,5,6],[[2,5,7,9],[3,6,8,9]->[2,5,7,3,6,8]]->[0,4,2,7,3,8]]->[0,1,2,3]
496+
497+
Dimension sizes (by numerical ID):
498+
499+
.. code-block::
500+
501+
60,60,20,20,8,8,8,8,8,8
502+
503+
**Third Example**
504+
505+
.. code-block:: vim
506+
507+
5,6,7,8,9
508+
├─ 2,7,8,4
509+
| ├─ 2,7,3
510+
| └─ 3,8,4
511+
└─ 4,9,5,6,2
512+
├─ 4,9,0
513+
└─ 0,5,6,2
514+
├─ 0,5,1
515+
└─ 1,6,2
516+
517+
String representation:
518+
519+
.. code-block::
520+
521+
[[2,7,3],[3,8,4]->[2,7,8,4]],[[4,9,0],[[0,5,1],[1,6,2]->[0,5,6,2]]->[4,9,5,6,2]]->[5,6,7,8,9]
522+
523+
Dimension sizes (by numerical ID):
524+
525+
.. code-block::
526+
527+
40,40,40,40,40,25,25,25,25,25
528+
529+
On the three example we get the following performance:
530+
377531
.. code-block:: bash
378532
:emphasize-lines: 4, 8, 12
379533

0 commit comments

Comments
 (0)