Skip to content

Commit 69ec2a4

Browse files
committed
add Qwen3_moe and cleanup
Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent 3643fee commit 69ec2a4

File tree

16 files changed

+109
-847
lines changed

16 files changed

+109
-847
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
# Placeholder for all non-transformer models
9494
from .models.codegen.modeling_codegen import (
9595
QEffCodeGenAttention,
96-
QeffCodeGenBlock,
96+
QEffCodeGenBlock,
9797
QEffCodeGenForCausalLM,
9898
QEffCodeGenModel,
9999
)
@@ -224,7 +224,7 @@
224224
CodeGenAttention: QEffCodeGenAttention,
225225
CodeGenModel: QEffCodeGenModel,
226226
CodeGenForCausalLM: QEffCodeGenForCausalLM,
227-
CodeGenBlock: QeffCodeGenBlock,
227+
CodeGenBlock: QEffCodeGenBlock,
228228
# Mistral model layers
229229
MistralAttention: QEffMistralAttention,
230230
MistralDecoderLayer: QEffMistralDecoderLayer,

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,7 @@ def forward(
123123
query = query.permute(0, 2, 1, 3)
124124

125125
if layer_past is not None:
126-
cache_kwargs = {
127-
"position_ids": position_ids,
128-
"batch_index": batch_index,
129-
}
126+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
130127
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
131128

132129
# compute self-attention: V x Softmax(QK^T)
@@ -163,12 +160,6 @@ def forward(
163160
cache_position: Optional[torch.LongTensor] = None,
164161
**kwargs, # NOOP kwargs, for now
165162
) -> Union[tuple, BaseModelOutputWithPast]:
166-
r"""
167-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
168-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
169-
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
170-
model's internal embedding lookup matrix.
171-
"""
172163
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
173164
output_hidden_states = (
174165
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -316,12 +307,6 @@ def forward(
316307
return_dict: Optional[bool] = None,
317308
cache_position: Optional[torch.LongTensor] = None,
318309
) -> Union[Tuple, CausalLMOutputWithPast]:
319-
r"""
320-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
321-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
322-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
323-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
324-
"""
325310
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
326311

327312
transformer_outputs = self.transformer(
@@ -358,9 +343,7 @@ def forward(
358343
)
359344

360345

361-
class QeffCodeGenBlock(CodeGenBlock):
362-
# Ignore copy
363-
346+
class QEffCodeGenBlock(CodeGenBlock):
364347
def forward(
365348
self,
366349
hidden_states: Optional[torch.FloatTensor],

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward(
140140
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
141141

142142
if layer_past is not None:
143-
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
143+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
144144
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
145145

146146
if attention_mask is not None:
@@ -161,10 +161,7 @@ def forward(
161161

162162
attn_output = self.dense(attn_output)
163163

164-
if output_attentions:
165-
return attn_output, layer_past, attention_scores
166-
else:
167-
return attn_output, layer_past
164+
return attn_output, attention_scores
168165

169166

170167
class QEffFalconDecoderLayer(FalconDecoderLayer):
@@ -192,7 +189,7 @@ def forward(
192189
attention_layernorm_out = self.input_layernorm(hidden_states)
193190

194191
# Self attention.
195-
attn_outputs = self.self_attention(
192+
attention_output, attn_weights = self.self_attention(
196193
attention_layernorm_out,
197194
layer_past=layer_past,
198195
attention_mask=attention_mask,
@@ -206,8 +203,6 @@ def forward(
206203
cache_position=cache_position,
207204
)
208205

209-
attention_output = attn_outputs[0]
210-
211206
if not self.config.new_decoder_architecture:
212207
if self.config.parallel_attn:
213208
mlp_layernorm_out = attention_layernorm_out
@@ -224,8 +219,6 @@ def forward(
224219
):
225220
mlp_layernorm_out = attention_layernorm_out
226221

227-
outputs = attn_outputs[1:]
228-
229222
# MLP.
230223
mlp_output = self.mlp(mlp_layernorm_out)
231224

@@ -234,12 +227,7 @@ def forward(
234227

235228
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
236229

237-
if use_cache:
238-
outputs = (output,) + outputs
239-
else:
240-
outputs = (output,) + outputs[1:]
241-
242-
return outputs # hidden_states, past_kv, attentions
230+
return output, attn_weights
243231

244232

245233
class QEffFalconModel(FalconModel):
@@ -366,22 +354,13 @@ def forward(
366354
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
367355
head_mask: Optional[torch.Tensor] = None,
368356
inputs_embeds: Optional[torch.Tensor] = None,
369-
labels: Optional[torch.Tensor] = None,
370357
use_cache: Optional[bool] = None,
371358
output_attentions: Optional[bool] = None,
372359
output_hidden_states: Optional[bool] = None,
373360
return_dict: Optional[bool] = None,
374361
cache_position: Optional[torch.LongTensor] = None,
375-
logits_to_keep: Union[int, torch.Tensor] = 0,
376362
**kwargs,
377363
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
378-
r"""
379-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
380-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
381-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
382-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
383-
"""
384-
385364
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386365

387366
transformer_outputs = self.transformer(

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 9 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Union
99

1010
import torch
1111
from torch import nn
@@ -104,7 +104,6 @@ def eager_attention_forward(
104104
value: torch.Tensor,
105105
attention_mask: Optional[torch.Tensor],
106106
scaling: float,
107-
**kwargs,
108107
):
109108
key_states = repeat_kv(key, module.num_key_value_groups)
110109
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -114,7 +113,6 @@ def eager_attention_forward(
114113
attn_weights = torch.where(
115114
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
116115
)
117-
118116
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
119117
attn_output = torch.matmul(attn_weights, value_states)
120118
attn_output = attn_output.transpose(1, 2).contiguous()
@@ -154,11 +152,10 @@ def forward(
154152
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
155153

156154
if past_key_value is not None:
157-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
158-
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
155+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
159156
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
160157

161-
attention_interface: Callable = eager_attention_forward
158+
attention_interface = eager_attention_forward
162159

163160
attn_output, attn_weights = attention_interface(
164161
self,
@@ -167,12 +164,11 @@ def forward(
167164
value_states,
168165
attention_mask,
169166
scaling=self.scaling,
170-
**kwargs,
171167
)
172168

173169
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
174170
attn_output = self.o_proj(attn_output)
175-
return attn_output, attn_weights, past_key_value
171+
return attn_output, attn_weights
176172

177173

178174
class QEffGemmaDecoderLayer(GemmaDecoderLayer):
@@ -189,7 +185,6 @@ def forward(
189185
position_ids: Optional[torch.LongTensor] = None,
190186
past_key_value: Optional[Cache] = None,
191187
batch_index: Optional[torch.LongTensor] = None,
192-
output_attentions: Optional[bool] = False,
193188
use_cache: Optional[bool] = False,
194189
cache_position: Optional[torch.LongTensor] = None,
195190
**kwargs,
@@ -200,9 +195,6 @@ def forward(
200195
attention_mask (`torch.FloatTensor`, *optional*):
201196
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
202197
query_sequence_length, key_sequence_length)` if default attention is used.
203-
output_attentions (`bool`, *optional*):
204-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
205-
returned tensors for more detail.
206198
use_cache (`bool`, *optional*):
207199
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
208200
(see `past_key_values`).
@@ -215,13 +207,12 @@ def forward(
215207
hidden_states = self.input_layernorm(hidden_states)
216208

217209
# Self Attention
218-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
210+
hidden_states, _ = self.self_attn(
219211
hidden_states=hidden_states,
220212
attention_mask=attention_mask,
221213
position_ids=position_ids,
222214
past_key_value=past_key_value,
223215
batch_index=batch_index,
224-
output_attentions=output_attentions,
225216
use_cache=use_cache,
226217
cache_position=cache_position,
227218
**kwargs,
@@ -234,15 +225,7 @@ def forward(
234225
hidden_states = self.mlp(hidden_states)
235226
hidden_states = residual + hidden_states
236227

237-
outputs = (hidden_states,)
238-
239-
if output_attentions:
240-
outputs += (self_attn_weights,)
241-
242-
if use_cache:
243-
outputs += (present_key_value,)
244-
245-
return outputs
228+
return hidden_states
246229

247230

248231
class QEffGemmaModel(GemmaModel):
@@ -261,18 +244,10 @@ def forward(
261244
batch_index: Optional[torch.LongTensor] = None,
262245
inputs_embeds: Optional[torch.FloatTensor] = None,
263246
use_cache: Optional[bool] = None,
264-
output_attentions: Optional[bool] = None,
265-
output_hidden_states: Optional[bool] = None,
266-
return_dict: Optional[bool] = None,
267247
cache_position: Optional[torch.LongTensor] = None,
268248
**kwargs,
269249
) -> Union[Tuple, BaseModelOutputWithPast]:
270-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271-
output_hidden_states = (
272-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273-
)
274250
use_cache = use_cache if use_cache is not None else self.config.use_cache
275-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276251

277252
if (input_ids is None) ^ (inputs_embeds is not None):
278253
raise ValueError(
@@ -306,46 +281,27 @@ def forward(
306281
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
307282
hidden_states = hidden_states * normalizer
308283

309-
# decoder layers
310-
all_hidden_states = () if output_hidden_states else None
311-
all_self_attns = () if output_attentions else None
312-
313284
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
314-
if output_hidden_states:
315-
all_hidden_states += (hidden_states,)
316-
317-
layer_outputs = decoder_layer(
285+
hidden_states = decoder_layer(
318286
hidden_states,
319287
attention_mask=causal_mask,
320288
position_ids=position_ids,
321289
past_key_value=past_key_values,
322290
batch_index=batch_index,
323-
output_attentions=output_attentions,
324291
use_cache=use_cache,
325292
cache_position=cache_position,
326293
**kwargs,
327294
)
328-
hidden_states = layer_outputs[0]
329-
330-
if output_attentions:
331-
all_self_attns += (layer_outputs[1],)
332295

333296
hidden_states = self.norm(hidden_states)
334297

335-
# add hidden states from the last decoder layer
336-
if output_hidden_states:
337-
all_hidden_states += (hidden_states,)
338-
339298
if return_legacy_cache:
340299
past_key_values = past_key_values.to_legacy_cache()
341300

342-
output = BaseModelOutputWithPast(
301+
return BaseModelOutputWithPast(
343302
last_hidden_state=hidden_states,
344-
past_key_values=past_key_values if use_cache else None,
345-
hidden_states=all_hidden_states,
346-
attentions=all_self_attns,
303+
past_key_values=past_key_values,
347304
)
348-
return output if return_dict else output.to_tuple()
349305

350306

351307
class QEffGemmaForCausalLM(GemmaForCausalLM):
@@ -363,21 +319,10 @@ def forward(
363319
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
364320
batch_index: Optional[torch.LongTensor] = None,
365321
inputs_embeds: Optional[torch.FloatTensor] = None,
366-
labels: Optional[torch.LongTensor] = None,
367322
use_cache: Optional[bool] = None,
368-
output_attentions: Optional[bool] = None,
369-
output_hidden_states: Optional[bool] = None,
370-
return_dict: Optional[bool] = None,
371323
cache_position: Optional[torch.LongTensor] = None,
372-
logits_to_keep: Union[int, torch.Tensor] = 0,
373324
**kwargs,
374325
) -> Union[Tuple, CausalLMOutputWithPast]:
375-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
376-
output_hidden_states = (
377-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
378-
)
379-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
380-
381326
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
382327
outputs = self.model(
383328
input_ids=input_ids,
@@ -387,9 +332,6 @@ def forward(
387332
batch_index=batch_index,
388333
inputs_embeds=inputs_embeds,
389334
use_cache=use_cache,
390-
output_attentions=output_attentions,
391-
output_hidden_states=output_hidden_states,
392-
return_dict=return_dict,
393335
cache_position=cache_position,
394336
**kwargs,
395337
)
@@ -405,6 +347,4 @@ def forward(
405347
loss=None,
406348
logits=logits,
407349
past_key_values=outputs.past_key_values,
408-
hidden_states=outputs.hidden_states,
409-
attentions=outputs.attentions,
410350
)

0 commit comments

Comments
 (0)