5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
- from typing import Callable , List , Optional , Tuple , Union
8
+ from typing import List , Optional , Tuple , Union
9
9
10
10
import torch
11
11
from torch import nn
@@ -104,7 +104,6 @@ def eager_attention_forward(
104
104
value : torch .Tensor ,
105
105
attention_mask : Optional [torch .Tensor ],
106
106
scaling : float ,
107
- ** kwargs ,
108
107
):
109
108
key_states = repeat_kv (key , module .num_key_value_groups )
110
109
value_states = repeat_kv (value , module .num_key_value_groups )
@@ -114,7 +113,6 @@ def eager_attention_forward(
114
113
attn_weights = torch .where (
115
114
attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights
116
115
)
117
-
118
116
attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
119
117
attn_output = torch .matmul (attn_weights , value_states )
120
118
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
@@ -154,11 +152,10 @@ def forward(
154
152
query_states , key_states = qeff_apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
155
153
156
154
if past_key_value is not None :
157
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
158
- cache_kwargs = {"sin" : sin , "cos" : cos , "batch_index" : batch_index , "position_ids" : position_ids }
155
+ cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids }
159
156
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
160
157
161
- attention_interface : Callable = eager_attention_forward
158
+ attention_interface = eager_attention_forward
162
159
163
160
attn_output , attn_weights = attention_interface (
164
161
self ,
@@ -167,12 +164,11 @@ def forward(
167
164
value_states ,
168
165
attention_mask ,
169
166
scaling = self .scaling ,
170
- ** kwargs ,
171
167
)
172
168
173
169
attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
174
170
attn_output = self .o_proj (attn_output )
175
- return attn_output , attn_weights , past_key_value
171
+ return attn_output , attn_weights
176
172
177
173
178
174
class QEffGemmaDecoderLayer (GemmaDecoderLayer ):
@@ -189,7 +185,6 @@ def forward(
189
185
position_ids : Optional [torch .LongTensor ] = None ,
190
186
past_key_value : Optional [Cache ] = None ,
191
187
batch_index : Optional [torch .LongTensor ] = None ,
192
- output_attentions : Optional [bool ] = False ,
193
188
use_cache : Optional [bool ] = False ,
194
189
cache_position : Optional [torch .LongTensor ] = None ,
195
190
** kwargs ,
@@ -200,9 +195,6 @@ def forward(
200
195
attention_mask (`torch.FloatTensor`, *optional*):
201
196
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202
197
query_sequence_length, key_sequence_length)` if default attention is used.
203
- output_attentions (`bool`, *optional*):
204
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205
- returned tensors for more detail.
206
198
use_cache (`bool`, *optional*):
207
199
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208
200
(see `past_key_values`).
@@ -215,13 +207,12 @@ def forward(
215
207
hidden_states = self .input_layernorm (hidden_states )
216
208
217
209
# Self Attention
218
- hidden_states , self_attn_weights , present_key_value = self .self_attn (
210
+ hidden_states , _ = self .self_attn (
219
211
hidden_states = hidden_states ,
220
212
attention_mask = attention_mask ,
221
213
position_ids = position_ids ,
222
214
past_key_value = past_key_value ,
223
215
batch_index = batch_index ,
224
- output_attentions = output_attentions ,
225
216
use_cache = use_cache ,
226
217
cache_position = cache_position ,
227
218
** kwargs ,
@@ -234,15 +225,7 @@ def forward(
234
225
hidden_states = self .mlp (hidden_states )
235
226
hidden_states = residual + hidden_states
236
227
237
- outputs = (hidden_states ,)
238
-
239
- if output_attentions :
240
- outputs += (self_attn_weights ,)
241
-
242
- if use_cache :
243
- outputs += (present_key_value ,)
244
-
245
- return outputs
228
+ return hidden_states
246
229
247
230
248
231
class QEffGemmaModel (GemmaModel ):
@@ -261,18 +244,10 @@ def forward(
261
244
batch_index : Optional [torch .LongTensor ] = None ,
262
245
inputs_embeds : Optional [torch .FloatTensor ] = None ,
263
246
use_cache : Optional [bool ] = None ,
264
- output_attentions : Optional [bool ] = None ,
265
- output_hidden_states : Optional [bool ] = None ,
266
- return_dict : Optional [bool ] = None ,
267
247
cache_position : Optional [torch .LongTensor ] = None ,
268
248
** kwargs ,
269
249
) -> Union [Tuple , BaseModelOutputWithPast ]:
270
- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
271
- output_hidden_states = (
272
- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
273
- )
274
250
use_cache = use_cache if use_cache is not None else self .config .use_cache
275
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
276
251
277
252
if (input_ids is None ) ^ (inputs_embeds is not None ):
278
253
raise ValueError (
@@ -306,46 +281,27 @@ def forward(
306
281
normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype )
307
282
hidden_states = hidden_states * normalizer
308
283
309
- # decoder layers
310
- all_hidden_states = () if output_hidden_states else None
311
- all_self_attns = () if output_attentions else None
312
-
313
284
for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
314
- if output_hidden_states :
315
- all_hidden_states += (hidden_states ,)
316
-
317
- layer_outputs = decoder_layer (
285
+ hidden_states = decoder_layer (
318
286
hidden_states ,
319
287
attention_mask = causal_mask ,
320
288
position_ids = position_ids ,
321
289
past_key_value = past_key_values ,
322
290
batch_index = batch_index ,
323
- output_attentions = output_attentions ,
324
291
use_cache = use_cache ,
325
292
cache_position = cache_position ,
326
293
** kwargs ,
327
294
)
328
- hidden_states = layer_outputs [0 ]
329
-
330
- if output_attentions :
331
- all_self_attns += (layer_outputs [1 ],)
332
295
333
296
hidden_states = self .norm (hidden_states )
334
297
335
- # add hidden states from the last decoder layer
336
- if output_hidden_states :
337
- all_hidden_states += (hidden_states ,)
338
-
339
298
if return_legacy_cache :
340
299
past_key_values = past_key_values .to_legacy_cache ()
341
300
342
- output = BaseModelOutputWithPast (
301
+ return BaseModelOutputWithPast (
343
302
last_hidden_state = hidden_states ,
344
- past_key_values = past_key_values if use_cache else None ,
345
- hidden_states = all_hidden_states ,
346
- attentions = all_self_attns ,
303
+ past_key_values = past_key_values ,
347
304
)
348
- return output if return_dict else output .to_tuple ()
349
305
350
306
351
307
class QEffGemmaForCausalLM (GemmaForCausalLM ):
@@ -363,21 +319,10 @@ def forward(
363
319
past_key_values : Optional [Union [Cache , List [torch .FloatTensor ]]] = None ,
364
320
batch_index : Optional [torch .LongTensor ] = None ,
365
321
inputs_embeds : Optional [torch .FloatTensor ] = None ,
366
- labels : Optional [torch .LongTensor ] = None ,
367
322
use_cache : Optional [bool ] = None ,
368
- output_attentions : Optional [bool ] = None ,
369
- output_hidden_states : Optional [bool ] = None ,
370
- return_dict : Optional [bool ] = None ,
371
323
cache_position : Optional [torch .LongTensor ] = None ,
372
- logits_to_keep : Union [int , torch .Tensor ] = 0 ,
373
324
** kwargs ,
374
325
) -> Union [Tuple , CausalLMOutputWithPast ]:
375
- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
376
- output_hidden_states = (
377
- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
378
- )
379
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
380
-
381
326
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
382
327
outputs = self .model (
383
328
input_ids = input_ids ,
@@ -387,9 +332,6 @@ def forward(
387
332
batch_index = batch_index ,
388
333
inputs_embeds = inputs_embeds ,
389
334
use_cache = use_cache ,
390
- output_attentions = output_attentions ,
391
- output_hidden_states = output_hidden_states ,
392
- return_dict = return_dict ,
393
335
cache_position = cache_position ,
394
336
** kwargs ,
395
337
)
@@ -405,6 +347,4 @@ def forward(
405
347
loss = None ,
406
348
logits = logits ,
407
349
past_key_values = outputs .past_key_values ,
408
- hidden_states = outputs .hidden_states ,
409
- attentions = outputs .attentions ,
410
350
)
0 commit comments