1
1
import unittest
2
- from pathlib import Path
3
- from itertools import starmap
2
+ from sympy import ceiling
4
3
5
- from bindings .looptree import *
6
- from tests .util import TEST_TMP_DIR
4
+ from bindings .looptree import LooptreeWorkload , LooptreeWorkloadDependencyAnalyzer
7
5
8
6
from tests .load_config_mixin import LoadConfigMixin
9
7
@@ -22,5 +20,20 @@ def test_model_with_two_level_mm(self):
22
20
workload = LooptreeWorkload .parse_cfg (config .root ['problem' ])
23
21
analyzer = LooptreeWorkloadDependencyAnalyzer (workload )
24
22
25
- result = analyze_reuse (mapping , workload , analyzer )
26
- print (result )
23
+ tile_shapes , result = analyze_reuse (mapping , workload , analyzer )
24
+
25
+ self .assertEqual (len (tile_shapes ), 3 )
26
+ P1_tile_shape , C1_tile_shape , M1_tile_shape = tile_shapes
27
+
28
+ REFERENCE_FILLS = {
29
+ ('DRAM' , 0 , 0 ): (None , 18 ),
30
+ ('DRAM' , 1 , 0 ): (None , 8 ),
31
+ ('DRAM' , 2 , 0 ): (None , 36 ),
32
+ ('GlobalBuffer' , 0 , 0 ): (None , 18.0 * ceiling (4 / M1_tile_shape )),
33
+ ('GlobalBuffer' , 1 , 0 ): (None , 8 )
34
+ }
35
+
36
+ for key , ref_value in REFERENCE_FILLS .items ():
37
+ self .assertEqual (result .fills [key ],
38
+ ref_value ,
39
+ f'fills for { key } do not match' )
0 commit comments