@@ -63,18 +63,23 @@ def forward(
63
63
is_cross_attention : bool = False ,
64
64
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
65
65
"""Input shape: Batch x Time x Channel"""
66
- bsz , tgt_len , _ = hidden_states .size ()
66
+ # determine input shapes
67
+ bsz , tgt_len = hidden_states .shape [:- 1 ]
68
+ q_input_shape = (bsz , tgt_len , - 1 , self .head_dim )
67
69
68
- # get query proj
69
- query_states = self ._shape (self .q_proj (hidden_states ) * self .scaling , tgt_len , bsz )
70
+ query_states = self .q_proj (hidden_states ) * self .scaling
71
+ query_states = query_states .view (* q_input_shape )
72
+ query_states = query_states .transpose (1 , 2 ).contiguous ()
70
73
71
74
if self .is_decoder :
72
75
if is_cross_attention and past_key_value :
73
76
# cross_attentions
74
77
key_states_old = past_key_value [self .layer_idx ][0 ]
75
78
value_states_old = past_key_value [self .layer_idx ][1 ]
76
- key_states = self ._shape (self .k_proj (key_value_states ), - 1 , bsz )
77
- value_states = self ._shape (self .v_proj (key_value_states ), - 1 , bsz )
79
+ key_states = self .k_proj (key_value_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
80
+ value_states = self .v_proj (key_value_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
81
+ key_states = key_states .transpose (1 , 2 ).contiguous ()
82
+ value_states = value_states .transpose (1 , 2 ).contiguous ()
78
83
indices = (torch .arange (bsz ),)
79
84
key_states_new = torch .index_put (key_states_old , indices , key_states )
80
85
value_states_new = torch .index_put (value_states_old , indices , value_states )
@@ -85,21 +90,25 @@ def forward(
85
90
input_features .shape [2 ] == torch .tensor (1 ), value_states_old , value_states_new
86
91
)
87
92
88
- past_key_value .key_cache [self .layer_idx ] = key_states
89
- past_key_value .value_cache [self .layer_idx ] = value_states
93
+ past_key_value .layers [self .layer_idx ]. keys = key_states
94
+ past_key_value .layers [self .layer_idx ]. values = value_states
90
95
else :
91
96
# self attention decoder
92
- key_states = self ._shape (self .k_proj (hidden_states ), - 1 , bsz )
93
- value_states = self ._shape (self .v_proj (hidden_states ), - 1 , bsz )
97
+ key_states = self .k_proj (hidden_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
98
+ value_states = self .v_proj (hidden_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
99
+ key_states = key_states .transpose (1 , 2 ).contiguous ()
100
+ value_states = value_states .transpose (1 , 2 ).contiguous ()
94
101
if past_key_value is not None :
95
102
cache_kwargs = {"position_ids" : position_ids_layer }
96
103
key_states , value_states = past_key_value .update (
97
104
key_states , value_states , self .layer_idx , cache_kwargs
98
105
)
99
106
else :
100
107
# self_attention Encoder
101
- key_states = self ._shape (self .k_proj (hidden_states ), - 1 , bsz )
102
- value_states = self ._shape (self .v_proj (hidden_states ), - 1 , bsz )
108
+ key_states = self .k_proj (hidden_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
109
+ value_states = self .v_proj (hidden_states ).view (bsz , - 1 , self .num_heads , self .head_dim )
110
+ key_states = key_states .transpose (1 , 2 ).contiguous ()
111
+ value_states = value_states .transpose (1 , 2 ).contiguous ()
103
112
104
113
src_len = key_states .size (2 )
105
114
@@ -150,7 +159,7 @@ def forward(
150
159
151
160
attn_output = self .out_proj (attn_output )
152
161
153
- return [attn_output , attn_weights , past_key_value ]
162
+ return [attn_output , attn_weights ]
154
163
155
164
156
165
class QEffWhisperDecoderLayer (WhisperDecoderLayer ):
@@ -203,7 +212,7 @@ def forward(
203
212
204
213
# Self Attention
205
214
self_attn_past_key_value = past_key_value .self_attention_cache if past_key_value is not None else None
206
- hidden_states , self_attn_weights , self_attn_present_key_value = self .self_attn (
215
+ hidden_states , self_attn_weights = self .self_attn (
207
216
hidden_states = hidden_states ,
208
217
past_key_value = self_attn_past_key_value ,
209
218
attention_mask = attention_mask ,
@@ -217,13 +226,12 @@ def forward(
217
226
hidden_states = residual + hidden_states
218
227
219
228
# Cross-Attention Block
220
- cross_attn_present_key_value = None
221
229
cross_attn_weights = None
222
230
if is_encoder_decoder :
223
231
residual = hidden_states
224
232
hidden_states = self .encoder_attn_layer_norm (hidden_states )
225
233
cross_attn_past_key_value = past_key_value .cross_attention_cache if past_key_value is not None else None
226
- hidden_states , cross_attn_weights , cross_attn_present_key_value = self .encoder_attn (
234
+ hidden_states , cross_attn_weights = self .encoder_attn (
227
235
hidden_states = hidden_states ,
228
236
key_value_states = encoder_hidden_states ,
229
237
attention_mask = encoder_attention_mask ,
@@ -237,13 +245,6 @@ def forward(
237
245
hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
238
246
hidden_states = residual + hidden_states
239
247
240
- # update the cached past_key_values accordingly
241
- past_key_value .self_attention_cache = self_attn_present_key_value
242
- past_key_value .cross_attention_cache = cross_attn_present_key_value
243
- else :
244
- # if no cross_attention, still need to update self_attn cache
245
- past_key_value = self_attn_present_key_value
246
-
247
248
# Fully Connected
248
249
residual = hidden_states
249
250
hidden_states = self .final_layer_norm (hidden_states )
0 commit comments