2
2
from collections .abc import Mapping
3
3
from numbers import Number
4
4
5
+ from bindings .looptree import TemporalTag , SequentialTag , PipelineTemporalTag
6
+
5
7
import islpy as isl
6
8
7
9
from pytimeloop .isl .singular import get_sum_of_pw_qpolynomial
10
+ from pytimeloop .isl .sum import sum_with_mask
8
11
from pytimeloop .looptree .mapping_utilities import *
9
12
10
13
@@ -29,7 +32,8 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
29
32
reads_to_parent ,
30
33
mapping ,
31
34
workload ,
32
- is_path = False ):
35
+ is_path = False ,
36
+ per_unit = False ):
33
37
mapping = mapping ['nodes' ]
34
38
dspace_id_to_name = workload .data_space_id_to_name ()
35
39
einsum_id_to_name = workload .einsum_id_to_name ()
@@ -49,8 +53,33 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
49
53
for (buffer_id , dspace_id , einsum_id ), (tags , fill ) in fills .items ():
50
54
read_to_parent = reads_to_parent [(buffer_id , dspace_id , einsum_id )][1 ]
51
55
52
- read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
53
- fill = get_sum_of_pw_qpolynomial (fill )
56
+ if not per_unit :
57
+ read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
58
+ fill = get_sum_of_pw_qpolynomial (fill )
59
+ else :
60
+ fill = sum_with_mask (
61
+ [
62
+ (
63
+ isinstance (t , TemporalTag ) or
64
+ isinstance (t , PipelineTemporalTag ) or
65
+ isinstance (t , SequentialTag )
66
+ )
67
+ for t in tags
68
+ ],
69
+ fill
70
+ ).max ().to_python ()
71
+ n_read_to_parent_dim = read_to_parent .dim (isl .dim_type .in_ )
72
+ read_to_parent = sum_with_mask (
73
+ [
74
+ (
75
+ isinstance (t , TemporalTag ) or
76
+ isinstance (t , PipelineTemporalTag ) or
77
+ isinstance (t , SequentialTag )
78
+ )
79
+ for t in tags [:n_read_to_parent_dim ]
80
+ ],
81
+ read_to_parent
82
+ ).max ().to_python ()
54
83
55
84
dspace_name = dspace_id_to_name [dspace_id ]
56
85
einsum_name = einsum_id_to_name [einsum_id ]
@@ -61,24 +90,32 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
61
90
key = (parent_buffer , dspace_name , einsum_name )
62
91
if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
63
92
writes [key ] += read_to_parent
93
+ reads [key ] += read_to_parent
64
94
# Subtracted term: elided first read of a read-write tensor
65
- reads [key ] += \
66
- read_to_parent - workload .get_tensor_volume (dspace_id )
95
+ # TODO: figure out how to do this per unit
96
+ if not per_unit :
97
+ reads [key ] -= workload .get_tensor_volume (dspace_id )
67
98
elif dspace_id in workload .tensors_read_by_einsum (einsum_id ):
68
99
reads [key ] += read_to_parent
69
100
# Fills will write into current buffer except for compute (which does
70
101
# not have write action) and top-level buffer
71
102
if buffer_id not in compute_targets and parent_buffer is not None :
72
103
if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
73
- writes [(buffer_id , dspace_name , einsum_name )] += \
74
- fill - workload .get_tensor_volume (dspace_id )
104
+ writes [(buffer_id , dspace_name , einsum_name )] += fill
105
+ if not per_unit :
106
+ writes [(buffer_id , dspace_name , einsum_name )] -= \
107
+ workload .get_tensor_volume (dspace_id )
75
108
else :
76
109
writes [(buffer_id , dspace_name , einsum_name )] += fill
77
110
78
111
return reads , writes
79
112
80
113
81
- def reads_and_writes_from_fill_by_peer (fills : Mapping , mapping , workload , is_path = False ):
114
+ def reads_and_writes_from_fill_by_peer (fills : Mapping ,
115
+ mapping ,
116
+ workload ,
117
+ is_path = False ,
118
+ per_unit = False ):
82
119
mapping = mapping ['nodes' ]
83
120
dspace_id_to_name = workload .data_space_id_to_name ()
84
121
einsum_id_to_name = workload .einsum_id_to_name ()
@@ -89,14 +126,27 @@ def reads_and_writes_from_fill_by_peer(fills: Mapping, mapping, workload, is_pat
89
126
einsums_with_complete_mappings = get_einsums_with_complete_mappings (mapping , workload , is_path )
90
127
91
128
for (buffer_id , dspace_id , einsum_id ), (tags , fill ) in fills .items ():
92
- fill = get_sum_of_pw_qpolynomial (fill )
129
+ if not per_unit :
130
+ fill = get_sum_of_pw_qpolynomial (fill )
131
+ else :
132
+ fill = sum_with_mask (
133
+ [
134
+ (
135
+ isinstance (t , TemporalTag ) or
136
+ isinstance (t , PipelineTemporalTag ) or
137
+ isinstance (t , SequentialTag )
138
+ )
139
+ for t in tags
140
+ ],
141
+ fill
142
+ ).max ().to_python ()
93
143
einsum_name = einsum_id_to_name [einsum_id ]
94
144
dspace_name = dspace_id_to_name [dspace_id ]
95
145
if einsum_id not in einsums_with_complete_mappings :
96
146
continue
97
147
98
148
reads [(buffer_id , dspace_name , einsum_name )] = fill
99
- writes [(buffer_id , dspace_name , einsum_name )] = 0 # already accounted for in above
149
+ writes [(buffer_id , dspace_name , einsum_name )] = 0 # already accounted for in fill_by_parent
100
150
101
151
return reads , writes
102
152
0 commit comments