1
+ from typing import overload
2
+
1
3
from pytimeloop .isl .singular import get_value_from_singular_qpolynomial
2
4
from pytimeloop .looptree .latency .processors import LATENCY_PROCESSORS
3
- from pytimeloop .looptree .reuse .isl .des import IslReuseAnalysisOutput
5
+ from pytimeloop .looptree .reuse .isl import IslReuseAnalysisOutput
6
+ from pytimeloop .looptree .reuse .summarized import SummarizedAnalysisOutput
4
7
from pytimeloop .looptree .latency .memory import memory_latency
5
8
6
9
from bindings .looptree import SpatialTag
@@ -11,9 +14,9 @@ def get_latency(looptree_results: IslReuseAnalysisOutput,
11
14
workload ,
12
15
arch ,
13
16
bindings ):
14
- comp_latency = compute_latency ( mapping ,
15
- looptree_results . temporal_steps ,
16
- workload )
17
+ comp_latency = calculate_compute_latency ( looptree_results ,
18
+ mapping ,
19
+ workload )
17
20
mem_latency = memory_latency (looptree_results ,
18
21
arch ,
19
22
mapping ,
@@ -23,12 +26,40 @@ def get_latency(looptree_results: IslReuseAnalysisOutput,
23
26
return overall_latency , comp_latency , mem_latency
24
27
25
28
26
- def compute_latency (mapping , temporal_steps , workload ):
29
+ @overload
30
+ def calculate_compute_latency (reuse_analysis_results : IslReuseAnalysisOutput ,
31
+ mapping ,
32
+ workload ):
33
+ pass
34
+ @overload
35
+ def calculate_compute_latency (reuse_analysis_results : SummarizedAnalysisOutput ,
36
+ mapping ,
37
+ workload ):
38
+ pass
39
+ def calculate_compute_latency (reuse_analysis_results , mapping , workload ):
40
+ if isinstance (reuse_analysis_results , IslReuseAnalysisOutput ):
41
+ return compute_isl_latency (reuse_analysis_results .temporal_steps ,
42
+ mapping ,
43
+ workload )
44
+ elif isinstance (reuse_analysis_results , SummarizedAnalysisOutput ):
45
+ return compute_summarized_latency (
46
+ reuse_analysis_results .temporal_steps ,
47
+ mapping ,
48
+ workload
49
+ )
50
+
51
+
52
+ def compute_isl_latency (temporal_steps , mapping , workload ):
27
53
return get_value_from_singular_qpolynomial (
28
54
_compute_latency (mapping .nodes , 0 , temporal_steps , workload )[1 ]
29
55
).to_python ()
30
56
31
57
58
+ def compute_summarized_latency (temporal_steps , mapping , workload ):
59
+ # TODO: this is only for single-Einsum!!!
60
+ return sum (value for key , value in temporal_steps )
61
+
62
+
32
63
def _compute_latency (mapping , top_idx : int , temporal_steps , workload ):
33
64
einsum_name_to_id = workload .einsum_name_to_id ()
34
65
0 commit comments