Skip to content

Commit a514d36

Browse files
committed
Added Whisper TF Upgrade
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
1 parent f5e7daa commit a514d36

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

QEfficient/transformers/models/whisper/modeling_whisper.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,23 @@ def forward(
6363
is_cross_attention: bool = False,
6464
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
6565
"""Input shape: Batch x Time x Channel"""
66-
bsz, tgt_len, _ = hidden_states.size()
66+
# determine input shapes
67+
bsz, tgt_len = hidden_states.shape[:-1]
68+
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
6769

68-
# get query proj
69-
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
70+
query_states = self.q_proj(hidden_states) * self.scaling
71+
query_states = query_states.view(*q_input_shape)
72+
query_states = query_states.transpose(1, 2).contiguous()
7073

7174
if self.is_decoder:
7275
if is_cross_attention and past_key_value:
7376
# cross_attentions
7477
key_states_old = past_key_value[self.layer_idx][0]
7578
value_states_old = past_key_value[self.layer_idx][1]
76-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
77-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
79+
key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
80+
value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
81+
key_states = key_states.transpose(1, 2).contiguous()
82+
value_states = value_states.transpose(1, 2).contiguous()
7883
indices = (torch.arange(bsz),)
7984
key_states_new = torch.index_put(key_states_old, indices, key_states)
8085
value_states_new = torch.index_put(value_states_old, indices, value_states)
@@ -85,21 +90,25 @@ def forward(
8590
input_features.shape[2] == torch.tensor(1), value_states_old, value_states_new
8691
)
8792

88-
past_key_value.key_cache[self.layer_idx] = key_states
89-
past_key_value.value_cache[self.layer_idx] = value_states
93+
past_key_value.layers[self.layer_idx].keys = key_states
94+
past_key_value.layers[self.layer_idx].values = value_states
9095
else:
9196
# self attention decoder
92-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
93-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
97+
key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
98+
value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
99+
key_states = key_states.transpose(1, 2).contiguous()
100+
value_states = value_states.transpose(1, 2).contiguous()
94101
if past_key_value is not None:
95102
cache_kwargs = {"position_ids": position_ids_layer}
96103
key_states, value_states = past_key_value.update(
97104
key_states, value_states, self.layer_idx, cache_kwargs
98105
)
99106
else:
100107
# self_attention Encoder
101-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
108+
key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
109+
value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
110+
key_states = key_states.transpose(1, 2).contiguous()
111+
value_states = value_states.transpose(1, 2).contiguous()
103112

104113
src_len = key_states.size(2)
105114

@@ -150,7 +159,7 @@ def forward(
150159

151160
attn_output = self.out_proj(attn_output)
152161

153-
return [attn_output, attn_weights, past_key_value]
162+
return [attn_output, attn_weights]
154163

155164

156165
class QEffWhisperDecoderLayer(WhisperDecoderLayer):
@@ -203,7 +212,7 @@ def forward(
203212

204213
# Self Attention
205214
self_attn_past_key_value = past_key_value.self_attention_cache if past_key_value is not None else None
206-
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn(
215+
hidden_states, self_attn_weights = self.self_attn(
207216
hidden_states=hidden_states,
208217
past_key_value=self_attn_past_key_value,
209218
attention_mask=attention_mask,
@@ -217,13 +226,12 @@ def forward(
217226
hidden_states = residual + hidden_states
218227

219228
# Cross-Attention Block
220-
cross_attn_present_key_value = None
221229
cross_attn_weights = None
222230
if is_encoder_decoder:
223231
residual = hidden_states
224232
hidden_states = self.encoder_attn_layer_norm(hidden_states)
225233
cross_attn_past_key_value = past_key_value.cross_attention_cache if past_key_value is not None else None
226-
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
234+
hidden_states, cross_attn_weights = self.encoder_attn(
227235
hidden_states=hidden_states,
228236
key_value_states=encoder_hidden_states,
229237
attention_mask=encoder_attention_mask,
@@ -237,13 +245,6 @@ def forward(
237245
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
238246
hidden_states = residual + hidden_states
239247

240-
# update the cached past_key_values accordingly
241-
past_key_value.self_attention_cache = self_attn_present_key_value
242-
past_key_value.cross_attention_cache = cross_attn_present_key_value
243-
else:
244-
# if no cross_attention, still need to update self_attn cache
245-
past_key_value = self_attn_present_key_value
246-
247248
# Fully Connected
248249
residual = hidden_states
249250
hidden_states = self.final_layer_norm(hidden_states)

0 commit comments

Comments
 (0)