From d3dec44a6f52746742bfaabf5a07ec8109a5f9e9 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 15 Aug 2025 14:14:29 +0800 Subject: [PATCH 01/77] 2025/08/15 --- .../autoencoders/autoencoder_kl_qwenimage.py | 1070 +++++++++++++++++ .../transformers/transformer_qwenimage.py | 653 ++++++++++ .../diffusers/pipelines/qwenimage/__init__.py | 53 + .../pipelines/qwenimage/pipeline_output.py | 21 + .../pipelines/qwenimage/pipeline_qwenimage.py | 735 +++++++++++ .../qwenimage/pipeline_qwenimage_img2img.py | 839 +++++++++++++ .../qwenimage/pipeline_qwenimage_inpaint.py | 1025 ++++++++++++++++ .../pipelines/qwenimage/__init__.py | 0 .../pipelines/qwenimage/test_qwenimage.py | 239 ++++ .../qwenimage/test_qwenimage_img2img.py | 221 ++++ .../qwenimage/test_qwenimage_inpaint.py | 236 ++++ 11 files changed, 5092 insertions(+) create mode 100644 mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py create mode 100644 mindone/diffusers/models/transformers/transformer_qwenimage.py create mode 100644 mindone/diffusers/pipelines/qwenimage/__init__.py create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_output.py create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py create mode 100644 tests/diffusers_tests/pipelines/qwenimage/__init__.py create mode 100644 tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py create mode 100644 tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py create mode 100644 tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000000..87ac406592 --- /dev/null +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -0,0 +1,1070 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - arXiv: https://arxiv.org/abs/2503.20314 + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py new file mode 100644 index 0000000000..b790462f93 --- /dev/null +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -0,0 +1,653 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import mint, nn, ops +# import torch +# import torch.nn as nn +# import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +# from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +# from ...utils.torch_utils import maybe_allow_in_graph +from ...utils import logging +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention +# from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, RMSNorm +from ..layers_compat import unflatten, view_as_complex + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_timestep_embedding( + timesteps: ms.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> ms.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (ms.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + ms.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * mint.arange( + start=0, end=half_dim, dtype=ms.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = mint.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = mint.cat([mint.sin(emb), mint.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = mint.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = mint.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: ms.Tensor, + freqs_cis: Union[ms.Tensor, Tuple[ms.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[ms.Tensor, ms.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`ms.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (ms.Tensor): Key tensor to apply + freqs_cis (`Tuple[ms.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[ms.Tensor, ms.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = mint.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = mint.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = ops.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Cell): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def construct(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Cell): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = mint.arange(1024) + neg_index = mint.arange(1024).flip(0) * -1 - 1 + pos_freqs = mint.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + neg_freqs = mint.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + self.register_buffer("pos_freqs", pos_freqs, persistent=False) + self.register_buffer("neg_freqs", neg_freqs, persistent=False) + + # 是否使用 scale rope + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = mint.outer(index, 1.0 / mint.pow(theta, mint.arange(0, dim, 2).to(ms.float32).div(dim))) + freqs = mint.polar(mint.ones_like(freqs), freqs) + return freqs + + def construct(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + frame, height, width = video_fhw + # rope_key = f"{frame}_{height}_{width}" + + # if not torch.compiler.is_compiling(): # 未匹配 + # if rope_key not in self.rope_cache: + # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width) + # vid_freqs = self.rope_cache[rope_key] + # else: + # vid_freqs = self._compute_video_freqs(frame, height, width) + vid_freqs = self._compute_video_freqs(frame, height, width) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = mint.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = mint.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = mint.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + + # def __init__(self): + # if not hasattr(F, "scaled_dot_product_attention"): + # raise ImportError( + # "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + # ) + + def __call__( + self, + attn: Attention, + hidden_states: ms.Tensor, # Image stream + encoder_hidden_states: ms.Tensor = None, # Text stream + encoder_hidden_states_mask: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + image_rotary_emb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = unflatten(img_query, -1, (attn.heads, -1)).swapaxes(1, 2) + img_key = unflatten(img_key, -1, (attn.heads, -1)).swapaxes(1, 2) + img_value = unflatten(img_value, -1, (attn.heads, -1)).swapaxes(1, 2) + # img_query = img_query.unflatten(-1, (attn.heads, -1)) + # img_key = img_key.unflatten(-1, (attn.heads, -1)) + # img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = unflatten(txt_query, -1, (attn.heads, -1)).swapaxes(1, 2) + txt_key = unflatten(txt_key, -1, (attn.heads, -1)).swapaxes(1, 2) + txt_value = unflatten(txt_value, -1, (attn.heads, -1)).swapaxes(1, 2) + # txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + # txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + # txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = mint.cat([txt_query, img_query], dim=1) + joint_key = mint.cat([txt_key, img_key], dim=1) + joint_value = mint.cat([txt_value, img_value], dim=1) + + # Compute joint attention + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +# @maybe_allow_in_graph +class QwenImageTransformerBlock(nn.Cell): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: ms.Tensor, + encoder_hidden_states_mask: ms.Tensor, + temb: ms.Tensor, + image_rotary_emb: Optional[Tuple[ms.Tensor, ms.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[ms.Tensor, ms.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == ms.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == ms.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.CellList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: ms.Tensor = None, + encoder_hidden_states_mask: ms.Tensor = None, + timestep: ms.Tensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: ms.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[ms.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`ms.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`ms.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `ms.Tensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None and "scale" in attention_kwargs: + # weight the lora layers by setting `lora_scale` for each PEFT layer here + # and remove `lora_scale` from each PEFT layer at the end. + # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + # for index_block, block in enumerate(self.transformer_blocks): + # if torch.is_grad_enabled() and self.gradient_checkpointing: + # encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + # block, + # hidden_states, + # encoder_hidden_states, + # encoder_hidden_states_mask, + # temb, + # image_rotary_emb, + # ) + + # else: + # encoder_hidden_states, hidden_states = block( + # hidden_states=hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # encoder_hidden_states_mask=encoder_hidden_states_mask, + # temb=temb, + # image_rotary_emb=image_rotary_emb, + # joint_attention_kwargs=attention_kwargs, + # ) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/mindone/diffusers/pipelines/qwenimage/__init__.py b/mindone/diffusers/pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000..64265880e7 --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"] + _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] + _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] + _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_qwenimage import QwenImagePipeline + from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline + from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_output.py b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py new file mode 100644 index 0000000000..eef4b60e37 --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py new file mode 100644 index 0000000000..47549ab4af --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -0,0 +1,735 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py new file mode 100644 index 0000000000..4fc84a31cc --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -0,0 +1,839 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImageImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney" + >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0] + >>> images.save("qwenimage_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + strength, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py new file mode 100644 index 0000000000..5ffec0c447 --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -0,0 +1,1025 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImageInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0] + >>> image.save("qwenimage_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if masked_image.dim() == 4: + masked_image = masked_image.unsqueeze(2) + elif masked_image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == self.latent_channels: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/tests/diffusers_tests/pipelines/qwenimage/__init__.py b/tests/diffusers_tests/pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py new file mode 100644 index 0000000000..5b8a6cbb92 --- /dev/null +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + QwenImagePipeline, + QwenImageTransformer2DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = QwenImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = QwenImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + # fmt: off + expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208]) + # fmt: on + + generated_slice = generated_image.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py new file mode 100644 index 0000000000..afdbd2c44b --- /dev/null +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py @@ -0,0 +1,221 @@ +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. + +import random +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + QwenImageImg2ImgPipeline, + QwenImageTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class QwenImageImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = QwenImageImg2ImgPipeline + params = frozenset(["prompt", "image", "height", "width", "guidance_scale", "true_cfg_scale", "strength"]) + batch_params = frozenset(["prompt", "image"]) + image_params = frozenset(["image"]) + image_latents_params = frozenset(["latents"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_attention_slicing = True + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = QwenImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs).images[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs).images[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs).images[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py new file mode 100644 index 0000000000..11845e130f --- /dev/null +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py @@ -0,0 +1,236 @@ +# Copyright 2025 The HuggingFace Team. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + QwenImageInpaintPipeline, + QwenImageTransformer2DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = QwenImageInpaintPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = QwenImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) From 60069601d39b9b04ecab8214e928d58d5d633956 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 15 Aug 2025 17:18:19 +0800 Subject: [PATCH 02/77] 2025/8/15 17:18 revised --- mindone/diffusers/models/__init__.py | 4 + .../diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_qwenimage.py | 192 +++++++++--------- .../diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_qwenimage.py | 25 ++- .../diffusers/pipelines/qwenimage/__init__.py | 55 ++--- .../pipelines/qwenimage/pipeline_output.py | 2 + .../pipelines/qwenimage/pipeline_qwenimage.py | 3 + .../qwenimage/pipeline_qwenimage_img2img.py | 165 +++++++-------- .../qwenimage/pipeline_qwenimage_inpaint.py | 2 + 10 files changed, 209 insertions(+), 241 deletions(-) diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index c913574a3e..1255263693 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -31,6 +31,7 @@ "autoencoders.autoencoder_kl_ltx": ["AutoencoderKLLTXVideo"], "autoencoders.autoencoder_kl_magvit": ["AutoencoderKLMagvit"], "autoencoders.autoencoder_kl_mochi": ["AutoencoderKLMochi"], + "autoencoders.autoencoder_kl_qwenimage": ["AutoencoderKLQwenImage"], "autoencoders.autoencoder_kl_temporal_decoder": ["AutoencoderKLTemporalDecoder"], "autoencoders.autoencoder_kl_wan": ["AutoencoderKLWan"], "autoencoders.autoencoder_oobleck": ["AutoencoderOobleck"], @@ -77,6 +78,7 @@ "transformers.transformer_lumina2": ["Lumina2Transformer2DModel"], "transformers.transformer_mochi": ["MochiTransformer3DModel"], "transformers.transformer_omnigen": ["OmniGenTransformer2DModel"], + "transformers.transformer_qwenimage": ["QwenImageTransformer2DModel"], "transformers.transformer_sd3": ["SD3Transformer2DModel"], "transformers.transformer_temporal": ["TransformerTemporalModel"], "transformers.transformer_wan": ["WanTransformer3DModel"], @@ -105,6 +107,7 @@ AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, + AutoencoderKLQwenImage, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -152,6 +155,7 @@ OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + QwenImageTransformer2DModel, SanaTransformer2DModel, SD3Transformer2DModel, StableAudioDiTModel, diff --git a/mindone/diffusers/models/autoencoders/__init__.py b/mindone/diffusers/models/autoencoders/__init__.py index 4bf41bccad..53508641f2 100644 --- a/mindone/diffusers/models/autoencoders/__init__.py +++ b/mindone/diffusers/models/autoencoders/__init__.py @@ -9,6 +9,7 @@ from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_oobleck import AutoencoderOobleck diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 87ac406592..e4fe00df56 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -1,5 +1,8 @@ # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. # +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,15 +23,19 @@ from typing import List, Optional, Tuple, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint +import numpy as np + +import mindspore as ms +from mindspore import mint, nn, ops +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import logging -from ...utils.accelerate_utils import apply_forward_hook +# from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin @@ -40,7 +47,7 @@ CACHE_T = 2 -class QwenImageCausalConv3d(nn.Conv3d): +class QwenImageCausalConv3d(mint.nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -75,17 +82,17 @@ def __init__( self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def construct(self, x, cache_x=None): padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) + x = mint.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - return super().forward(x) + x = mint.nn.functional.pad(x, padding) + return super().construct(x) -class QwenImageRMS_norm(nn.Module): +class QwenImageRMS_norm(nn.Cell): r""" A custom RMS normalization layer. @@ -104,11 +111,11 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.channel_first = channel_first self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + self.gamma = nn.Parameter(mint.ones(shape)) + self.bias = nn.Parameter(mint.zeros(shape)) if bias else 0.0 - def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + def construct(self, x): + return mint.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias class QwenImageUpsample(nn.Upsample): @@ -116,17 +123,17 @@ class QwenImageUpsample(nn.Upsample): Perform upsampling while ensuring the output tensor has the same data type as the input. Args: - x (torch.Tensor): Input tensor to be upsampled. + x (ms.Tensor): Input tensor to be upsampled. Returns: - torch.Tensor: Upsampled tensor with the same data type as the input. + ms.Tensor: Upsampled tensor with the same data type as the input. """ - def forward(self, x): - return super().forward(x.float()).type_as(x) + def construct(self, x): + return super().construct(x.float()).type_as(x) -class QwenImageResample(nn.Module): +class QwenImageResample(nn.Cell): r""" A custom resampling module for 2D and 3D data. @@ -149,25 +156,25 @@ def __init__(self, dim: int, mode: str) -> None: if mode == "upsample2d": self.resample = nn.Sequential( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), + mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), + mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.Sequential(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.Sequential(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: - self.resample = nn.Identity() + self.resample = mint.nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): b, c, t, h, w = x.size() if self.mode == "upsample3d": if feat_cache is not None: @@ -179,11 +186,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk - cache_x = torch.cat( + cache_x = mint.cat( [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 ) if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([mint.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) if feat_cache[idx] == "Rep": x = self.time_conv(x) else: @@ -192,7 +199,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = mint.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) @@ -207,13 +214,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + x = self.time_conv(mint.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 return x -class QwenImageResidualBlock(nn.Module): +class QwenImageResidualBlock(nn.Cell): r""" A custom residual block module. @@ -242,9 +249,9 @@ def __init__( self.norm2 = QwenImageRMS_norm(out_dim, images=False) self.dropout = nn.Dropout(dropout) self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) - self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else mint.nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): # Apply shortcut connection h = self.conv_shortcut(x) @@ -256,7 +263,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x @@ -275,7 +282,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv2(x, feat_cache[idx]) feat_cache[idx] = cache_x @@ -287,7 +294,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x + h -class QwenImageAttentionBlock(nn.Module): +class QwenImageAttentionBlock(nn.Cell): r""" Causal self-attention with a single head. @@ -301,10 +308,10 @@ def __init__(self, dim): # layers self.norm = QwenImageRMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) + self.to_qkv = mint.nn.Conv2d(dim, dim * 3, 1) + self.proj = mint.nn.Conv2d(dim, dim, 1) - def forward(self, x): + def construct(self, x): identity = x batch_size, channels, time, height, width = x.size() @@ -318,7 +325,10 @@ def forward(self, x): q, k, v = qkv.chunk(3, dim=-1) # apply attention - x = F.scaled_dot_product_attention(q, k, v) + # x = F.scaled_dot_product_attention(q, k, v) + x = ops.operation.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( + q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), None, None, None, None + )[3].to(q.dtype) x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) @@ -332,7 +342,7 @@ def forward(self, x): return x + identity -class QwenImageMidBlock(nn.Module): +class QwenImageMidBlock(nn.Cell): """ Middle block for QwenImageVAE encoder and decoder. @@ -352,12 +362,12 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", for _ in range(num_layers): attentions.append(QwenImageAttentionBlock(dim)) resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): # First residual block x = self.resnets[0](x, feat_cache, feat_idx) @@ -371,7 +381,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x -class QwenImageEncoder3d(nn.Module): +class QwenImageEncoder3d(nn.Cell): r""" A 3D encoder module. @@ -414,7 +424,7 @@ def __init__( self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) # downsample blocks - self.down_blocks = nn.ModuleList([]) + self.down_blocks = nn.CellList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks for _ in range(num_res_blocks): @@ -438,13 +448,13 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv_in(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -469,7 +479,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv_out(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -478,7 +488,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x -class QwenImageUpBlock(nn.Module): +class QwenImageUpBlock(nn.Cell): """ A block that handles upsampling for the QwenImageVAE decoder. @@ -512,26 +522,26 @@ def __init__( resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) current_dim = out_dim - self.resnets = nn.ModuleList(resnets) + self.resnets = nn.CellList(resnets) # Add upsampling layer if needed self.upsamplers = None if upsample_mode is not None: - self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + self.upsamplers = nn.CellList([QwenImageResample(out_dim, mode=upsample_mode)]) self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): """ Forward pass through the upsampling block. Args: - x (torch.Tensor): Input tensor + x (ms.Tensor): Input tensor feat_cache (list, optional): Feature cache for causal convolutions feat_idx (list, optional): Feature index for cache management Returns: - torch.Tensor: Output tensor + ms.Tensor: Output tensor """ for resnet in self.resnets: if feat_cache is not None: @@ -547,7 +557,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x -class QwenImageDecoder3d(nn.Module): +class QwenImageDecoder3d(nn.Cell): r""" A 3D decoder module. @@ -594,7 +604,7 @@ def __init__( self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) # upsample blocks - self.up_blocks = nn.ModuleList([]) + self.up_blocks = nn.CellList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks if i > 0: @@ -626,14 +636,14 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def construct(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv_in(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -655,7 +665,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv_out(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -801,7 +811,7 @@ def _count_conv3d(model): self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num - def _encode(self, x: torch.Tensor): + def _encode(self, x: ms.Tensor): _, _, num_frame, height, width = x.shape if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): @@ -819,21 +829,21 @@ def _encode(self, x: torch.Tensor): feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx, ) - out = torch.cat([out, out_], 2) + out = mint.cat([out, out_], 2) enc = self.quant_conv(out) self.clear_cache() return enc - @apply_forward_hook + # @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: ms.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: r""" Encode a batch of images into latents. Args: - x (`torch.Tensor`): Input batch of images. + x (`ms.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. @@ -843,7 +853,7 @@ def encode( """ if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] - h = torch.cat(encoded_slices) + h = mint.cat(encoded_slices) else: h = self._encode(x) posterior = DiagonalGaussianDistribution(h) @@ -852,7 +862,7 @@ def encode( return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.Tensor, return_dict: bool = True): + def _decode(self, z: ms.Tensor, return_dict: bool = True): _, _, num_frame, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio @@ -868,22 +878,22 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) + out = mint.cat([out, out_], 2) - out = torch.clamp(out, min=-1.0, max=1.0) + out = mint.clamp(out, min=-1.0, max=1.0) self.clear_cache() if not return_dict: return (out,) return DecoderOutput(sample=out) - @apply_forward_hook - def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + # @apply_forward_hook + def decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderOutput, ms.Tensor]: r""" Decode a batch of images. Args: - z (`torch.Tensor`): Input batch of latent vectors. + z (`ms.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -894,7 +904,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] - decoded = torch.cat(decoded_slices) + decoded = mint.cat(decoded_slices) else: decoded = self._decode(z).sample @@ -902,7 +912,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) - def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_v(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( @@ -910,7 +920,7 @@ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b - def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + def blend_h(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( @@ -918,14 +928,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b - def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + def tiled_encode(self, x: ms.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. Args: - x (`torch.Tensor`): Input batch of videos. + x (`ms.Tensor`): Input batch of videos. Returns: - `torch.Tensor`: + `ms.Tensor`: The latent representation of the encoded videos. """ _, _, num_frames, height, width = x.shape @@ -964,7 +974,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) tile = self.quant_conv(tile) time.append(tile) - row.append(torch.cat(time, dim=2)) + row.append(mint.cat(time, dim=2)) rows.append(row) self.clear_cache() @@ -979,17 +989,17 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) - result_rows.append(torch.cat(result_row, dim=-1)) + result_rows.append(mint.cat(result_row, dim=-1)) - enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + enc = mint.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc - def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def tiled_decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderOutput, ms.Tensor]: r""" Decode a batch of images using a tiled decoder. Args: - z (`torch.Tensor`): Input batch of latent vectors. + z (`ms.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -1024,7 +1034,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod tile = self.post_quant_conv(tile) decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) time.append(decoded) - row.append(torch.cat(time, dim=2)) + row.append(mint.cat(time, dim=2)) rows.append(row) self.clear_cache() @@ -1039,24 +1049,24 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) - result_rows.append(torch.cat(result_row, dim=-1)) + result_rows.append(mint.cat(result_row, dim=-1)) - dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + dec = mint.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] if not return_dict: return (dec,) return DecoderOutput(sample=dec) - def forward( + def construct( self, - sample: torch.Tensor, + sample: ms.Tensor, sample_posterior: bool = False, return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.Tensor]: + generator: Optional[np.random.Generator] = None, + ) -> Union[DecoderOutput, ms.Tensor]: """ Args: - sample (`torch.Tensor`): Input sample. + sample (`ms.Tensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ diff --git a/mindone/diffusers/models/transformers/__init__.py b/mindone/diffusers/models/transformers/__init__.py index 577cb3ae11..f6f7b13b26 100644 --- a/mindone/diffusers/models/transformers/__init__.py +++ b/mindone/diffusers/models/transformers/__init__.py @@ -25,6 +25,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel +from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index b790462f93..4717a79ef9 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -21,7 +21,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mindspore as ms -import mindspore.mint.nn.functional as F from mindspore import mint, nn, ops # import torch # import torch.nn as nn @@ -77,7 +76,7 @@ def get_timestep_embedding( half_dim = embedding_dim // 2 exponent = -math.log(max_period) * mint.arange( - start=0, end=half_dim, dtype=ms.float32, device=timesteps.device + start=0, end=half_dim, dtype=ms.float32 ) exponent = exponent / (half_dim - downscale_freq_shift) @@ -124,7 +123,7 @@ def apply_rotary_emb_qwen( cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) + # cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit @@ -204,7 +203,7 @@ def rope_params(self, index, dim, theta=10000): freqs = mint.polar(mint.ones_like(freqs), freqs) return freqs - def construct(self, video_fhw, txt_seq_lens, device): + def construct(self, video_fhw, txt_seq_lens): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text @@ -388,17 +387,17 @@ def __init__( qk_norm=qk_norm, eps=eps, ) - self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm2 = mint.nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # Text processing modules self.txt_mod = nn.Sequential( - nn.SiLU(), - nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + mint.nn.SiLU(), + mint.nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) - self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm1 = mint.nn.LayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation - self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = mint.nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def _modulate(self, x, mod_params): @@ -528,8 +527,8 @@ def __init__( self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) - self.img_in = nn.Linear(in_channels, self.inner_dim) - self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + self.img_in = mint.nn.Linear(in_channels, self.inner_dim) + self.txt_in = mint.nn.Linear(joint_attention_dim, self.inner_dim) self.transformer_blocks = nn.CellList( [ @@ -543,7 +542,7 @@ def __init__( ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.proj_out = mint.nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @@ -610,7 +609,7 @@ def construct( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens) # for index_block, block in enumerate(self.transformer_blocks): # if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/mindone/diffusers/pipelines/qwenimage/__init__.py b/mindone/diffusers/pipelines/qwenimage/__init__.py index 64265880e7..c12ac23cd6 100644 --- a/mindone/diffusers/pipelines/qwenimage/__init__.py +++ b/mindone/diffusers/pipelines/qwenimage/__init__.py @@ -1,42 +1,20 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_torch_available, - is_transformers_available, -) - +"""Adapted from https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/qwenimage/__init__.py.""" -_dummy_objects = {} -_additional_imports = {} -_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]} +from typing import TYPE_CHECKING -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 +from ...utils import _LazyModule - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"] - _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] - _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] - _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] +_import_structure = { + "modeling_qwenimage": ["ReduxImageEncoder"], + "pipeline_qwenimage": ["QwenImagePipeline"], + "pipeline_qwenimage_img2img": ["QwenImageImg2ImgPipeline"], + "pipeline_qwenimage_inpaint": ["QwenImageInpaintPipeline"], + } -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 - else: - from .pipeline_qwenimage import QwenImagePipeline - from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline - from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline +if TYPE_CHECKING: + from .pipeline_qwenimage import QwenImagePipeline + from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline + from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline else: import sys @@ -45,9 +23,4 @@ globals()["__file__"], _import_structure, module_spec=__spec__, - ) - - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) - for name, value in _additional_imports.items(): - setattr(sys.modules[__name__], name, value) + ) \ No newline at end of file diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_output.py b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py index eef4b60e37..48a8b8464b 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_output.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py @@ -1,3 +1,5 @@ +"""Adapted from https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/qwenimage/pipeline_output.py.""" + from dataclasses import dataclass from typing import List, Union diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 47549ab4af..788086e3c1 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -1,5 +1,8 @@ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. # +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 4fc84a31cc..7dce470e45 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -1,27 +1,24 @@ +"""Adapted from https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py.""" + import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import torch -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +import mindspore as ms +from mindspore import mint +from transformers import Qwen2Tokenizer +from ....transformers import Qwen2_5_VLForConditionalGeneration from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils.mindspore_utils import randn_tensor #, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - +XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -45,7 +42,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -75,7 +72,6 @@ def calculate_shift( def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, @@ -90,8 +86,6 @@ def retrieve_timesteps( num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. @@ -100,7 +94,7 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -112,7 +106,7 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: @@ -122,11 +116,11 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @@ -183,11 +177,11 @@ def __init__( self.default_sample_size = 128 # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + def _extract_masked_hidden(self, hidden_states: ms.Tensor, mask: ms.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] - split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + split_result = mint.split(selected, valid_lengths.tolist(), dim=0) return split_result @@ -195,10 +189,8 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + dtype: Optional[ms.dtype] = None, ): - device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -208,7 +200,7 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" - ).to(device) + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -217,36 +209,36 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + encoder_attention_mask = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype) return prompt_embeds, encoder_attention_mask - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] - image_latents = torch.cat(image_latents, dim=0) + image_latents = mint.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(image_latents.device, image_latents.dtype) + .to(image_latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - image_latents.device, image_latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.dtype ) image_latents = (image_latents - latents_mean) * latents_std @@ -254,7 +246,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -269,10 +261,9 @@ def get_timesteps(self, num_inference_steps, strength, device): def encode_prompt( self, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, max_sequence_length: int = 1024, ): r""" @@ -280,21 +271,17 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - device: (`torch.device`): - torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. """ - device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -365,10 +352,10 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + def _prepare_latent_image_ids(batch_size, height, width, dtype): + latent_image_ids = mint.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -376,7 +363,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids.to(dtype=dtype) @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents @@ -442,7 +429,6 @@ def prepare_latents( height, width, dtype, - device, generator, latents=None, ): @@ -465,10 +451,10 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) + return latents.to(dtype=dtype), latent_image_ids - image = image.to(device=device, dtype=dtype) + image = image.to(dtype=dtype) if image.shape[1] != self.latent_channels: image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] else: @@ -476,20 +462,20 @@ def prepare_latents( if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + image_latents = mint.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: - image_latents = torch.cat([image_latents], dim=0) + image_latents = mint.cat([image_latents], dim=0) image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) return latents, latent_image_ids @@ -513,8 +499,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -528,12 +512,12 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_mask: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -552,7 +536,7 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a @@ -585,17 +569,17 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/np.random.Generator.html) to make generation deterministic. - latents (`torch.Tensor`, *optional*): + latents (`ms.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): + negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. @@ -652,7 +636,7 @@ def __call__( # 2. Preprocess image init_image = self.image_processor.preprocess(image, height=height, width=width) - init_image = init_image.to(dtype=torch.float32) + init_image = init_image.to(dtype=ms.float32) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -662,8 +646,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) @@ -672,7 +654,6 @@ def __call__( prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -681,7 +662,6 @@ def __call__( prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -699,11 +679,10 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, sigmas=sigmas, mu=mu, ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" @@ -721,7 +700,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, latents, ) @@ -732,7 +710,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = mint.full([1], guidance_scale, dtype=ms.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None @@ -782,8 +760,8 @@ def __call__( )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = mint.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 @@ -791,9 +769,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} @@ -808,9 +784,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - self._current_timestep = None if output_type == "latent": image = latents @@ -818,12 +791,12 @@ def __call__( latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = latents.to(self.vae.dtype) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype ) latents = latents / latents_std + latents_mean diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 5ffec0c447..3953440d30 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -1,3 +1,5 @@ +"""Adapted from https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py.""" + import inspect from typing import Any, Callable, Dict, List, Optional, Union From 15bc8aefcdcecfe29adf119068b194381bc814ba Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 18 Aug 2025 10:22:35 +0800 Subject: [PATCH 03/77] 2025/8/18 10:22 revised --- .../pipelines/qwenimage/pipeline_qwenimage.py | 140 +++++-------- .../qwenimage/pipeline_qwenimage_img2img.py | 11 +- .../qwenimage/pipeline_qwenimage_inpaint.py | 193 ++++++++---------- 3 files changed, 140 insertions(+), 204 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 788086e3c1..0a42cfaf41 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -19,41 +19,35 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import torch -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +import mindspore as ms +from mindspore import mint +from transformers import Qwen2Tokenizer +from ....transformers import Qwen2_5_VLForConditionalGeneration from ...image_processor import VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - +XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py - >>> import torch - >>> from diffusers import QwenImagePipeline + >>> import mindspore as ms + >>> from mindone.diffusers import QwenImagePipeline - >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", mindspore_dtype=ms.bfloat16) >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image = pipe(prompt, num_inference_steps=50)[0][0] >>> image.save("qwenimage.png") ``` """ @@ -76,7 +70,6 @@ def calculate_shift( def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, @@ -91,8 +84,6 @@ def retrieve_timesteps( num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. @@ -101,7 +92,7 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -113,7 +104,7 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: @@ -123,11 +114,11 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @@ -180,21 +171,19 @@ def __init__( self.prompt_template_encode_start_idx = 34 self.default_sample_size = 128 - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + def _extract_masked_hidden(self, hidden_states: ms.Tensor, mask: ms.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] - split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + split_result = mint.split(selected, valid_lengths.tolist(), dim=0) return split_result def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + dtype: Optional[ms.dtype] = None, ): - device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -204,7 +193,7 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" - ).to(device) + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -213,26 +202,25 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + encoder_attention_mask = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype) return prompt_embeds, encoder_attention_mask def encode_prompt( self, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, max_sequence_length: int = 1024, ): r""" @@ -240,21 +228,17 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - device: (`torch.device`): - torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. """ - device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -320,10 +304,10 @@ def check_inputs( raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + def _prepare_latent_image_ids(batch_size, height, width, dtype): + latent_image_ids = mint.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -331,7 +315,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids.to(dtype=dtype) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -393,7 +377,6 @@ def prepare_latents( height, width, dtype, - device, generator, latents=None, ): @@ -405,8 +388,8 @@ def prepare_latents( shape = (batch_size, 1, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) + return latents.to(dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -414,10 +397,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) return latents, latent_image_ids @@ -441,8 +424,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -454,12 +435,12 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_mask: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -502,14 +483,14 @@ def __call__( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.Tensor`, *optional*): + latents (`ms.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): + negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. @@ -571,8 +552,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) @@ -581,7 +560,6 @@ def __call__( prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -590,7 +568,6 @@ def __call__( prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -603,7 +580,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, latents, ) @@ -622,7 +598,6 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, sigmas=sigmas, mu=mu, ) @@ -631,7 +606,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = mint.full([1], guidance_scale, dtype=ms.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None @@ -682,8 +657,8 @@ def __call__( )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = mint.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 @@ -691,9 +666,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} @@ -708,9 +681,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - self._current_timestep = None if output_type == "latent": image = latents @@ -718,12 +688,12 @@ def __call__( latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = latents.to(self.vae.dtype) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype ) latents = latents / latents_std + latents_mean image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 7dce470e45..ce6ab2bb40 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -25,16 +25,15 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> import torch - >>> from diffusers import QwenImageImg2ImgPipeline - >>> from diffusers.utils import load_image + >>> import mindspore + >>> from mindone.diffusers import QwenImageImg2ImgPipeline + >>> from mindone.diffusers.utils import load_image - >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) - >>> pipe = pipe.to("cuda") + >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", mindspore_dtype=mindspore.bfloat16) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> init_image = load_image(url).resize((1024, 1024)) >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney" - >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0] + >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95)[0][0] >>> images.save("qwenimage_img2img.png") ``` """ diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 3953440d30..1bc0c4a35e 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -5,44 +5,38 @@ import numpy as np import PIL.Image -import torch -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +import mindspore as ms +from mindspore import mint +from transformers import Qwen2Tokenizer +from ....transformers import Qwen2_5_VLForConditionalGeneration from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - +XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py - >>> import torch - >>> from diffusers import QwenImageInpaintPipeline - >>> from diffusers.utils import load_image + >>> import mindspore as ms + >>> from mindone.diffusers import QwenImageInpaintPipeline + >>> from mindone.diffusers.utils import load_image - >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", mindspore_dtype=ms.bfloat16) >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> source = load_image(img_url) >>> mask = load_image(mask_url) - >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0] + >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85)[0][0] >>> image.save("qwenimage_inpainting.png") ``` """ @@ -50,7 +44,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -80,7 +74,6 @@ def calculate_shift( def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, @@ -95,8 +88,6 @@ def retrieve_timesteps( num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. @@ -105,7 +96,7 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -117,7 +108,7 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: @@ -127,11 +118,11 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @@ -195,11 +186,11 @@ def __init__( self.default_sample_size = 128 # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + def _extract_masked_hidden(self, hidden_states: ms.Tensor, mask: ms.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] - split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + split_result = mint.split(selected, valid_lengths.tolist(), dim=0) return split_result @@ -207,10 +198,8 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + dtype: Optional[ms.dtype] = None, ): - device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -220,7 +209,7 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" - ).to(device) + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -229,37 +218,37 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + encoder_attention_mask = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype) return prompt_embeds, encoder_attention_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] - image_latents = torch.cat(image_latents, dim=0) + image_latents = mint.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(image_latents.device, image_latents.dtype) + .to(image_latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - image_latents.device, image_latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.dtype ) image_latents = (image_latents - latents_mean) * latents_std @@ -267,7 +256,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -282,10 +271,9 @@ def get_timesteps(self, num_inference_steps, strength, device): def encode_prompt( self, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, max_sequence_length: int = 1024, ): r""" @@ -293,21 +281,17 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - device: (`torch.device`): - torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. """ - device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -394,10 +378,10 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + def _prepare_latent_image_ids(batch_size, height, width, dtype): + latent_image_ids = mint.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -405,7 +389,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids.to(dtype=dtype) @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents @@ -471,7 +455,6 @@ def prepare_latents( height, width, dtype, - device, generator, latents=None, ): @@ -494,10 +477,10 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) + return latents.to(dtype=dtype), latent_image_ids - image = image.to(device=device, dtype=dtype) + image = image.to(dtype=dtype) if image.shape[1] != self.latent_channels: image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] else: @@ -505,28 +488,28 @@ def prepare_latents( if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + image_latents = mint.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: - image_latents = torch.cat([image_latents], dim=0) + image_latents = mint.cat([image_latents], dim=0) image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: - noise = latents.to(device) + noise = latents latents = noise noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) return latents, noise, image_latents, latent_image_ids @@ -540,7 +523,6 @@ def prepare_mask_latents( height, width, dtype, - device, generator, ): # VAE applies 8x compression on images but we must also account for packing which requires @@ -550,8 +532,8 @@ def prepare_mask_latents( # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) - mask = mask.to(device=device, dtype=dtype) + mask = mint.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -560,7 +542,7 @@ def prepare_mask_latents( elif masked_image.dim() != 5: raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") - masked_image = masked_image.to(device=device, dtype=dtype) + masked_image = masked_image.to(dtype=dtype) if masked_image.shape[1] == self.latent_channels: masked_image_latents = masked_image @@ -585,8 +567,7 @@ def prepare_mask_latents( ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = masked_image_latents.to(dtype=dtype) masked_image_latents = self._pack_latents( masked_image_latents, @@ -625,8 +606,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -643,12 +622,12 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_mask: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -667,7 +646,7 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a @@ -675,14 +654,14 @@ def __call__( latents as `image`, but if passing latents directly it is not encoded again. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. - mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + mask_image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + mask_image_latent (`ms.Tensor`, `List[ms.Tensor]`): `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask latents tensor will ge generated by `mask_image`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -717,17 +696,17 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/np.random.Generator.html) to make generation deterministic. - latents (`torch.Tensor`, *optional*): + latents (`ms.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): + negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. @@ -798,7 +777,7 @@ def __call__( init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) - init_image = init_image.to(dtype=torch.float32) + init_image = init_image.to(dtype=ms.float32) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -808,8 +787,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) @@ -818,7 +795,6 @@ def __call__( prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -827,7 +803,6 @@ def __call__( prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -845,11 +820,10 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, sigmas=sigmas, mu=mu, ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) if num_inference_steps < 1: raise ValueError( @@ -869,7 +843,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, latents, ) @@ -892,7 +865,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, ) @@ -903,7 +875,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = mint.full([1], guidance_scale, dtype=ms.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None @@ -953,8 +925,8 @@ def __call__( )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = mint.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 @@ -968,15 +940,13 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise + init_latents_proper, ms.Tensor([noise_timestep]), noise ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} @@ -991,9 +961,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - self._current_timestep = None if output_type == "latent": image = latents @@ -1001,12 +968,12 @@ def __call__( latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = latents.to(self.vae.dtype) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype ) latents = latents / latents_std + latents_mean From 103db5063a6ed2f143957d5d77fb2c2e7c02f620 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 18 Aug 2025 17:00:21 +0800 Subject: [PATCH 04/77] 2025/8/18 17:00 revised --- mindone/diffusers/__init__.py | 12 + .../diffusers/models/attention_dispatch.py | 1218 +++++++++++++++++ .../transformers/transformer_qwenimage.py | 76 +- mindone/diffusers/pipelines/__init__.py | 12 + .../diffusers/pipelines/qwenimage/__init__.py | 2 + .../pipelines/qwenimage/pipeline_qwenimage.py | 28 +- .../qwenimage/pipeline_qwenimage_edit.py | 851 ++++++++++++ .../qwenimage/pipeline_qwenimage_img2img.py | 29 +- .../qwenimage/pipeline_qwenimage_inpaint.py | 29 +- 9 files changed, 2160 insertions(+), 97 deletions(-) create mode 100644 mindone/diffusers/models/attention_dispatch.py create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index f18d0af574..c607ae65e7 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -28,6 +28,7 @@ "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", + "AutoencoderKLQwenImage", "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", @@ -65,6 +66,7 @@ "OmniGenTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", + "QwenImageTransformer2DModel", "SanaControlNetModel", "SanaTransformer2DModel", "SD3ControlNetModel", @@ -210,6 +212,10 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "QwenImageImg2ImgPipeline", + "QwenImageInpaintPipeline", + "QwenImagePipeline", + "QwenImageEditPipeline", "ReduxImageEncoder", "SanaControlNetPipeline", "SanaPAGPipeline", @@ -357,6 +363,7 @@ AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, + AutoencoderKLQwenImage, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -394,6 +401,7 @@ OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + QwenImageTransformer2DModel, SanaControlNetModel, SanaTransformer2DModel, SD3ControlNetModel, @@ -538,6 +546,10 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + QwenImageEditPipeline, + QwenImageImg2ImgPipeline, + QwenImageInpaintPipeline, + QwenImagePipeline, ReduxImageEncoder, SanaControlNetPipeline, SanaPAGPipeline, diff --git a/mindone/diffusers/models/attention_dispatch.py b/mindone/diffusers/models/attention_dispatch.py new file mode 100644 index 0000000000..7cc30e47ab --- /dev/null +++ b/mindone/diffusers/models/attention_dispatch.py @@ -0,0 +1,1218 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +import math +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + get_logger, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_npu_available, + is_torch_version, + is_torch_xla_available, + is_torch_xla_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS + + +_REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_SAGE_VERSION = "2.1.1" +_REQUIRED_FLEX_VERSION = "2.5.0" +_REQUIRED_XLA_VERSION = "2.2" +_REQUIRED_XFORMERS_VERSION = "0.0.29" + +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) +_CAN_USE_NPU_ATTN = is_torch_npu_available() +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) + + +if _CAN_USE_FLASH_ATTN: + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + flash_attn_func = None + flash_attn_varlen_func = None + + +if _CAN_USE_FLASH_ATTN_3: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + + +if _CAN_USE_SAGE_ATTN: + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if _CAN_USE_FLEX_ATTN: + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention + + +if _CAN_USE_NPU_ATTN: + from torch_npu import npu_fusion_attention +else: + npu_fusion_attention = None + + +if _CAN_USE_XLA_ATTN: + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention +else: + xla_flash_attention = None + + +if _CAN_USE_XFORMERS_ATTN: + import xformers.ops as xops +else: + xops = None + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +# TODO(aryan): Add support for the following: +# - Sage Attention++ +# - block sparse, radial and other attention methods +# - CP with sage attention, flex, xformers, other missing backends +# - Add support for normal and CP training with backends that don't support it yet + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionBackendName(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionBackendRegistry: + _backends = {} + _constraints = {} + _supported_arg_names = {} + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") + + def decorator(func): + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] + + @classmethod + def list_backends(cls): + return list(cls._backends.keys()) + + +@contextlib.contextmanager +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): + """ + Context manager to set the active attention backend. + """ + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") + + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend + + try: + yield + finally: + _AttentionBackendRegistry._active_backend = old_backend + + +def dispatch_attention_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + **attention_kwargs, + } + if is_torch_version(">=", "2.5.0"): + kwargs["enable_gqa"] = enable_gqa + + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) + + +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +# ===== Helper functions ===== + + +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if not _CAN_USE_FLASH_ATTN: + raise RuntimeError( + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." + ) + + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: + if not _CAN_USE_FLASH_ATTN_3: + raise RuntimeError( + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." + ) + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + if not _CAN_USE_SAGE_ATTN: + raise RuntimeError( + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." + ) + + elif backend == AttentionBackendName.FLEX: + if not _CAN_USE_FLEX_ATTN: + raise RuntimeError( + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." + ) + + elif backend == AttentionBackendName._NATIVE_NPU: + if not _CAN_USE_NPU_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + + elif backend == AttentionBackendName._NATIVE_XLA: + if not _CAN_USE_XLA_ATTN: + raise RuntimeError( + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." + ) + + elif backend == AttentionBackendName.XFORMERS: + if not _CAN_USE_XFORMERS_ATTN: + raise RuntimeError( + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." + ) + + +@functools.lru_cache(maxsize=128) +def _prepare_for_flash_attn_or_sage_varlen_without_mask( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size: int, + seq_len_q: int, + attn_mask: torch.Tensor, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, +) -> None: + if attn_mask is None: + return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) + return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== torch op registrations ===== +# Registrations are required for fullgraph tracing compatibility + + +# TODO: library.custom_op and register_fake probably need version guards? +# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding +# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3_original( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) + + +# ===== Attention backends ===== + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out, lse, *_ = flash_attn_3_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + attention_chunk=0, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_VARLEN_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out, lse, *_ = flash_attn_3_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.unflatten(0, (batch_size, -1)) + + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = flex_attention.create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = flex_attention.flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=kernel_options, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, # not supported + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_NPU, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_npu_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + return npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + +# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_XLA, + constraints=[_check_device, _check_shape], +) +def _native_xla_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + query = query / math.sqrt(query.shape[-1]) + out = xla_flash_attention( + q=query, + k=key, + v=value, + causal=is_causal, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="NHD", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, seq_len_q, num_heads_q, _ = query.shape + _, seq_len_kv, num_heads_kv, _ = key.shape + + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + + if enable_gqa: + out = out.flatten(2, 3) + + return out diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index 4717a79ef9..6ac4b046a2 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -168,9 +168,9 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim - pos_index = mint.arange(1024) - neg_index = mint.arange(1024).flip(0) * -1 - 1 - pos_freqs = mint.cat( + pos_index = mint.arange(4096) + neg_index = mint.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = mint.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), @@ -178,7 +178,7 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): ], dim=1, ) - neg_freqs = mint.cat( + self.neg_freqs = mint.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), @@ -187,10 +187,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): dim=1, ) self.rope_cache = {} - self.register_buffer("pos_freqs", pos_freqs, persistent=False) - self.register_buffer("neg_freqs", neg_freqs, persistent=False) - # 是否使用 scale rope + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): @@ -210,34 +208,44 @@ def construct(self, video_fhw, txt_seq_lens): """ if isinstance(video_fhw, list): video_fhw = video_fhw[0] - frame, height, width = video_fhw - # rope_key = f"{frame}_{height}_{width}" - - # if not torch.compiler.is_compiling(): # 未匹配 - # if rope_key not in self.rope_cache: - # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width) - # vid_freqs = self.rope_cache[rope_key] - # else: - # vid_freqs = self._compute_video_freqs(frame, height, width) - vid_freqs = self._compute_video_freqs(frame, height, width) - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) + + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + # jit-related, 25/8/18. Remain to fix. + # if not torch.compiler.is_compiling(): + # if rope_key not in self.rope_cache: + # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + # video_freq = self.rope_cache[rope_key] + # else: + # video_freq = self._compute_video_freqs(frame, height, width) + # vid_freqs.append(video_freq) + video_freq = self._compute_video_freqs(frame, height, width) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = mint.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width): + def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = mint.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) @@ -329,14 +337,18 @@ def __call__( joint_value = mint.cat([txt_value, img_value], dim=1) # Compute joint attention - joint_hidden_states = dispatch_attention_fn( - joint_query, - joint_key, - joint_value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, + # NOTICE! 2025/8/18. Replace in the present version. + # joint_hidden_states = dispatch_attention_fn( + # joint_query, + # joint_key, + # joint_value, + # attn_mask=attention_mask, + # dropout_p=0.0, + # is_causal=False, + # backend=self._attention_backend, + # ) + joint_hidden_states = attn.scaled_dot_product_attention( + joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # Reshape back diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index 0b8f911e63..173bd904cc 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -243,6 +243,12 @@ "WuerstchenPriorPipeline", ], "wan": ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"], + "qwenimage": [ + "QwenImageEditPipeline", + "QwenImageImg2ImgPipeline", + "QwenImageInpaintPipeline", + "QwenImagePipeline", + ], "pipeline_utils": [ "AudioPipelineOutput", "DiffusionPipeline", @@ -388,6 +394,12 @@ from .pia import PIAPipeline from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline + from .qwenimage import ( + QwenImageEditPipeline, + QwenImageImg2ImgPipeline, + QwenImageInpaintPipeline, + QwenImagePipeline, + ) from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline diff --git a/mindone/diffusers/pipelines/qwenimage/__init__.py b/mindone/diffusers/pipelines/qwenimage/__init__.py index c12ac23cd6..b6f05c169f 100644 --- a/mindone/diffusers/pipelines/qwenimage/__init__.py +++ b/mindone/diffusers/pipelines/qwenimage/__init__.py @@ -9,10 +9,12 @@ "pipeline_qwenimage": ["QwenImagePipeline"], "pipeline_qwenimage_img2img": ["QwenImageImg2ImgPipeline"], "pipeline_qwenimage_inpaint": ["QwenImageInpaintPipeline"], + "pipeline_qwenimage_edit": ["QwenImageEditPipeline"], } if TYPE_CHECKING: from .pipeline_qwenimage import QwenImagePipeline + from .pipeline_qwenimage_edit import QwenImageEditPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline else: diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 0a42cfaf41..82b5d73a40 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -240,6 +240,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -303,20 +306,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, dtype): - latent_image_ids = mint.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(dtype=dtype) - @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) @@ -388,8 +377,7 @@ def prepare_latents( shape = (batch_size, 1, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - return latents.to(dtype=dtype), latent_image_ids + return latents.to(dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -400,9 +388,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - - return latents, latent_image_ids + return latents @property def guidance_scale(self): @@ -574,7 +560,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -583,7 +569,7 @@ def __call__( generator, latents, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py new file mode 100644 index 0000000000..56a2513b1d --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -0,0 +1,851 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import mindspore as ms +from mindspore import mint +from transformers import Qwen2Tokenizer, Qwen2VLProcessor + +from ....transformers import Qwen2_5_VLForConditionalGeneration +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + +XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from mindone.diffusers import QwenImageEditPipeline + >>> from mindone.diffusers.utils import load_image + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", mindspore_dtype=mindspore.bfloat16) + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50)[0][0] + >>> image.save("qwenimage_edit.png") + ``` +""" +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: ms.Tensor, mask: ms.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = mint.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[ms.Tensor] = None, + dtype: Optional[mint.dtype] = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = mint.stack( + [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[ms.Tensor] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`ms.Tensor`, *optional*): + image to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = mint.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + ms.Tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.dtype) + ) + latents_std = ( + ms.Tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = mint.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = mint.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_mask: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/np.random.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + width, height = image_size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Preprocess image + if image is not None and not (isinstance(image, ms.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + prompt_image = image + image = self.image_processor.preprocess(image, image_height, image_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + # negative image is the same size as the original image, but all pixels are white + # negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255)) + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = mint.full([1], guidance_scale, dtype=ms.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = mint.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = mint.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + ms.Tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.dtype) + ) + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index ce6ab2bb40..141c0fd6e2 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -282,6 +282,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -349,21 +352,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, dtype): - latent_image_ids = mint.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(dtype=dtype) - @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -450,8 +438,7 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - return latents.to(dtype=dtype), latent_image_ids + return latents.to(dtype=dtype) image = image.to(dtype=dtype) if image.shape[1] != self.latent_channels: @@ -474,9 +461,7 @@ def prepare_latents( latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - - return latents, latent_image_ids + return latents @property def guidance_scale(self): @@ -691,7 +676,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + latents = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, @@ -702,7 +687,7 @@ def __call__( generator, latents, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1bc0c4a35e..ee74aa41ec 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -293,6 +293,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -376,21 +379,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, dtype): - latent_image_ids = mint.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + mint.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + mint.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(dtype=dtype) - @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -477,8 +465,7 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - return latents.to(dtype=dtype), latent_image_ids + return latents.to(dtype=dtype) image = image.to(dtype=dtype) if image.shape[1] != self.latent_channels: @@ -509,9 +496,7 @@ def prepare_latents( image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) - - return latents, noise, image_latents, latent_image_ids + return latents, noise, image_latents def prepare_mask_latents( self, @@ -835,7 +820,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, noise, image_latents, latent_image_ids = self.prepare_latents( + latents, noise, image_latents = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, @@ -868,7 +853,7 @@ def __call__( generator, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From 77779b5503be9768b3a893c86944656eaf325e8e Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 18 Aug 2025 19:08:35 +0800 Subject: [PATCH 05/77] 2025/8/18 19:08 revised --- mindone/diffusers/loaders/__init__.py | 2 + .../loaders/lora_conversion_utils.py | 35 ++ mindone/diffusers/loaders/lora_pipeline.py | 331 ++++++++++++++++++ .../transformers/transformer_qwenimage.py | 2 +- 4 files changed, 369 insertions(+), 1 deletion(-) diff --git a/mindone/diffusers/loaders/__init__.py b/mindone/diffusers/loaders/__init__.py index adee07f818..0299e5e487 100644 --- a/mindone/diffusers/loaders/__init__.py +++ b/mindone/diffusers/loaders/__init__.py @@ -76,6 +76,7 @@ def text_encoder_attn_modules(text_encoder): "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", + "QwenImageLoraLoaderMixin" ], "peft": ["PeftAdapterMixin"], "single_file": ["FromSingleFileMixin"], @@ -98,6 +99,7 @@ def text_encoder_attn_modules(text_encoder): LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, Mochi1LoraLoaderMixin, + QwenImageLoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, diff --git a/mindone/diffusers/loaders/lora_conversion_utils.py b/mindone/diffusers/loaders/lora_conversion_utils.py index 4e33a4c030..352e6235e2 100644 --- a/mindone/diffusers/loaders/lora_conversion_utils.py +++ b/mindone/diffusers/loaders/lora_conversion_utils.py @@ -1920,3 +1920,38 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} return converted_state_dict + +def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): + converted_state_dict = {} + all_keys = list(state_dict.keys()) + down_key = ".lora_down.weight" + up_key = ".lora_up.weight" + + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha = state_dict.pop(alpha_key).item() + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + + for k in all_keys: + if k.endswith(down_key): + diffusers_down_key = k.replace(down_key, ".lora_A.weight") + diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight") + alpha_key = k.replace(down_key, ".alpha") + + down_weight = state_dict.pop(k) + up_weight = state_dict.pop(k.replace(down_key, up_key)) + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[diffusers_down_key] = down_weight * scale_down + converted_state_dict[diffusers_up_key] = up_weight * scale_up + + if len(state_dict) > 0: + raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}") + + converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} + return converted_state_dict diff --git a/mindone/diffusers/loaders/lora_pipeline.py b/mindone/diffusers/loaders/lora_pipeline.py index e12608d053..2bf2558694 100644 --- a/mindone/diffusers/loaders/lora_pipeline.py +++ b/mindone/diffusers/loaders/lora_pipeline.py @@ -41,6 +41,7 @@ _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, + _convert_non_diffusers_qwen_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, @@ -5658,3 +5659,333 @@ def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." # noqa: E501 deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) + +class QwenImageLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict) + if has_alphas_in_sd: + state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`QwenImageTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. + """ + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[nn.Cell, ms.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, + ): + r""" + Save the LoRA parameters corresponding to the transformer. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, ms.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + """ + state_dict = {} + lora_adapter_metadata = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) \ No newline at end of file diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index 6ac4b046a2..bbf4b3b90b 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -32,7 +32,7 @@ # from ...utils.torch_utils import maybe_allow_in_graph from ...utils import logging from ..attention import FeedForward -from ..attention_dispatch import dispatch_attention_fn +# from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention # from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps From 0cab22b487cfeee4a979aa03b4d226f22317c2c2 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 18 Aug 2025 19:13:41 +0800 Subject: [PATCH 06/77] 2025/8/18 19:13 revised --- mindone/diffusers/models/transformers/transformer_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index bbf4b3b90b..3a211654cf 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -169,7 +169,7 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): self.theta = theta self.axes_dim = axes_dim pos_index = mint.arange(4096) - neg_index = mint.arange(4096).flip(0) * -1 - 1 + neg_index = mint.arange(4096).flip(dims=[0]) * -1 - 1 self.pos_freqs = mint.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), From 2b7b4c9f8b8da57a56c990f1a12ea4b0fbfd20b5 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Aug 2025 09:02:14 +0800 Subject: [PATCH 07/77] 2025/8/19 9:02 revised --- .../autoencoders/autoencoder_kl_qwenimage.py | 14 +++++++------- .../models/transformers/transformer_qwenimage.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index e4fe00df56..656c3b40a4 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -111,8 +111,8 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.channel_first = channel_first self.scale = dim**0.5 - self.gamma = nn.Parameter(mint.ones(shape)) - self.bias = nn.Parameter(mint.zeros(shape)) if bias else 0.0 + self.gamma = ms.Parameter(mint.ones(shape)) + self.bias = ms.Parameter(mint.zeros(shape)) if bias else 0.0 def construct(self, x): return mint.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias @@ -154,21 +154,21 @@ def __init__(self, dim: int, mode: str) -> None: # layers if mode == "upsample2d": - self.resample = nn.Sequential( + self.resample = ms.SequentialCell( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": - self.resample = nn.Sequential( + self.resample = ms.SequentialCell( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == "downsample2d": - self.resample = nn.Sequential(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = ms.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = nn.Sequential(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = ms.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: @@ -247,7 +247,7 @@ def __init__( self.norm1 = QwenImageRMS_norm(in_dim, images=False) self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) self.norm2 = QwenImageRMS_norm(out_dim, images=False) - self.dropout = nn.Dropout(dropout) + self.dropout = mint.nn.Dropout(dropout) self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else mint.nn.Identity() diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index 3a211654cf..bdefaace2b 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -381,7 +381,7 @@ def __init__( self.attention_head_dim = attention_head_dim # Image processing modules - self.img_mod = nn.Sequential( + self.img_mod = nn.SequentialCell( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) @@ -403,7 +403,7 @@ def __init__( self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # Text processing modules - self.txt_mod = nn.Sequential( + self.txt_mod = nn.SequentialCell( mint.nn.SiLU(), mint.nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) From d7eaa37300b2840c01a90dfaa03ca625bb1978a1 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Aug 2025 09:04:56 +0800 Subject: [PATCH 08/77] 2025/8/19 9:04 revised --- .../models/autoencoders/autoencoder_kl_qwenimage.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 656c3b40a4..bb37223548 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -154,21 +154,21 @@ def __init__(self, dim: int, mode: str) -> None: # layers if mode == "upsample2d": - self.resample = ms.SequentialCell( + self.resample = nn.SequentialCell( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": - self.resample = ms.SequentialCell( + self.resample = nn.SequentialCell( QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == "downsample2d": - self.resample = ms.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == "downsample3d": - self.resample = ms.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.resample = nn.SequentialCell(mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2))) self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: From dddd8f296779c2963c5cf2fdc2cf96ee990e4927 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Aug 2025 09:12:42 +0800 Subject: [PATCH 09/77] 2025/8/19 9:12 revised --- .../diffusers/models/autoencoders/autoencoder_kl_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index bb37223548..469a46722d 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -736,10 +736,10 @@ def __init__( # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { - "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.cells_and_names()) if self.decoder is not None else 0, - "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.cells_and_names()) if self.encoder is not None else 0, } From 3117bdc22a186e71efdc04ea23ab1b008d4ccb4e Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Aug 2025 10:28:02 +0800 Subject: [PATCH 10/77] 2025/8/19 10:27 revised --- mindone/diffusers/models/transformers/transformer_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index bdefaace2b..adc43d5b6f 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -385,7 +385,7 @@ def __init__( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) - self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm1 = mint.nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Attention( query_dim=dim, cross_attention_dim=None, # Enable cross attention for joint computation From e19c2e327dd5651db63f70d2fdc4c52aea5c79fe Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 09:22:05 +0800 Subject: [PATCH 11/77] 2025/8/20 9:22 revised --- mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a07a08ee21..268423daf7 100644 --- a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1219,7 +1219,7 @@ def __init__(self, config): super().__init__(config) # TODO: we need this patch here, may fix later config.vision_config._attn_implementation = config._attn_implementation - config.vision_config.torch_dtype = getattr(config, "mindspore_dtype", None) + config.vision_config.mindspore_dtype = getattr(config, "mindspore_dtype", None) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLModel(config) self.vocab_size = config.vocab_size From 0fb127a8134b6337f0024374d28d03b95a0511d9 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 09:47:23 +0800 Subject: [PATCH 12/77] 2025/8/20 9:247 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 82b5d73a40..15c1f2b9ef 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -192,7 +192,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 56a2513b1d..6d11f318c9 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -255,7 +255,7 @@ def _get_qwen_prompt_embeds( text=txt, images=image, padding=True, - return_tensors="pt", + return_tensors="ms", ) outputs = self.text_encoder( diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 141c0fd6e2..2f3af0bc3a 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -198,7 +198,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index ee74aa41ec..b3302e43b7 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -208,7 +208,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, From 5d317bcfaa2d9916db78f063d399f7044b2e0206 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 09:48:09 +0800 Subject: [PATCH 13/77] 2025/8/20 9:48 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_output.py b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py index 48a8b8464b..a997d5f7bd 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_output.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_output.py @@ -20,4 +20,4 @@ class QwenImagePipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ - images: Union[List[PIL.Image.Image], np.ndarray] + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file From e8043d88cf3c315285b73913d97c6e80cb23f7ea Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 09:52:13 +0800 Subject: [PATCH 14/77] 2025/8/20 9:52 revised --- mindone/transformers/modeling_utils.py | 2 +- mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index c62042197c..8731ed7a42 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -950,7 +950,7 @@ def _from_config(cls, config, **kwargs): if isinstance(mindspore_dtype, str): mindspore_dtype = getattr(ms, mindspore_dtype) - elif mindspore_dtype is not None: + elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): TORCH_TO_MINDSPORE_DTYPE_MAP = { "torch.float32": ms.float32, "torch.bfloat16": ms.bfloat16, diff --git a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 268423daf7..0cb29635f0 100644 --- a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1219,7 +1219,7 @@ def __init__(self, config): super().__init__(config) # TODO: we need this patch here, may fix later config.vision_config._attn_implementation = config._attn_implementation - config.vision_config.mindspore_dtype = getattr(config, "mindspore_dtype", None) + # config.vision_config.mindspore_dtype = getattr(config, "mindspore_dtype", None) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLModel(config) self.vocab_size = config.vocab_size From b78ef0a96425f13d31f9d87f85a174107946b4a9 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 10:15:53 +0800 Subject: [PATCH 15/77] 2025/8/20 10:15 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py | 2 +- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 15c1f2b9ef..89c432211f 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -192,7 +192,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 6d11f318c9..4d73cd13e2 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -255,7 +255,7 @@ def _get_qwen_prompt_embeds( text=txt, images=image, padding=True, - return_tensors="ms", + return_tensors="np", ) outputs = self.text_encoder( diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 2f3af0bc3a..024c03da22 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -198,7 +198,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index b3302e43b7..1e71ea646a 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -208,7 +208,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="ms" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, From 656accec7363c2deca3a9e64d683b0b2ba85d8f7 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 10:50:04 +0800 Subject: [PATCH 16/77] 2025/8/20 10:50 revised --- .../diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 7 ++++--- .../transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 89c432211f..a938bccc6e 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -192,11 +192,12 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" \ + "" ) encoder_hidden_states = self.text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask, + input_ids=ms.Tensor(txt_tokens.input_ids), + attention_mask=ms.Tensor(txt_tokens.attention_mask), output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] diff --git a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0cb29635f0..4938512188 100644 --- a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1219,7 +1219,6 @@ def __init__(self, config): super().__init__(config) # TODO: we need this patch here, may fix later config.vision_config._attn_implementation = config._attn_implementation - # config.vision_config.mindspore_dtype = getattr(config, "mindspore_dtype", None) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2_5_VLModel(config) self.vocab_size = config.vocab_size From 9a33d83b9ecd499ed2edada33b78b7cc7932bab9 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 11:11:18 +0800 Subject: [PATCH 17/77] 2025/8/20 11:11 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index a938bccc6e..e0cc5969c5 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -201,7 +201,7 @@ def _get_qwen_prompt_embeds( output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) From c2f972c5b2dcac2389690be60622c3ee8202b64e Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 11:27:31 +0800 Subject: [PATCH 18/77] 2025/8/20 11:27 revised --- .../diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 8 ++++---- .../pipelines/qwenimage/pipeline_qwenimage_edit.py | 6 +++--- .../pipelines/qwenimage/pipeline_qwenimage_img2img.py | 6 +++--- .../pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index e0cc5969c5..905088fd21 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -203,13 +203,13 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0], u.shape[1])]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0])]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 4d73cd13e2..e463e44017 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -259,15 +259,15 @@ def _get_qwen_prompt_embeds( ) outputs = self.text_encoder( - input_ids=model_inputs.input_ids, - attention_mask=model_inputs.attention_mask, + input_ids=ms.Tensor(model_inputs.input_ids), + attention_mask=ms.Tensor(model_inputs.attention_mask), pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True, ) hidden_states = outputs.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(model_inputs.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 024c03da22..0ea98e5068 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -201,12 +201,12 @@ def _get_qwen_prompt_embeds( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" ) encoder_hidden_states = self.text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask, + input_ids=ms.Tensor(txt_tokens.input_ids), + attention_mask=ms.Tensor(txt_tokens.attention_mask), output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1e71ea646a..a125682c14 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -211,12 +211,12 @@ def _get_qwen_prompt_embeds( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="np" ) encoder_hidden_states = self.text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask, + input_ids=ms.Tensor(txt_tokens.input_ids), + attention_mask=ms.Tensor(txt_tokens.attention_mask), output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.size(0) for e in split_hidden_states]) From 9e2cccf765919026b5ff3dccd539383e3b562037 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 11:47:48 +0800 Subject: [PATCH 19/77] 2025/8/20 11:47 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 905088fd21..762451e3ba 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -206,10 +206,10 @@ def _get_qwen_prompt_embeds( attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0], u.shape[1])]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0])]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0]))]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) From 9b5be21cc67fbb6f2d8f24d3a1e6dd2733375009 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 14:25:09 +0800 Subject: [PATCH 20/77] 2025/8/20 14:25 revised --- .../pipelines/qwenimage/pipeline_qwenimage.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 762451e3ba..41cb330094 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -594,7 +594,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand((latents.shape[0],)) else: guidance = None @@ -615,33 +615,33 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + timestep = t.expand((latents.shape[0],)).to(latents.dtype) + # with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + # with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) From 19069195309a43c4e6ceb2ed7fc605f1d688cbe9 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 20 Aug 2025 14:26:32 +0800 Subject: [PATCH 21/77] 2025/8/20 14:26 revised --- .../models/transformers/transformer_qwenimage.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index adc43d5b6f..f36bbae731 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -245,15 +245,15 @@ def _compute_video_freqs(self, frame, height, width, idx=0): freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand((frame, height, width, -1),) if self.scale_rope: freqs_height = mint.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height = freqs_height.view(1, height, 1, -1).expand((frame, height, width, -1),) freqs_width = mint.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width = freqs_width.view(1, 1, width, -1).expand((frame, height, width, -1),) else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand((frame, height, width, -1),) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand((frame, height, width, -1),) freqs = mint.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() @@ -398,6 +398,7 @@ def __init__( processor=QwenDoubleStreamAttnProcessor2_0(), qk_norm=qk_norm, eps=eps, + ) self.img_norm2 = mint.nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") From c3055babc95d4fd812263434d923ff0ff559769c Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 21 Aug 2025 15:20:44 +0800 Subject: [PATCH 22/77] 2025/8/21 15:20 revised --- .../diffusers/models/attention_dispatch.py | 1218 ----------------- .../diffusers/models/attention_processor.py | 3 +- .../transformers/transformer_qwenimage.py | 12 +- .../qwenimage/pipeline_qwenimage_edit.py | 44 +- .../qwenimage/pipeline_qwenimage_img2img.py | 42 +- .../qwenimage/pipeline_qwenimage_inpaint.py | 42 +- 6 files changed, 72 insertions(+), 1289 deletions(-) delete mode 100644 mindone/diffusers/models/attention_dispatch.py diff --git a/mindone/diffusers/models/attention_dispatch.py b/mindone/diffusers/models/attention_dispatch.py deleted file mode 100644 index 7cc30e47ab..0000000000 --- a/mindone/diffusers/models/attention_dispatch.py +++ /dev/null @@ -1,1218 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import functools -import inspect -import math -from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - -import torch - -from ..utils import ( - get_logger, - is_flash_attn_3_available, - is_flash_attn_available, - is_flash_attn_version, - is_sageattention_available, - is_sageattention_version, - is_torch_npu_available, - is_torch_version, - is_torch_xla_available, - is_torch_xla_version, - is_xformers_available, - is_xformers_version, -) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS - - -_REQUIRED_FLASH_VERSION = "2.6.3" -_REQUIRED_SAGE_VERSION = "2.1.1" -_REQUIRED_FLEX_VERSION = "2.5.0" -_REQUIRED_XLA_VERSION = "2.2" -_REQUIRED_XFORMERS_VERSION = "0.0.29" - -_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) -_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() -_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) -_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) -_CAN_USE_NPU_ATTN = is_torch_npu_available() -_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) -_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) - - -if _CAN_USE_FLASH_ATTN: - from flash_attn import flash_attn_func, flash_attn_varlen_func -else: - flash_attn_func = None - flash_attn_varlen_func = None - - -if _CAN_USE_FLASH_ATTN_3: - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func -else: - flash_attn_3_func = None - flash_attn_3_varlen_func = None - - -if _CAN_USE_SAGE_ATTN: - from sageattention import ( - sageattn, - sageattn_qk_int8_pv_fp8_cuda, - sageattn_qk_int8_pv_fp8_cuda_sm90, - sageattn_qk_int8_pv_fp16_cuda, - sageattn_qk_int8_pv_fp16_triton, - sageattn_varlen, - ) -else: - sageattn = None - sageattn_qk_int8_pv_fp16_cuda = None - sageattn_qk_int8_pv_fp16_triton = None - sageattn_qk_int8_pv_fp8_cuda = None - sageattn_qk_int8_pv_fp8_cuda_sm90 = None - sageattn_varlen = None - - -if _CAN_USE_FLEX_ATTN: - # We cannot import the flex_attention function from the package directly because it is expected (from the - # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the - # compiled function. - import torch.nn.attention.flex_attention as flex_attention - - -if _CAN_USE_NPU_ATTN: - from torch_npu import npu_fusion_attention -else: - npu_fusion_attention = None - - -if _CAN_USE_XLA_ATTN: - from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention -else: - xla_flash_attention = None - - -if _CAN_USE_XFORMERS_ATTN: - import xformers.ops as xops -else: - xops = None - - -logger = get_logger(__name__) # pylint: disable=invalid-name - -# TODO(aryan): Add support for the following: -# - Sage Attention++ -# - block sparse, radial and other attention methods -# - CP with sage attention, flex, xformers, other missing backends -# - Add support for normal and CP training with backends that don't support it yet - -_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] -_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] -_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] - - -class AttentionBackendName(str, Enum): - # EAGER = "eager" - - # `flash-attn` - FLASH = "flash" - FLASH_VARLEN = "flash_varlen" - _FLASH_3 = "_flash_3" - _FLASH_VARLEN_3 = "_flash_varlen_3" - - # PyTorch native - FLEX = "flex" - NATIVE = "native" - _NATIVE_CUDNN = "_native_cudnn" - _NATIVE_EFFICIENT = "_native_efficient" - _NATIVE_FLASH = "_native_flash" - _NATIVE_MATH = "_native_math" - _NATIVE_NPU = "_native_npu" - _NATIVE_XLA = "_native_xla" - - # `sageattention` - SAGE = "sage" - SAGE_VARLEN = "sage_varlen" - _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" - _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" - _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" - _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" - # TODO: let's not add support for Sparge Attention now because it requires tuning per model - # We can look into supporting something "autotune"-ing in the future - # SPARGE = "sparge" - - # `xformers` - XFORMERS = "xformers" - - -class _AttentionBackendRegistry: - _backends = {} - _constraints = {} - _supported_arg_names = {} - _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) - _checks_enabled = DIFFUSERS_ATTN_CHECKS - - @classmethod - def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): - logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") - - def decorator(func): - cls._backends[backend] = func - cls._constraints[backend] = constraints or [] - cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) - return func - - return decorator - - @classmethod - def get_active_backend(cls): - return cls._active_backend, cls._backends[cls._active_backend] - - @classmethod - def list_backends(cls): - return list(cls._backends.keys()) - - -@contextlib.contextmanager -def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): - """ - Context manager to set the active attention backend. - """ - if backend not in _AttentionBackendRegistry._backends: - raise ValueError(f"Backend {backend} is not registered.") - - backend = AttentionBackendName(backend) - _check_attention_backend_requirements(backend) - - old_backend = _AttentionBackendRegistry._active_backend - _AttentionBackendRegistry._active_backend = backend - - try: - yield - finally: - _AttentionBackendRegistry._active_backend = old_backend - - -def dispatch_attention_fn( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - attention_kwargs: Optional[Dict[str, Any]] = None, - *, - backend: Optional[AttentionBackendName] = None, -) -> torch.Tensor: - attention_kwargs = attention_kwargs or {} - - if backend is None: - # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment - # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager - backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() - else: - backend_name = AttentionBackendName(backend) - backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - - kwargs = { - "query": query, - "key": key, - "value": value, - "attn_mask": attn_mask, - "dropout_p": dropout_p, - "is_causal": is_causal, - "scale": scale, - **attention_kwargs, - } - if is_torch_version(">=", "2.5.0"): - kwargs["enable_gqa"] = enable_gqa - - if _AttentionBackendRegistry._checks_enabled: - removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) - if removed_kwargs: - logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") - for check in _AttentionBackendRegistry._constraints.get(backend_name): - check(**kwargs) - - kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} - return backend_fn(**kwargs) - - -# ===== Checks ===== -# A list of very simple functions to catch common errors quickly when debugging. - - -def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: - if attn_mask is not None and is_causal: - raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") - - -def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: - if query.device != key.device or query.device != value.device: - raise ValueError("Query, key, and value must be on the same device.") - if query.dtype != key.dtype or query.dtype != value.dtype: - raise ValueError("Query, key, and value must have the same dtype.") - - -def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: - _check_device(query, key, value) - if query.device.type != "cuda": - raise ValueError("Query, key, and value must be on a CUDA device.") - - -def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: - def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: - _check_device_cuda(query, key, value) - if torch.cuda.get_device_capability(query.device) < (major, minor): - raise ValueError( - f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." - ) - - return check_device_cuda - - -def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: - if query.dtype != key.dtype: - raise ValueError("Query and key must have the same dtype.") - if query.dtype != value.dtype: - raise ValueError("Query and value must have the same dtype.") - - -def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: - _check_qkv_dtype_match(query, key, value) - if query.dtype not in (torch.bfloat16, torch.float16): - raise ValueError("Query, key, and value must be either bfloat16 or float16.") - - -def _check_shape( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - **kwargs, -) -> None: - if query.shape[-1] != key.shape[-1]: - raise ValueError("Query and key must have the same last dimension.") - if query.shape[-2] != value.shape[-2]: - raise ValueError("Query and value must have the same second to last dimension.") - if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: - raise ValueError("Attention mask must match the key's second to last dimension.") - - -# ===== Helper functions ===== - - -def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: - if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: - if not _CAN_USE_FLASH_ATTN: - raise RuntimeError( - f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." - ) - - elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: - if not _CAN_USE_FLASH_ATTN_3: - raise RuntimeError( - f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." - ) - - elif backend in [ - AttentionBackendName.SAGE, - AttentionBackendName.SAGE_VARLEN, - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, - AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, - ]: - if not _CAN_USE_SAGE_ATTN: - raise RuntimeError( - f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." - ) - - elif backend == AttentionBackendName.FLEX: - if not _CAN_USE_FLEX_ATTN: - raise RuntimeError( - f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." - ) - - elif backend == AttentionBackendName._NATIVE_NPU: - if not _CAN_USE_NPU_ATTN: - raise RuntimeError( - f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." - ) - - elif backend == AttentionBackendName._NATIVE_XLA: - if not _CAN_USE_XLA_ATTN: - raise RuntimeError( - f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." - ) - - elif backend == AttentionBackendName.XFORMERS: - if not _CAN_USE_XFORMERS_ATTN: - raise RuntimeError( - f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." - ) - - -@functools.lru_cache(maxsize=128) -def _prepare_for_flash_attn_or_sage_varlen_without_mask( - batch_size: int, - seq_len_q: int, - seq_len_kv: int, - device: Optional[torch.device] = None, -): - seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) - seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) - cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) - max_seqlen_q = seqlens_q.max().item() - max_seqlen_k = seqlens_k.max().item() - return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) - - -def _prepare_for_flash_attn_or_sage_varlen_with_mask( - batch_size: int, - seq_len_q: int, - attn_mask: torch.Tensor, - device: Optional[torch.device] = None, -): - seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) - seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) - cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) - max_seqlen_q = seqlens_q.max().item() - max_seqlen_k = seqlens_k.max().item() - return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) - - -def _prepare_for_flash_attn_or_sage_varlen( - batch_size: int, - seq_len_q: int, - seq_len_kv: int, - attn_mask: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, -) -> None: - if attn_mask is None: - return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) - return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) - - -def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: - """ - Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in - FlashAttention/Sage varlen. - - Supports 1D to 4D shapes and common broadcasting patterns. - """ - if attn_mask.dtype != torch.bool: - raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") - - if attn_mask.ndim == 1: - # [seq_len_k] -> broadcast across batch - attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) - - elif attn_mask.ndim == 2: - # [batch_size, seq_len_k]. Maybe broadcast across batch - if attn_mask.size(0) not in [1, batch_size]: - raise ValueError( - f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." - ) - attn_mask = attn_mask.expand(batch_size, seq_len_k) - - elif attn_mask.ndim == 3: - # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension - # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. - if attn_mask.size(0) not in [1, batch_size]: - raise ValueError( - f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." - ) - attn_mask = attn_mask.any(dim=1) - attn_mask = attn_mask.expand(batch_size, seq_len_k) - - elif attn_mask.ndim == 4: - # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions - if attn_mask.size(0) not in [1, batch_size]: - raise ValueError( - f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." - ) - attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] - attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] - - else: - raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") - - if attn_mask.shape != (batch_size, seq_len_k): - raise ValueError( - f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" - ) - - return attn_mask - - -def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return q_idx >= kv_idx - - -# ===== torch op registrations ===== -# Registrations are required for fullgraph tracing compatibility - - -# TODO: library.custom_op and register_fake probably need version guards? -# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding -# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3_original( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_3_func(query, key, value) - lse = lse.permute(0, 2, 1) - return out, lse - - -@torch.library.register_fake("flash_attn_3::_flash_attn_forward") -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len, num_heads, head_dim = query.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(query), query.new_empty(lse_shape) - - -# ===== Attention backends ===== - - -@_AttentionBackendRegistry.register( - AttentionBackendName.FLASH, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - scale: Optional[float] = None, - is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, -) -> torch.Tensor: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, - ) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName.FLASH_VARLEN, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_varlen_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, - dropout_p: float = 0.0, - scale: Optional[float] = None, - is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out = flash_attn_varlen_func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, - ) - out = out.unflatten(0, (batch_size, -1)) - - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._FLASH_3, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_attention_3( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: Optional[float] = None, - is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, -) -> torch.Tensor: - out, lse, *_ = flash_attn_3_func( - q=query, - k=key, - v=value, - softmax_scale=scale, - causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - attention_chunk=0, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, - ) - return (out, lse) if return_attn_probs else out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._FLASH_VARLEN_3, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_varlen_attention_3( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, - scale: Optional[float] = None, - is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out, lse, *_ = flash_attn_3_varlen_func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - seqused_q=None, - seqused_k=None, - softmax_scale=scale, - causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, - ) - out = out.unflatten(0, (batch_size, -1)) - - return (out, lse) if return_attn_probs else out - - -@_AttentionBackendRegistry.register( - AttentionBackendName.FLEX, - constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], -) -def _native_flex_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - return_lse: bool = False, - kernel_options: Optional[Dict[str, Any]] = None, -) -> torch.Tensor: - # TODO: should we LRU cache the block mask creation? - score_mod = None - block_mask = None - batch_size, seq_len_q, num_heads, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): - block_mask = attn_mask - elif is_causal: - block_mask = flex_attention.create_block_mask( - _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device - ) - elif torch.is_tensor(attn_mask): - if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) - - attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) - - if attn_mask.dtype == torch.bool: - # TODO: this probably does not work but verify! - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return attn_mask[batch_idx, head_idx, q_idx, kv_idx] - - block_mask = flex_attention.create_block_mask( - mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device - ) - else: - - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] - else: - raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") - - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = flex_attention.flex_attention( - query=query, - key=key, - value=value, - score_mod=score_mod, - block_mask=block_mask, - scale=scale, - enable_gqa=enable_gqa, - return_lse=return_lse, - kernel_options=kernel_options, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName.NATIVE, - constraints=[_check_device, _check_shape], -) -def _native_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_CUDNN, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _native_cudnn_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_EFFICIENT, - constraints=[_check_device, _check_shape], -) -def _native_efficient_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_FLASH, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _native_flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=None, # not supported - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_MATH, - constraints=[_check_device, _check_shape], -) -def _native_math_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_NPU, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _native_npu_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - scale: Optional[float] = None, -) -> torch.Tensor: - return npu_fusion_attention( - query, - key, - value, - query.size(2), # num_heads - input_layout="BSND", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] - - -# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 -@_AttentionBackendRegistry.register( - AttentionBackendName._NATIVE_XLA, - constraints=[_check_device, _check_shape], -) -def _native_xla_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, -) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - query = query / math.sqrt(query.shape[-1]) - out = xla_flash_attention( - q=query, - k=key, - v=value, - causal=is_causal, - ) - out = out.permute(0, 2, 1, 3) - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName.SAGE, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _sage_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - return_lse: bool = False, -) -> torch.Tensor: - return sageattn( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName.SAGE_VARLEN, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _sage_varlen_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, - is_causal: bool = False, - scale: Optional[float] = None, - smooth_k: bool = True, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out = sageattn_varlen( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - is_causal=is_causal, - sm_scale=scale, - smooth_k=smooth_k, - ) - out = out.unflatten(0, (batch_size, -1)) - - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, - smooth_v: bool = False, - return_lse: bool = False, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - qk_quant_gran=qk_quant_gran, - sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_sm90_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, - return_lse: bool = False, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda_sm90( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - qk_quant_gran=qk_quant_gran, - sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, - constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], -) -def _sage_qk_int8_pv_fp16_cuda_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", - smooth_k: bool = True, - smooth_v: bool = False, - return_lse: bool = False, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp16_cuda( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - qk_quant_gran=qk_quant_gran, - sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, - constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], -) -def _sage_qk_int8_pv_fp16_triton_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", - smooth_k: bool = True, - return_lse: bool = False, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp16_triton( - q=query, - k=key, - v=value, - tensor_layout="NHD", - quantization_backend=quantization_backend, - is_causal=is_causal, - sm_scale=scale, - smooth_k=smooth_k, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName.XFORMERS, - constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], -) -def _xformers_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - batch_size, seq_len_q, num_heads_q, _ = query.shape - _, seq_len_kv, num_heads_kv, _ = key.shape - - if is_causal: - attn_mask = xops.LowerTriangularMask() - elif attn_mask is not None: - if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) - elif attn_mask.ndim != 4: - raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) - - if enable_gqa: - if num_heads_q % num_heads_kv != 0: - raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") - num_heads_per_group = num_heads_q // num_heads_kv - query = query.unflatten(2, (num_heads_kv, -1)) - key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) - value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) - - out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) - - if enable_gqa: - out = out.flatten(2, 3) - - return out diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index e186653d5e..18aa7168e7 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -789,7 +789,8 @@ def flash_attention_op( ): # For most scenarios, qkv has been processed into a BNSD layout before sdp input_layout = "BNSD" - head_num = self.heads + # head_num = self.heads + head_num = query.shape[1] # In case qkv is 3-dim after `head_to_batch_dim` if query.ndim == 3: diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index f36bbae731..f4346b8d79 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -298,16 +298,16 @@ def __call__( txt_value = attn.add_v_proj(encoder_hidden_states) # Reshape for multi-head attention - img_query = unflatten(img_query, -1, (attn.heads, -1)).swapaxes(1, 2) - img_key = unflatten(img_key, -1, (attn.heads, -1)).swapaxes(1, 2) - img_value = unflatten(img_value, -1, (attn.heads, -1)).swapaxes(1, 2) + img_query = unflatten(img_query, -1, (attn.heads, -1)) + img_key = unflatten(img_key, -1, (attn.heads, -1)) + img_value = unflatten(img_value, -1, (attn.heads, -1)) # img_query = img_query.unflatten(-1, (attn.heads, -1)) # img_key = img_key.unflatten(-1, (attn.heads, -1)) # img_value = img_value.unflatten(-1, (attn.heads, -1)) - txt_query = unflatten(txt_query, -1, (attn.heads, -1)).swapaxes(1, 2) - txt_key = unflatten(txt_key, -1, (attn.heads, -1)).swapaxes(1, 2) - txt_value = unflatten(txt_value, -1, (attn.heads, -1)).swapaxes(1, 2) + txt_query = unflatten(txt_query, -1, (attn.heads, -1)) + txt_key = unflatten(txt_key, -1, (attn.heads, -1)) + txt_value = unflatten(txt_value, -1, (attn.heads, -1)) # txt_query = txt_query.unflatten(-1, (attn.heads, -1)) # txt_key = txt_key.unflatten(-1, (attn.heads, -1)) # txt_value = txt_value.unflatten(-1, (attn.heads, -1)) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index e463e44017..ac5954e6a7 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -743,7 +743,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand((latents.shape[0],)) else: guidance = None @@ -769,34 +769,34 @@ def __call__( latent_model_input = mint.cat([latents, image_latents], dim=1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + timestep = t.expand((latents.shape[0],)).to(latents.dtype) + # with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + # with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 0ea98e5068..245e1e6760 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -695,7 +695,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand((latents.shape[0],)) else: guidance = None @@ -715,33 +715,33 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + timestep = t.expand((latents.shape[0],)).to(latents.dtype) + # with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + # with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index a125682c14..15499d37a5 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -861,7 +861,7 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand((latents.shape[0],)) else: guidance = None @@ -881,33 +881,33 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + timestep = t.expand((latents.shape[0],)).to(latents.dtype) + # with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + # with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) From e02580045414dbf620d8d44d928bbf2832cb5749 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 21 Aug 2025 15:24:09 +0800 Subject: [PATCH 23/77] 2025/8/21 15:24 revised --- .../diffusers/models/autoencoders/autoencoder_kl_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 469a46722d..fff01dfc91 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -798,7 +798,7 @@ def disable_slicing(self) -> None: def clear_cache(self): def _count_conv3d(model): count = 0 - for m in model.modules(): + for m in model.cells_and_names(): if isinstance(m, QwenImageCausalConv3d): count += 1 return count From 436ebf34590992e070429c3e178f80bc6ae1ed5c Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 21 Aug 2025 17:08:43 +0800 Subject: [PATCH 24/77] 2025/8/21 17:08 revised --- .../models/autoencoders/autoencoder_kl_qwenimage.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index fff01dfc91..45b39d93ff 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -175,7 +175,7 @@ def __init__(self, dim: int, mode: str) -> None: self.resample = mint.nn.Identity() def construct(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() + b, c, t, h, w = x.shape if self.mode == "upsample3d": if feat_cache is not None: idx = feat_idx[0] @@ -204,7 +204,7 @@ def construct(self, x, feat_cache=None, feat_idx=[0]): t = x.shape[2] x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = self.resample(x) - x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + x = x.view(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) if self.mode == "downsample3d": if feat_cache is not None: @@ -313,7 +313,7 @@ def __init__(self, dim): def construct(self, x): identity = x - batch_size, channels, time, height, width = x.size() + batch_size, channels, time, height, width = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) x = self.norm(x) @@ -736,10 +736,10 @@ def __init__( # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { - "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.cells_and_names()) + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for _, m in self.decoder.cells_and_names()) if self.decoder is not None else 0, - "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.cells_and_names()) + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for _, m in self.encoder.cells_and_names()) if self.encoder is not None else 0, } @@ -798,7 +798,7 @@ def disable_slicing(self) -> None: def clear_cache(self): def _count_conv3d(model): count = 0 - for m in model.cells_and_names(): + for _, m in model.cells_and_names(): if isinstance(m, QwenImageCausalConv3d): count += 1 return count From dafec1a60256696e34855033c79b9813f549ca01 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 21 Aug 2025 17:57:24 +0800 Subject: [PATCH 25/77] 2025/8/21 17:57 revised --- .../autoencoders/autoencoder_kl_qwenimage.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 45b39d93ff..fc9b398edc 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -85,7 +85,6 @@ def __init__( def construct(self, x, cache_x=None): padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) x = mint.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] x = mint.nn.functional.pad(x, padding) @@ -111,8 +110,9 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.channel_first = channel_first self.scale = dim**0.5 - self.gamma = ms.Parameter(mint.ones(shape)) - self.bias = ms.Parameter(mint.zeros(shape)) if bias else 0.0 + self.gamma = ms.Parameter(mint.ones(shape), name="gamma") + self.bias = ms.Parameter(mint.zeros(shape), name="bias" \ + "") if bias else 0.0 def construct(self, x): return mint.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias @@ -155,12 +155,12 @@ def __init__(self, dim: int, mode: str) -> None: # layers if mode == "upsample2d": self.resample = nn.SequentialCell( - QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact", recompute_scale_factor = True), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.SequentialCell( - QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact", recompute_scale_factor = True), mint.nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -182,15 +182,16 @@ def construct(self, x, feat_cache=None, feat_idx=[0]): if feat_cache[idx] is None: feat_cache[idx] = "Rep" feat_idx[0] += 1 + else: cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk cache_x = mint.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2 ) if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = mint.cat([mint.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([mint.zeros_like(cache_x), cache_x], dim=2) if feat_cache[idx] == "Rep": x = self.time_conv(x) else: @@ -326,7 +327,7 @@ def construct(self, x): # apply attention # x = F.scaled_dot_product_attention(q, k, v) - x = ops.operation.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( + x = ops.operations.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), None, None, None, None )[3].to(q.dtype) @@ -454,7 +455,7 @@ def construct(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) x = self.conv_in(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -715,6 +716,8 @@ def __init__( base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout ) + self.diag_gauss_dist = DiagonalGaussianDistribution() + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension @@ -903,10 +906,10 @@ def decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderOutput, returned. """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [self._decode(z_slice)[0] for z_slice in z.split(1)] decoded = mint.cat(decoded_slices) else: - decoded = self._decode(z).sample + decoded = self._decode(z)[0] if not return_dict: return (decoded,) @@ -1071,10 +1074,10 @@ def construct( Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample - posterior = self.encode(x).latent_dist + posterior = self.encode(x)[0] if sample_posterior: - z = posterior.sample(generator=generator) + z = posterior.diag_gauss_dist.sample(posterior, generator=generator) else: - z = posterior.mode() + z = posterior.diag_gauss_dist.mode(posterior) dec = self.decode(z, return_dict=return_dict) return dec From e573be160e354a057ee5ea073ae1493d12babf84 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 21 Aug 2025 19:13:54 +0800 Subject: [PATCH 26/77] 2025/8/21 19:13 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 3 --- .../pipelines/qwenimage/pipeline_qwenimage_img2img.py | 3 --- .../pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 41cb330094..0822f2495b 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -686,9 +686,6 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models - self.maybe_free_model_hooks() - if not return_dict: return (image,) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 245e1e6760..dfc91089c2 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -787,9 +787,6 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models - self.maybe_free_model_hooks() - if not return_dict: return (image,) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 15499d37a5..89e31ffd24 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -970,9 +970,6 @@ def __call__( self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image ] - # Offload all models - self.maybe_free_model_hooks() - if not return_dict: return (image,) From d549ab2ed8d301e27774c70642c1d6f48118cb41 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 22 Aug 2025 11:32:04 +0800 Subject: [PATCH 27/77] 2025/8/22 11:32 revised --- .../models/autoencoders/autoencoder_kl_qwenimage.py | 8 +++++--- .../models/transformers/transformer_qwenimage.py | 6 ------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index fc9b398edc..2c3002c999 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -859,11 +859,13 @@ def encode( h = mint.cat(encoded_slices) else: h = self._encode(x) - posterior = DiagonalGaussianDistribution(h) + + # we cannot use class in grapha mode, even for jit_class or subclass of Tensor. :-( + # posterior = DiagonalGaussianDistribution(h) if not return_dict: - return (posterior,) - return AutoencoderKLOutput(latent_dist=posterior) + return (h,) + return AutoencoderKLOutput(latent_dist=h) def _decode(self, z: ms.Tensor, return_dict: bool = True): _, _, num_frame, height, width = z.shape diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index f4346b8d79..f7166694db 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -301,16 +301,10 @@ def __call__( img_query = unflatten(img_query, -1, (attn.heads, -1)) img_key = unflatten(img_key, -1, (attn.heads, -1)) img_value = unflatten(img_value, -1, (attn.heads, -1)) - # img_query = img_query.unflatten(-1, (attn.heads, -1)) - # img_key = img_key.unflatten(-1, (attn.heads, -1)) - # img_value = img_value.unflatten(-1, (attn.heads, -1)) txt_query = unflatten(txt_query, -1, (attn.heads, -1)) txt_key = unflatten(txt_key, -1, (attn.heads, -1)) txt_value = unflatten(txt_value, -1, (attn.heads, -1)) - # txt_query = txt_query.unflatten(-1, (attn.heads, -1)) - # txt_key = txt_key.unflatten(-1, (attn.heads, -1)) - # txt_value = txt_value.unflatten(-1, (attn.heads, -1)) # Apply QK normalization if attn.norm_q is not None: From 09ac0bdd741e99f0c9835eebb1b77b7dbc4ef06a Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 22 Aug 2025 17:40:54 +0800 Subject: [PATCH 28/77] 2025/8/22 17:40 revised --- .../autoencoders/autoencoder_kl_qwenimage.py | 10 ++++----- .../transformers/transformer_qwenimage.py | 21 ------------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 2c3002c999..750e5e88c2 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -22,7 +22,7 @@ # - arXiv: https://arxiv.org/abs/2503.20314 from typing import List, Optional, Tuple, Union - +import math import numpy as np import mindspore as ms @@ -326,10 +326,10 @@ def construct(self, x): q, k, v = qkv.chunk(3, dim=-1) # apply attention - # x = F.scaled_dot_product_attention(q, k, v) - x = ops.operations.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( - q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), None, None, None, None - )[3].to(q.dtype) + x = ops.flash_attention_score(q, k, v, 1, scalar_value=1/math.sqrt(q.shape[-1]), input_layout="BNSD") + # x = ops.operations.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( + # q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), None, None, None, None + # )[3].to(q.dtype) x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index f7166694db..c6d0dea973 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -618,27 +618,6 @@ def construct( image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens) - # for index_block, block in enumerate(self.transformer_blocks): - # if torch.is_grad_enabled() and self.gradient_checkpointing: - # encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - # block, - # hidden_states, - # encoder_hidden_states, - # encoder_hidden_states_mask, - # temb, - # image_rotary_emb, - # ) - - # else: - # encoder_hidden_states, hidden_states = block( - # hidden_states=hidden_states, - # encoder_hidden_states=encoder_hidden_states, - # encoder_hidden_states_mask=encoder_hidden_states_mask, - # temb=temb, - # image_rotary_emb=image_rotary_emb, - # joint_attention_kwargs=attention_kwargs, - # ) - for index_block, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, From fb5877ba98df1421eb3d85068158133f4f9ab934 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 25 Aug 2025 10:40:59 +0800 Subject: [PATCH 29/77] 2025/8/25 10:40 revised --- .../pipelines/qwenimage/pipeline_qwenimage_edit.py | 14 +++++++------- .../qwenimage/pipeline_qwenimage_img2img.py | 8 ++++---- .../qwenimage/pipeline_qwenimage_inpaint.py | 8 ++++---- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ac5954e6a7..5db0ac92f8 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -269,13 +269,13 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(model_inputs.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0], u.shape[1])]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0])]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) @@ -662,7 +662,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Preprocess image - if image is not None and not (isinstance(image, ms.Tensor) and image.size(1) == self.latent_channels): + if image is not None and not (isinstance(image, ms.Tensor) and image.shape[1] == self.latent_channels): img = image[0] if isinstance(image, list) else image image_height, image_width = self.image_processor.get_default_height_width(img) aspect_ratio = image_width / image_height @@ -782,7 +782,7 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred[:, : latents.size(1)] + noise_pred = noise_pred[:, : latents.shape[1]] if do_true_cfg: # with self.transformer.cache_context("uncond"): @@ -797,7 +797,7 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.shape[1]] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index dfc91089c2..b5c3d849e7 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -208,13 +208,13 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0]))]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 89e31ffd24..6b099cc9db 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -218,13 +218,13 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.size(0), dtype=ms.int64) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0]))]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) From 358b20b886c24d1eb0adca8f01ef4f76b24d0674 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 26 Aug 2025 10:30:02 +0800 Subject: [PATCH 30/77] 2025/8/26 10:30 revised --- .../api/models/autoencoderkl_qwenimage.md | 26 ++ .../api/models/qwenimage_transformer2d.md | 24 ++ docs/diffusers/api/pipelines/qwenimage.md | 42 +++ .../pipelines/qwenimage/test_qwenimage.py | 350 +++++++++--------- 4 files changed, 264 insertions(+), 178 deletions(-) create mode 100644 docs/diffusers/api/models/autoencoderkl_qwenimage.md create mode 100644 docs/diffusers/api/models/qwenimage_transformer2d.md create mode 100644 docs/diffusers/api/pipelines/qwenimage.md diff --git a/docs/diffusers/api/models/autoencoderkl_qwenimage.md b/docs/diffusers/api/models/autoencoderkl_qwenimage.md new file mode 100644 index 0000000000..3c2ab7c76a --- /dev/null +++ b/docs/diffusers/api/models/autoencoderkl_qwenimage.md @@ -0,0 +1,26 @@ + + +# AutoencoderKLQwenImage + +The model can be loaded with the following code snippet. + +```python +from mindone.diffusers import AutoencoderKLQwenImage + +vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae") +``` + +::: mindspore.diffusers.AutoencoderKLQwenImage + +::: mindspore.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +::: mindspore.diffusers.models.autoencoders.vae.DecoderOutput diff --git a/docs/diffusers/api/models/qwenimage_transformer2d.md b/docs/diffusers/api/models/qwenimage_transformer2d.md new file mode 100644 index 0000000000..bfe0037e70 --- /dev/null +++ b/docs/diffusers/api/models/qwenimage_transformer2d.md @@ -0,0 +1,24 @@ + + +# QwenImageTransformer2DModel + +The model can be loaded with the following code snippet. + +```python +from mindone.diffusers import QwenImageTransformer2DModel + +transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", mindspore_dtype=mindspore.bfloat16) +``` + +::: mindspore.diffusers.QwenImageTransformer2DModel + +::: mindspore.diffusers.models.modeling_outputs.Transformer2DModelOutput \ No newline at end of file diff --git a/docs/diffusers/api/pipelines/qwenimage.md b/docs/diffusers/api/pipelines/qwenimage.md new file mode 100644 index 0000000000..cf569f21b8 --- /dev/null +++ b/docs/diffusers/api/pipelines/qwenimage.md @@ -0,0 +1,42 @@ + + +# QwenImage + +
+ LoRA +
+ +Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese. + +Qwen-Image comes in the following variants: + +| model type | model id | +|:----------:|:--------:| +| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) | +| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) | + +!!! Tip + +[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. + +!!! Tip + +::: mindone.diffusers.QwenImagePipeline + +::: mindone.diffusers.pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput + +::: mindone.diffusers.QwenImageImg2ImgPipeline + +::: mindone.diffusers.QwenImageInpaintPipeline \ No newline at end of file diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 5b8a6cbb92..0e831a90b7 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -19,221 +19,215 @@ import numpy as np import torch -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from ddt import data, ddt, unpack +from transformers import Qwen2_5_VLConfig +# from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer -from diffusers import ( +import mindspore as ms + +from mindone.diffusers import ( AutoencoderKLQwenImage, - FlowMatchEulerDiscreteScheduler, QwenImagePipeline, QwenImageTransformer2DModel, ) -from diffusers.utils.testing_utils import enable_full_determinism, torch_device - -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +# from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import load_numpy_from_local_file, slow + +# from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +# from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..pipeline_test_utils import ( + THRESHOLD_FP16, + THRESHOLD_FP32, + THRESHOLD_PIXEL, + PipelineTesterMixin, + get_module, + get_pipeline_components, +) +test_cases = [ + {"mode": ms.PYNATIVE_MODE, "dtype": "float32"}, + {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, + {"mode": ms.GRAPH_MODE, "dtype": "float32"}, + {"mode": ms.GRAPH_MODE, "dtype": "bfloat16"}, +] -enable_full_determinism() +# enable_full_determinism() +@ddt class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = QwenImagePipeline - params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - required_optional_params = frozenset( + pipeline_config = [ + [ + "transformer", + "diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + dict( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ), + ], [ - "num_inference_steps", - "generator", - "latents", - "return_dict", - "callback_on_step_end", - "callback_on_step_end_tensor_inputs", - ] - ) - supports_dduf = False - test_xformers_attention = False - test_layerwise_casting = True - test_group_offloading = True + "vae", + "diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + "mindone.diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + dict( + base_dim=4 * 6, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ), + ], + [ + "scheduler", + "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + ], + [ + "text_encoder", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + "mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + dict( + config=Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ), + ), + ], + [ + "tokenizer", + "transformers.models.qwem2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwem2.tokenization_qwen2.Qwen2Tokenizer", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ), + ], + ] def get_dummy_components(self): - torch.manual_seed(0) - transformer = QwenImageTransformer2DModel( - patch_size=2, - in_channels=16, - out_channels=4, - num_layers=2, - attention_head_dim=16, - num_attention_heads=3, - joint_attention_dim=16, - guidance_embeds=False, - axes_dims_rope=(8, 4, 4), - ) - - torch.manual_seed(0) - z_dim = 4 - vae = AutoencoderKLQwenImage( - base_dim=z_dim * 6, - z_dim=z_dim, - dim_mult=[1, 2, 4], - num_res_blocks=1, - temperal_downsample=[False, True], - # fmt: off - latents_mean=[0.0] * 4, - latents_std=[1.0] * 4, - # fmt: on - ) - - torch.manual_seed(0) - scheduler = FlowMatchEulerDiscreteScheduler() - - torch.manual_seed(0) - config = Qwen2_5_VLConfig( - text_config={ - "hidden_size": 16, - "intermediate_size": 16, - "num_hidden_layers": 2, - "num_attention_heads": 2, - "num_key_value_heads": 2, - "rope_scaling": { - "mrope_section": [1, 1, 2], - "rope_type": "default", - "type": "default", - }, - "rope_theta": 1000000.0, - }, - vision_config={ - "depth": 2, - "hidden_size": 16, - "intermediate_size": 16, - "num_heads": 2, - "out_hidden_size": 16, - }, - hidden_size=16, - vocab_size=152064, - vision_end_token_id=151653, - vision_start_token_id=151652, - vision_token_id=151654, - ) - text_encoder = Qwen2_5_VLForConditionalGeneration(config) - tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") - components = { - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - "text_encoder": text_encoder, - "tokenizer": tokenizer, + key: None + for key in [ + "transformer", + "vae", + "scheduler", + "text_encoder", + "tokenizer", + ] } - return components - def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) + return get_pipeline_components(components, self.pipeline_config) + + def get_dummy_inputs(self): inputs = { "prompt": "dance monkey", "negative_prompt": "bad quality", - "generator": generator, "num_inference_steps": 2, "guidance_scale": 3.0, "true_cfg_scale": 1.0, "height": 32, "width": 32, "max_sequence_length": 16, - "output_type": "pt", + "output_type": "np", } return inputs - def test_inference(self): - device = "cpu" + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + ms.set_context(mode=mode) components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) + ms.set_context(mode=mode) - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - generated_image = image[0] - self.assertEqual(generated_image.shape, (3, 32, 32)) + pt_components, ms_components = self.get_dummy_components() + pt_pipe_cls = get_module("diffusers.pipelines.qwenimage.QwenImagePipeline") + ms_pipe_cls = get_module("mindone.diffusers.pipelines.qwenimage.QwenImagePipeline") - # fmt: off - expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208]) - # fmt: on + pt_pipe = pt_pipe_cls(**pt_components) + ms_pipe = ms_pipe_cls(**ms_components) - generated_slice = generated_image.flatten() - generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + pt_pipe.set_progress_bar_config(disable=None) + ms_pipe.set_progress_bar_config(disable=None) - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype) + pt_pipe = pt_pipe.to(pt_dtype) + ms_pipe = ms_pipe.to(ms_dtype) - def test_attention_slicing_forward_pass( - self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 - ): - if not self.test_attention_slicing: - return + inputs = self.get_dummy_inputs() - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - output_without_slicing = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=1) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing1 = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=2) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing2 = pipe(**inputs)[0] - - if test_max_difference: - max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() - max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() - self.assertLess( - max(max_diff1, max_diff2), - expected_max_diff, - "Attention slicing should not affect the inference results", - ) - - def test_vae_tiling(self, expected_diff_max: float = 0.2): - generator_device = "cpu" - components = self.get_dummy_components() + torch.manual_seed(0) + pt_image = pt_pipe(**inputs).images + torch.manual_seed(0) + ms_image = ms_pipe(**inputs)[0] - pipe = self.pipeline_class(**components) - pipe.to("cpu") - pipe.set_progress_bar_config(disable=None) - - # Without tiling - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_without_tiling = pipe(**inputs)[0] - - # With tiling - pipe.vae.enable_tiling( - tile_sample_min_height=96, - tile_sample_min_width=96, - tile_sample_stride_height=64, - tile_sample_stride_width=64, - ) - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_with_tiling = pipe(**inputs)[0] - - self.assertLess( - (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), - expected_diff_max, - "VAE tiling should not affect the inference results", + pt_generated_image = pt_image[0] + ms_generated_image = ms_image[0] + + threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16 + assert np.max(np.linalg.norm(pt_generated_image - ms_generated_image) / np.linalg.norm(pt_generated_image)) < threshold + + +@slow +@ddt +class QwenImagePipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + ms.set_context(mode=mode) + ms_dtype = getattr(ms, dtype) + + model_id = "Qwen/Qwen-Image" + pipe = QwenImagePipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) + + pipe.transformer.to(ms.bfloat16) + pipe.vae.enable_tiling() + + torch.manual_seed(0) + image = pipe( + prompt="dance monkey", + negative_prompt="bad quality", + )[0][0] + + expected_image = load_numpy_from_local_file( + "mindone-testing-arrays", + f"qwenimage_t2i_{dtype}.npy", + subfolder="qwenimage", ) + assert np.mean(np.abs(np.array(image, dtype=np.float32) - expected_image)) < THRESHOLD_PIXEL From 721543eb9133221a1d49f4d079df295be4ac4501 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 26 Aug 2025 17:10:06 +0800 Subject: [PATCH 31/77] 2025/8/26 17:10 revised --- .../diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py | 2 +- tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py | 2 +- tests/diffusers_tests/pipelines/wan/test_wan.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 5db0ac92f8..f4a171de2c 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -241,7 +241,7 @@ def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, image: Optional[ms.Tensor] = None, - dtype: Optional[mint.dtype] = None, + dtype: Optional[ms.dtype] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 0e831a90b7..2c3a846610 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -31,7 +31,7 @@ QwenImageTransformer2DModel, ) # from diffusers.utils.testing_utils import enable_full_determinism, torch_device -from diffusers.utils.testing_utils import load_numpy_from_local_file, slow +from mindone.diffusers.utils.testing_utils import load_numpy_from_local_file, slow # from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS # from ..test_pipelines_common import PipelineTesterMixin, to_np diff --git a/tests/diffusers_tests/pipelines/wan/test_wan.py b/tests/diffusers_tests/pipelines/wan/test_wan.py index 2b631961cf..dfc0ad63e4 100644 --- a/tests/diffusers_tests/pipelines/wan/test_wan.py +++ b/tests/diffusers_tests/pipelines/wan/test_wan.py @@ -28,7 +28,8 @@ from mindone.diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline from mindone.diffusers.utils.testing_utils import load_numpy_from_local_file, slow -from ..pipeline_test_utils import ( +# from ..pipeline_test_utils import ( +from mindone.tests.diffusers_tests.pipeline_test_utils import ( THRESHOLD_FP16, THRESHOLD_FP32, THRESHOLD_PIXEL, From fc02927a861d517d0823d9c35d4297e9cab6cb84 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 26 Aug 2025 17:20:09 +0800 Subject: [PATCH 32/77] 2025/8/26 17:20 revised --- tests/diffusers_tests/pipelines/wan/test_wan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/diffusers_tests/pipelines/wan/test_wan.py b/tests/diffusers_tests/pipelines/wan/test_wan.py index dfc0ad63e4..49195b43b2 100644 --- a/tests/diffusers_tests/pipelines/wan/test_wan.py +++ b/tests/diffusers_tests/pipelines/wan/test_wan.py @@ -28,8 +28,7 @@ from mindone.diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline from mindone.diffusers.utils.testing_utils import load_numpy_from_local_file, slow -# from ..pipeline_test_utils import ( -from mindone.tests.diffusers_tests.pipeline_test_utils import ( +from ..pipeline_test_utils import ( THRESHOLD_FP16, THRESHOLD_FP32, THRESHOLD_PIXEL, From 3b16c50739e00a578e46ca58d3f5f176a38826da Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 27 Aug 2025 14:08:30 +0800 Subject: [PATCH 33/77] 2025/8/27 14:08 revised --- mindone/utils/weight_norm.py | 8 ++++---- .../pipelines/qwenimage/test_qwenimage.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mindone/utils/weight_norm.py b/mindone/utils/weight_norm.py index 76c02aa928..b0a0594dd1 100644 --- a/mindone/utils/weight_norm.py +++ b/mindone/utils/weight_norm.py @@ -1,16 +1,16 @@ import mindspore as ms -from mindspore import Parameter, nn, ops +from mindspore import Parameter, nn, ops, mint def norm_except_dim(v, pow, dim): if dim == -1: - return ops.norm(v, pow) + return mint.norm(v, pow) elif dim == 0: output_size = (v.shape[0],) + (1,) * (v.ndim - 1) - return ops.norm(v.view((v.shape[0], -1)), pow, 1).view(output_size) + return mint.norm(v.view((v.shape[0], -1)), pow, 1).view(output_size) elif dim == (v.ndim - 1): output_size = (1,) * (v.ndim - 1) + (v.shape[v.ndim - 1],) - return ops.norm(v.view((-1, v.shape[v.ndim - 1])), pow, 0).view(output_size) + return mint.norm(v.view((-1, v.shape[v.ndim - 1])), pow, 0).view(output_size) else: return norm_except_dim(v.transpose(0, v.ndim - 1), pow, 0).transpose(0, v.ndim - 1) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 2c3a846610..6362e2af13 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -93,6 +93,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): "scheduler", "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + dict(), ], [ "text_encoder", @@ -120,11 +121,27 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): "num_heads": 2, "out_hidden_size": 16, }, + attention_dropout=0.0, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + rms_norm_eps=1e-06 + max_position_embeddings=128000, hidden_size=16, + hidden_act="silu", + intermediate_size=16, vocab_size=152064, vision_end_token_id=151653, vision_start_token_id=151652, vision_token_id=151654, + use_sliding_window=False, + attn_implementation="eager", + rope_scaling={ + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + rope_theta=1000000.0, ), ), ], From 151ed25989d10c84601ad1d881f890cd011dfb06 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 27 Aug 2025 17:05:10 +0800 Subject: [PATCH 34/77] 2025/8/27 17:05 revised --- .../pipelines/qwenimage/pipeline_qwenimage.py | 2 +- .../pipelines/qwenimage/test_qwenimage.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 0822f2495b..34348ef5b2 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -203,7 +203,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int32) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 6362e2af13..606742f7bc 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -125,16 +125,18 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): num_hidden_layers=2, num_attention_heads=2, num_key_value_heads=2, - rms_norm_eps=1e-06 + rms_norm_eps=1e-06, max_position_embeddings=128000, hidden_size=16, hidden_act="silu", intermediate_size=16, + initializer_range=0.02, vocab_size=152064, vision_end_token_id=151653, vision_start_token_id=151652, vision_token_id=151654, use_sliding_window=False, + use_cache=True, attn_implementation="eager", rope_scaling={ "mrope_section": [1, 1, 2], @@ -147,10 +149,13 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ], [ "tokenizer", - "transformers.models.qwem2.tokenization_qwen2.Qwen2Tokenizer", - "transformers.models.qwem2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( - pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + local_files_only=True, + trust_remote_code=True, ), ], ] From 46cd675f5234fcc9dea218a8a997a2e39f18836d Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 27 Aug 2025 17:09:15 +0800 Subject: [PATCH 35/77] 2025/8/27 17:09 revised --- tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 606742f7bc..2be2623aab 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -135,6 +135,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): vision_end_token_id=151653, vision_start_token_id=151652, vision_token_id=151654, + sliding_window=32768, #None use_sliding_window=False, use_cache=True, attn_implementation="eager", From 35b35fcded93f96347ea3ca1ce6b144a4a26e11d Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 27 Aug 2025 17:23:54 +0800 Subject: [PATCH 36/77] 2025/8/27 17:23 revised --- mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py | 4 ++-- tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 34348ef5b2..285acaa84f 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -196,8 +196,8 @@ def _get_qwen_prompt_embeds( "" ) encoder_hidden_states = self.text_encoder( - input_ids=ms.Tensor(txt_tokens.input_ids), - attention_mask=ms.Tensor(txt_tokens.attention_mask), + input_ids=ms.Tensor(txt_tokens.input_ids, dtype=ms.int32), + attention_mask=ms.Tensor(txt_tokens.attention_mask, dtype=ms.int32), output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 2be2623aab..cd673006dd 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -154,7 +154,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", local_files_only=True, trust_remote_code=True, ), From 19c938e6e50b292a3b1d009082e66c70fe6365eb Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 29 Aug 2025 15:42:48 +0800 Subject: [PATCH 37/77] 2025/8/29 15:42 revised --- .../models/transformers/transformer_qwenimage.py | 11 ++++++++++- .../pipelines/qwenimage/pipeline_qwenimage.py | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index c6d0dea973..72988d1489 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -331,7 +331,7 @@ def __call__( joint_value = mint.cat([txt_value, img_value], dim=1) # Compute joint attention - # NOTICE! 2025/8/18. Replace in the present version. + # TODO: dispatch_attention_fn.py # joint_hidden_states = dispatch_attention_fn( # joint_query, # joint_key, @@ -341,9 +341,11 @@ def __call__( # is_causal=False, # backend=self._attention_backend, # ) + joint_query, joint_key, joint_value = (x.permute(0, 2, 1, 3) for x in (joint_query, joint_key, joint_value)) joint_hidden_states = attn.scaled_dot_product_attention( joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3) # Reshape back joint_hidden_states = joint_hidden_states.flatten(2, 3) @@ -563,6 +565,7 @@ def construct( txt_seq_lens: Optional[List[int]] = None, guidance: ms.Tensor = None, # TODO: this should probably be removed attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, return_dict: bool = True, ) -> Union[ms.Tensor, Transformer2DModelOutput]: """ @@ -628,6 +631,12 @@ def construct( joint_attention_kwargs=attention_kwargs, ) + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(mint.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + # Use only the image part (hidden_states) from the dual-stream blocks hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 285acaa84f..0822f2495b 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -196,14 +196,14 @@ def _get_qwen_prompt_embeds( "" ) encoder_hidden_states = self.text_encoder( - input_ids=ms.Tensor(txt_tokens.input_ids, dtype=ms.int32), - attention_mask=ms.Tensor(txt_tokens.attention_mask, dtype=ms.int32), + input_ids=ms.Tensor(txt_tokens.input_ids), + attention_mask=ms.Tensor(txt_tokens.attention_mask), output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(txt_tokens.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int32) for e in split_hidden_states] + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] From 44504e593738f03c58b7e28aa538ba4b942e923b Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 1 Sep 2025 09:19:04 +0800 Subject: [PATCH 38/77] 2025/9/1 09:18 revised --- .../pipelines/qwenimage/test_qwenimage.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index cd673006dd..c6336045a9 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -18,10 +18,10 @@ import unittest import numpy as np +import pytest import torch from ddt import data, ddt, unpack from transformers import Qwen2_5_VLConfig -# from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer import mindspore as ms @@ -30,11 +30,8 @@ QwenImagePipeline, QwenImageTransformer2DModel, ) -# from diffusers.utils.testing_utils import enable_full_determinism, torch_device from mindone.diffusers.utils.testing_utils import load_numpy_from_local_file, slow -# from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -# from ..test_pipelines_common import PipelineTesterMixin, to_np from ..pipeline_test_utils import ( THRESHOLD_FP16, THRESHOLD_FP32, @@ -47,13 +44,8 @@ test_cases = [ {"mode": ms.PYNATIVE_MODE, "dtype": "float32"}, {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, - {"mode": ms.GRAPH_MODE, "dtype": "float32"}, - {"mode": ms.GRAPH_MODE, "dtype": "bfloat16"}, ] - -# enable_full_determinism() - @ddt class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_config = [ @@ -236,10 +228,10 @@ def test_inference(self, mode, dtype): ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) - model_id = "Qwen/Qwen-Image" + # model_id = "Qwen/Qwen-Image" + model_id = "/data6/Qwen-Image" pipe = QwenImagePipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) - pipe.transformer.to(ms.bfloat16) pipe.vae.enable_tiling() torch.manual_seed(0) @@ -249,7 +241,8 @@ def test_inference(self, mode, dtype): )[0][0] expected_image = load_numpy_from_local_file( - "mindone-testing-arrays", + # "mindone-testing-arrays", + "/data4/mindone-testing-arrays", f"qwenimage_t2i_{dtype}.npy", subfolder="qwenimage", ) From d5cfad23de116a67a17f156eaca63e11a8c45e3e Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 1 Sep 2025 09:40:17 +0800 Subject: [PATCH 39/77] 2025/9/1 09:40 revised --- .../diffusers_tests/pipelines/qwenimage/test_qwenimage.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index c6336045a9..c186f48e3d 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -146,7 +146,8 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="test/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", local_files_only=True, trust_remote_code=True, ), @@ -225,6 +226,9 @@ class QwenImagePipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): @data(*test_cases) @unpack def test_inference(self, mode, dtype): + if dtype == "float32": + pytest.skip("Skipping this case since this pipeline will OOM in float32") + ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) @@ -240,6 +244,8 @@ def test_inference(self, mode, dtype): negative_prompt="bad quality", )[0][0] + # The text_coder causes deviations between ms and pt versions, but the test result (2.809) \ + # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( # "mindone-testing-arrays", "/data4/mindone-testing-arrays", From aa08b1a12ef294ff8704b69104d6acc2f81b4f8d Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 2 Sep 2025 14:06:43 +0800 Subject: [PATCH 40/77] 2025/9/2 14:06, img2img infer --- .../diffusers/models/attention_processor.py | 2 +- .../autoencoders/autoencoder_kl_qwenimage.py | 8 ++-- .../qwenimage/pipeline_qwenimage_img2img.py | 41 +++++++++++-------- .../pipelines/qwenimage/test_qwenimage.py | 2 +- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index 18aa7168e7..9bbcb282a6 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -789,7 +789,7 @@ def flash_attention_op( ): # For most scenarios, qkv has been processed into a BNSD layout before sdp input_layout = "BNSD" - # head_num = self.heads + # head_num = self.heads # may cause some errors head_num = query.shape[1] # In case qkv is 3-dim after `head_to_batch_dim` diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 750e5e88c2..b5f888cec8 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -840,7 +840,7 @@ def _encode(self, x: ms.Tensor): # @apply_forward_hook def encode( - self, x: ms.Tensor, return_dict: bool = True + self, x: ms.Tensor, return_dict: bool = False ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: r""" Encode a batch of images into latents. @@ -867,7 +867,7 @@ def encode( return (h,) return AutoencoderKLOutput(latent_dist=h) - def _decode(self, z: ms.Tensor, return_dict: bool = True): + def _decode(self, z: ms.Tensor, return_dict: bool = False): _, _, num_frame, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio @@ -893,7 +893,7 @@ def _decode(self, z: ms.Tensor, return_dict: bool = True): return DecoderOutput(sample=out) # @apply_forward_hook - def decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderOutput, ms.Tensor]: + def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]: r""" Decode a batch of images. @@ -1066,7 +1066,7 @@ def construct( self, sample: ms.Tensor, sample_posterior: bool = False, - return_dict: bool = True, + return_dict: bool = False, generator: Optional[np.random.Generator] = None, ) -> Union[DecoderOutput, ms.Tensor]: """ diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index b5c3d849e7..6e572830cb 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -14,7 +14,7 @@ from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging #, scale_lora_layers, unscale_lora_layers -from ...utils.mindspore_utils import randn_tensor #, pynative_context +from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -41,16 +41,17 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" + vae, encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents + if sample_mode == "sample": + return vae.diag_gauss_dist.sample(encoder_output, generator=generator) + elif sample_mode == "argmax": + return vae.diag_gauss_dist.mode(encoder_output) + # This brach is not needed because the encoder_output type is ms.Tensor as per AutoencoderKLOuput change + # elif hasattr(encoder_output, "latents"): + # return encoder_output.latents else: - raise AttributeError("Could not access latents of provided encoder_output") + return encoder_output # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift @@ -222,14 +223,16 @@ def _get_qwen_prompt_embeds( return prompt_embeds, encoder_attention_mask def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = mint.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])[0]) + for i in range(image.shape[0]) + ] + image_latents = mint.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae, self.vae.encode(image)[0]) latents_mean = ( ms.Tensor(self.vae.config.latents_mean) @@ -784,7 +787,9 @@ def __call__( ) latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index c186f48e3d..f74ed2f57a 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -147,7 +147,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): dict( # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="test/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", local_files_only=True, trust_remote_code=True, ), From 92d1a2396696a08635ddfc5a654f2c622b87c20c Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Sep 2025 08:50:16 +0800 Subject: [PATCH 41/77] 2025/9/3 8:50, inpaint infer --- .../qwenimage/pipeline_qwenimage_edit.py | 155 +-- .../pipeline_qwenimage_edit_inpaint.py | 1103 +++++++++++++++++ .../qwenimage/pipeline_qwenimage_img2img.py | 1 - .../qwenimage/pipeline_qwenimage_inpaint.py | 42 +- .../pipelines/qwenimage/test_qwenimage.py | 2 +- 5 files changed, 1205 insertions(+), 98 deletions(-) create mode 100644 mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index f4a171de2c..8158c9bd55 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -29,8 +29,8 @@ from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging -from ...utils.mindspore_utils import randn_tensor +from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -41,7 +41,7 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> import torch + >>> import mindspore >>> from PIL import Image >>> from mindone.diffusers import QwenImageEditPipeline >>> from mindone.diffusers.utils import load_image @@ -59,25 +59,6 @@ >>> image.save("qwenimage_edit.png") ``` """ -PREFERRED_QWENIMAGE_RESOLUTIONS = [ - (672, 1568), - (688, 1504), - (720, 1456), - (752, 1392), - (800, 1328), - (832, 1248), - (880, 1184), - (944, 1104), - (1024, 1024), - (1104, 944), - (1184, 880), - (1248, 832), - (1328, 800), - (1392, 752), - (1456, 720), - (1504, 688), - (1568, 672), -] # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift @@ -153,17 +134,17 @@ def retrieve_timesteps( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" + vae, encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents + if sample_mode == "sample": + return vae.diag_gauss_dist.sample(encoder_output, generator=generator) + elif sample_mode == "argmax": + return vae.diag_gauss_dist.mode(encoder_output) + # This brach is not needed because the encoder_output type is ms.Tensor as per AutoencoderKLOuput change + # elif hasattr(encoder_output, "latents"): + # return encoder_output.latents else: - raise AttributeError("Could not access latents of provided encoder_output") - + return encoder_output def calculate_dimensions(target_area, ratio): width = math.sqrt(target_area * ratio) @@ -261,21 +242,21 @@ def _get_qwen_prompt_embeds( outputs = self.text_encoder( input_ids=ms.Tensor(model_inputs.input_ids), attention_mask=ms.Tensor(model_inputs.attention_mask), - pixel_values=model_inputs.pixel_values, - image_grid_thw=model_inputs.image_grid_thw, + pixel_values=ms.Tensor(model_inputs.pixel_values), + image_grid_thw=ms.Tensor(model_inputs.image_grid_thw), output_hidden_states=True, ) hidden_states = outputs.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(model_inputs.attention_mask)) + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0], u.shape[1])]) for u in split_hidden_states] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( - [mint.cat([u, u.new_zeros(max_seq_len - u.shape[0])]) for u in attn_mask_list] + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0]))]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype) @@ -400,14 +381,17 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") - for i in range(image.shape[0]) - ] - image_latents = mint.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])[0], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = mint.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae, self.vae.encode(image)[0], sample_mode="argmax") + latents_mean = ( ms.Tensor(self.vae.config.latents_mean) .view(1, self.latent_channels, 1, 1, 1) @@ -535,7 +519,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, - guidance_scale: float = 1.0, + guidance_scale: Optional[float] = None, num_images_per_prompt: int = 1, generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, latents: Optional[ms.Tensor] = None, @@ -549,12 +533,17 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - _auto_resize: bool = True, ): r""" Function invoked when calling the pipeline for generation. Args: + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -563,7 +552,12 @@ def __call__( `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -575,12 +569,16 @@ def __call__( Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): @@ -624,9 +622,8 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[0].size if isinstance(image, list) else image.size - width, height = image_size - calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height) + image_size = image[0].shape if isinstance(image, list) else image.shape + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width @@ -663,23 +660,24 @@ def __call__( # 3. Preprocess image if image is not None and not (isinstance(image, ms.Tensor) and image.shape[1] == self.latent_channels): - img = image[0] if isinstance(image, list) else image - image_height, image_width = self.image_processor.get_default_height_width(img) - aspect_ratio = image_width / image_height - if _auto_resize: - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS - ) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - image = self.image_processor.resize(image, image_height, image_width) + image = self.image_processor.resize(image, calculated_height, calculated_width) prompt_image = image - image = self.image_processor.preprocess(image, image_height, image_width) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( image=prompt_image, @@ -690,9 +688,6 @@ def __call__( max_sequence_length=max_sequence_length, ) if do_true_cfg: - # negative image is the same size as the original image, but all pixels are white - # negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255)) - negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( image=prompt_image, prompt=negative_prompt, @@ -717,7 +712,7 @@ def __call__( img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), - (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), ] ] * batch_size @@ -741,10 +736,17 @@ def __call__( self._num_timesteps = len(timesteps) # handle guidance - if self.transformer.config.guidance_embeds: + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) guidance = guidance.expand((latents.shape[0],)) - else: + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: guidance = None if self.attention_kwargs is None: @@ -839,12 +841,11 @@ def __call__( latents.dtype ) latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models - self.maybe_free_model_hooks() - if not return_dict: return (image,) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py new file mode 100644 index 0000000000..0794b84ebd --- /dev/null +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -0,0 +1,1103 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import mindspore as ms +from mindspore import mint +from transformers import Qwen2Tokenizer, Qwen2VLProcessor + +from ....transformers import Qwen2_5_VLForConditionalGeneration +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor, pynative_context +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + +XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from mindone.diffusers import QwenImageEditInpaintPipeline + >>> from mindone.diffusers.utils import load_image + + >>> pipe = QwenImageEditInpaintPipeline.from_pretrained("Qwen/Qwen-Image-Edit", mindspore_dtype=mindspore.bfloat16) + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe( + ... prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=1.0, num_inference_steps=50 + ... )[0][0] + >>> image.save("qwenimage_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.calculate_dimensions +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, noise, image_latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if masked_image.dim() == 4: + masked_image = masked_image.unsqueeze(2) + elif masked_image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == self.latent_channels: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + + # height and width are the same as the calculated height and width + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + original_image = image + prompt_image = image + image = self.image_processor.preprocess( + image, + height=calculated_height, + width=calculated_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + image = image.to(dtype=torch.float32) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, noise, image_latents = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 6e572830cb..4b2c719e3c 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -53,7 +53,6 @@ def retrieve_latents( else: return encoder_output - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift def calculate_shift( image_seq_len, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 6b099cc9db..bae2aaf342 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -15,7 +15,7 @@ from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging -from ...utils.mindspore_utils import randn_tensor +from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -44,17 +44,17 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" + vae, encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents + if sample_mode == "sample": + return vae.diag_gauss_dist.sample(encoder_output, generator=generator) + elif sample_mode == "argmax": + return vae.diag_gauss_dist.mode(encoder_output) + # This brach is not needed because the encoder_output type is ms.Tensor as per AutoencoderKLOuput change + # elif hasattr(encoder_output, "latents"): + # return encoder_output.latents else: - raise AttributeError("Could not access latents of provided encoder_output") - + return encoder_output # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift def calculate_shift( @@ -233,14 +233,16 @@ def _get_qwen_prompt_embeds( # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = mint.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])[0]) + for i in range(image.shape[0]) + ] + image_latents = mint.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae, self.vae.encode(image)[0]) latents_mean = ( ms.Tensor(self.vae.config.latents_mean) @@ -962,7 +964,9 @@ def __call__( ) latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) if padding_mask_crop is not None: diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index f74ed2f57a..0b4e95acbd 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -244,7 +244,7 @@ def test_inference(self, mode, dtype): negative_prompt="bad quality", )[0][0] - # The text_coder causes deviations between ms and pt versions, but the test result (2.809) \ + # The text_coder causes deviations between ms and pt versions. However, the deviation\ # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( # "mindone-testing-arrays", From c1998e4fe1eb560d74f3f96e7ffa848fe38c517b Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Sep 2025 14:07:51 +0800 Subject: [PATCH 42/77] 2025/9/3 14:07, img2img test --- .../pipelines/qwenimage/test_qwenimage.py | 1 - .../qwenimage/test_qwenimage_img2img.py | 392 ++++++++++-------- 2 files changed, 209 insertions(+), 184 deletions(-) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index 0b4e95acbd..a671369b96 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -165,7 +165,6 @@ def get_dummy_components(self): "tokenizer", ] } - return get_pipeline_components(components, self.pipeline_config) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py index afdbd2c44b..19efc5732a 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py @@ -5,217 +5,243 @@ import unittest import numpy as np +import pytest import torch -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from ddt import data, ddt, unpack +from transformers import Qwen2_5_VLConfig -from diffusers import ( +import minsdspore as ms + +from mindone.diffusers import ( AutoencoderKLQwenImage, - FlowMatchEulerDiscreteScheduler, QwenImageImg2ImgPipeline, QwenImageTransformer2DModel, ) -from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, +from mindone.diffusers.utils.testing_utils import ( + load_numpy_from_local_file, + slow, + floats_tensor, ) -from ..test_pipelines_common import PipelineTesterMixin, to_np - - -enable_full_determinism() +from ..pipeline_test_utils import ( + THRESHOLD_FP16, + THRESHOLD_FP32, + THRESHOLD_PIXEL, + PipelineTesterMixin, + get_module, + get_pipeline_components, +) +test_cases = [ + {"mode": ms.PYNATIVE_MODE, "dtype": "float32"}, + {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, +] +@ddt class QwenImageImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = QwenImageImg2ImgPipeline - params = frozenset(["prompt", "image", "height", "width", "guidance_scale", "true_cfg_scale", "strength"]) - batch_params = frozenset(["prompt", "image"]) - image_params = frozenset(["image"]) - image_latents_params = frozenset(["latents"]) - required_optional_params = frozenset( + pipeline_config = [ + [ + "transformer", + "diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + dict( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ), + ], [ - "num_inference_steps", - "generator", - "latents", - "return_dict", - "callback_on_step_end", - "callback_on_step_end_tensor_inputs", - ] - ) - supports_dduf = False - test_xformers_attention = False - test_attention_slicing = True - test_layerwise_casting = True - test_group_offloading = True + "vae", + "diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + "mindone.diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + dict( + base_dim=4 * 6, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ), + ], + [ + "scheduler", + "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + dict(), + ], + [ + "text_encoder", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + "mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + dict( + config=Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + attention_dropout=0.0, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + rms_norm_eps=1e-06, + max_position_embeddings=128000, + hidden_size=16, + hidden_act="silu", + intermediate_size=16, + initializer_range=0.02, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + sliding_window=32768, #None + use_sliding_window=False, + use_cache=True, + attn_implementation="eager", + rope_scaling={ + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + rope_theta=1000000.0, + ), + ), + ], + [ + "tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + dict( + # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + local_files_only=True, + trust_remote_code=True, + ), + ], + ] def get_dummy_components(self): - torch.manual_seed(0) - transformer = QwenImageTransformer2DModel( - patch_size=2, - in_channels=16, - out_channels=4, - num_layers=2, - attention_head_dim=16, - num_attention_heads=3, - joint_attention_dim=16, - guidance_embeds=False, - axes_dims_rope=(8, 4, 4), - ) - - torch.manual_seed(0) - z_dim = 4 - vae = AutoencoderKLQwenImage( - base_dim=z_dim * 6, - z_dim=z_dim, - dim_mult=[1, 2, 4], - num_res_blocks=1, - temperal_downsample=[False, True], - latents_mean=[0.0] * 4, - latents_std=[1.0] * 4, - ) - - torch.manual_seed(0) - scheduler = FlowMatchEulerDiscreteScheduler() - - torch.manual_seed(0) - config = Qwen2_5_VLConfig( - text_config={ - "hidden_size": 16, - "intermediate_size": 16, - "num_hidden_layers": 2, - "num_attention_heads": 2, - "num_key_value_heads": 2, - "rope_scaling": { - "mrope_section": [1, 1, 2], - "rope_type": "default", - "type": "default", - }, - "rope_theta": 1000000.0, - }, - vision_config={ - "depth": 2, - "hidden_size": 16, - "intermediate_size": 16, - "num_heads": 2, - "out_hidden_size": 16, - }, - hidden_size=16, - vocab_size=152064, - vision_end_token_id=151653, - vision_start_token_id=151652, - vision_token_id=151654, - ) - text_encoder = Qwen2_5_VLForConditionalGeneration(config) - tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") - - return { - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - "text_encoder": text_encoder, - "tokenizer": tokenizer, + components = { + key: None + for key in [ + "transformer", + "vae", + "scheduler", + "text_encoder", + "tokenizer", + ] } + return get_pipeline_components(components, self.pipeline_config) - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - + def get_dummy_inputs(self): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) inputs = { "image": image, "prompt": "dance monkey", "negative_prompt": "bad quality", - "generator": generator, "num_inference_steps": 2, "guidance_scale": 3.0, "true_cfg_scale": 1.0, "height": 32, "width": 32, "max_sequence_length": 16, - "output_type": "pt", + "output_type": "np", } return inputs - def test_inference(self): - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - generated_image = image[0] - self.assertEqual(generated_image.shape, (3, 32, 32)) - - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) - - def test_attention_slicing_forward_pass( - self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 - ): - if not self.test_attention_slicing: - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - output_without_slicing = pipe(**inputs).images[0] - - pipe.enable_attention_slicing(slice_size=1) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing1 = pipe(**inputs).images[0] - - pipe.enable_attention_slicing(slice_size=2) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing2 = pipe(**inputs).images[0] - - if test_max_difference: - max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() - max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() - self.assertLess( - max(max_diff1, max_diff2), - expected_max_diff, - "Attention slicing should not affect the inference results", - ) - - def test_vae_tiling(self, expected_diff_max: float = 0.2): - generator_device = "cpu" - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe.to("cpu") - pipe.set_progress_bar_config(disable=None) - - # Without tiling - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_without_tiling = pipe(**inputs)[0] - - # With tiling - pipe.vae.enable_tiling( - tile_sample_min_height=96, - tile_sample_min_width=96, - tile_sample_stride_height=64, - tile_sample_stride_width=64, - ) - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_with_tiling = pipe(**inputs)[0] - - self.assertLess( - (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), - expected_diff_max, - "VAE tiling should not affect the inference results", + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + ms.set_context(mode=mode) + + pt_components, ms_components = self.get_dummy_components() + pt_pipe_cls = get_module("diffusers.pipelines.qwenimage.QwenImageImg2ImgPipeline") + ms_pipe_cls = get_module("mindone.diffusers.pipelines.qwenimage.QwenImageImg2ImgPipeline") + + pt_pipe = pt_pipe_cls(**pt_components) + ms_pipe = ms_pipe_cls(**ms_components) + + pt_pipe.set_progress_bar_config(disable=None) + ms_pipe.set_progress_bar_config(disable=None) + + ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype) + pt_pipe = pt_pipe.to(pt_dtype) + ms_pipe = ms_pipe.to(ms_dtype) + + inputs = self.get_dummy_inputs() + + torch.manual_seed(0) + pt_image = pt_pipe(**inputs).images + torch.manual_seed(0) + ms_image = ms_pipe(**inputs)[0] + + pt_generated_image = pt_image[0] + ms_generated_image = ms_image[0] + + threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16 + assert np.max(np.linalg.norm(pt_generated_image - ms_generated_image) / np.linalg.norm(pt_generated_image)) < threshold + + +@slow +@ddt +class QwenImageImg2ImgPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + if dtype == "float32": + pytest.skip("Skipping this case since this pipeline will OOM in float32") + + ms.set_context(mode=mode) + ms_dtype = getattr(ms, dtype) + + # model_id = "Qwen/Qwen-Image" + model_id = "/data6/Qwen-Image" + image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image + + pipe = QwenImageImg2ImgPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) + + pipe.vae.enable_tiling() + + torch.manual_seed(0) + image = pipe( + image=image, + prompt="dance monkey", + negative_prompt="bad quality", + )[0][0] + + # The text_coder causes deviations between ms and pt versions. However, the deviation\ + # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. + expected_image = load_numpy_from_local_file( + # "mindone-testing-arrays", + "/data4/mindone-testing-arrays", + f"qwenimage_i2i_{dtype}.npy", + subfolder="qwenimage", ) + + assert np.mean(np.abs(np.array(image, dtype=np.float32) - expected_image)) < THRESHOLD_PIXEL From ae1f7b2d592df4528282e4091642624d274484a0 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Sep 2025 14:18:04 +0800 Subject: [PATCH 43/77] 2025/9/3 14:18, img2img test --- .../pipelines/qwenimage/test_qwenimage.py | 3 -- .../qwenimage/test_qwenimage_img2img.py | 33 ++++++++++++++----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index a671369b96..bc60ba2cbe 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -188,9 +188,6 @@ def get_dummy_inputs(self): def test_inference(self, mode, dtype): ms.set_context(mode=mode) - components = self.get_dummy_components() - ms.set_context(mode=mode) - pt_components, ms_components = self.get_dummy_components() pt_pipe_cls = get_module("diffusers.pipelines.qwenimage.QwenImagePipeline") ms_pipe_cls = get_module("mindone.diffusers.pipelines.qwenimage.QwenImagePipeline") diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py index 19efc5732a..69748ab905 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py @@ -10,7 +10,7 @@ from ddt import data, ddt, unpack from transformers import Qwen2_5_VLConfig -import minsdspore as ms +import mindspore as ms from mindone.diffusers import ( AutoencoderKLQwenImage, @@ -158,10 +158,25 @@ def get_dummy_components(self): } return get_pipeline_components(components, self.pipeline_config) - def get_dummy_inputs(self): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) - inputs = { - "image": image, + def get_dummy_inputs(self, seed): + pt_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) + ms_image = ms.Tensor(pt_image.numpy()) + + pt_inputs = { + "image": pt_image, + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } + + ms_inputs = { + "image": ms_image, "prompt": "dance monkey", "negative_prompt": "bad quality", "num_inference_steps": 2, @@ -173,7 +188,7 @@ def get_dummy_inputs(self): "output_type": "np", } - return inputs + return pt_inputs, ms_inputs @data(*test_cases) @unpack @@ -194,12 +209,12 @@ def test_inference(self, mode, dtype): pt_pipe = pt_pipe.to(pt_dtype) ms_pipe = ms_pipe.to(ms_dtype) - inputs = self.get_dummy_inputs() + pt_inputs, ms_inputs = self.get_dummy_inputs() torch.manual_seed(0) - pt_image = pt_pipe(**inputs).images + pt_image = pt_pipe(**pt_inputs).images torch.manual_seed(0) - ms_image = ms_pipe(**inputs)[0] + ms_image = ms_pipe(**ms_inputs)[0] pt_generated_image = pt_image[0] ms_generated_image = ms_image[0] From be1b18b80e1728f0d3773f81492b806051e5ef3a Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Sep 2025 16:30:23 +0800 Subject: [PATCH 44/77] 2025/9/3 16:30, inpaint test --- .../qwenimage/pipeline_qwenimage_edit.py | 2 +- .../qwenimage/test_qwenimage_img2img.py | 13 +- .../qwenimage/test_qwenimage_inpaint.py | 417 ++++++++++-------- 3 files changed, 241 insertions(+), 191 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 8158c9bd55..c52049d103 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -30,7 +30,7 @@ from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging #, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import randn_tensor, pynative_context +from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py index 69748ab905..5df9730dd1 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py @@ -2,6 +2,7 @@ # with modifications to run diffusers on mindspore. import random +import sys import unittest import numpy as np @@ -20,7 +21,6 @@ from mindone.diffusers.utils.testing_utils import ( load_numpy_from_local_file, slow, - floats_tensor, ) from ..pipeline_test_utils import ( @@ -28,8 +28,10 @@ THRESHOLD_FP32, THRESHOLD_PIXEL, PipelineTesterMixin, + floats_tensor, get_module, get_pipeline_components, + randn_tensor, ) test_cases = [ @@ -158,7 +160,7 @@ def get_dummy_components(self): } return get_pipeline_components(components, self.pipeline_config) - def get_dummy_inputs(self, seed): + def get_dummy_inputs(self, seed=0): pt_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) ms_image = ms.Tensor(pt_image.numpy()) @@ -209,6 +211,9 @@ def test_inference(self, mode, dtype): pt_pipe = pt_pipe.to(pt_dtype) ms_pipe = ms_pipe.to(ms_dtype) + sys.modules[ms_pipe.__module__].randn_tensor = randn_tensor + sys.modules[ms_pipe.vae.diag_gauss_dist.__module__].randn_tensor = randn_tensor + pt_inputs, ms_inputs = self.get_dummy_inputs() torch.manual_seed(0) @@ -238,14 +243,14 @@ def test_inference(self, mode, dtype): # model_id = "Qwen/Qwen-Image" model_id = "/data6/Qwen-Image" image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image - + pipe = QwenImageImg2ImgPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) pipe.vae.enable_tiling() torch.manual_seed(0) image = pipe( - image=image, + image=ms.Tensor(image.numpy()), prompt="dance monkey", negative_prompt="bad quality", )[0][0] diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py index 11845e130f..946c6404af 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py @@ -16,221 +16,266 @@ # limitations under the License. import random +import sys import unittest import numpy as np +import pytest import torch -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from ddt import data, ddt, unpack +from transformers import Qwen2_5_VLConfig + +import mindspore as ms from diffusers import ( AutoencoderKLQwenImage, - FlowMatchEulerDiscreteScheduler, QwenImageInpaintPipeline, QwenImageTransformer2DModel, ) -from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device - -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np - +from mindone.diffusers.utils.testing_utils import ( + load_numpy_from_local_file, + slow, +) -enable_full_determinism() +from ..pipeline_test_utils import ( + THRESHOLD_FP16, + THRESHOLD_FP32, + THRESHOLD_PIXEL, + PipelineTesterMixin, + floats_tensor, + get_module, + get_pipeline_components, + randn_tensor, +) +test_cases = [ + {"mode": ms.PYNATIVE_MODE, "dtype": "float32"}, + {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, +] +@ddt class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = QwenImageInpaintPipeline - params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - required_optional_params = frozenset( + pipeline_config = [ + [ + "transformer", + "diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + dict( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ), + ], + [ + "vae", + "diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + "mindone.diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + dict( + base_dim=4 * 6, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ), + ], [ - "num_inference_steps", - "generator", - "latents", - "return_dict", - "callback_on_step_end", - "callback_on_step_end_tensor_inputs", - ] - ) - supports_dduf = False - test_xformers_attention = False - test_layerwise_casting = True - test_group_offloading = True + "scheduler", + "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + dict(), + ], + [ + "text_encoder", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + "mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + dict( + config=Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + attention_dropout=0.0, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + rms_norm_eps=1e-06, + max_position_embeddings=128000, + hidden_size=16, + hidden_act="silu", + intermediate_size=16, + initializer_range=0.02, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + sliding_window=32768, #None + use_sliding_window=False, + use_cache=True, + attn_implementation="eager", + rope_scaling={ + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + rope_theta=1000000.0, + ), + ), + ], + [ + "tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + dict( + # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + local_files_only=True, + trust_remote_code=True, + ), + ], + ] def get_dummy_components(self): - torch.manual_seed(0) - transformer = QwenImageTransformer2DModel( - patch_size=2, - in_channels=16, - out_channels=4, - num_layers=2, - attention_head_dim=16, - num_attention_heads=3, - joint_attention_dim=16, - guidance_embeds=False, - axes_dims_rope=(8, 4, 4), - ) - - torch.manual_seed(0) - z_dim = 4 - vae = AutoencoderKLQwenImage( - base_dim=z_dim * 6, - z_dim=z_dim, - dim_mult=[1, 2, 4], - num_res_blocks=1, - temperal_downsample=[False, True], - # fmt: off - latents_mean=[0.0] * 4, - latents_std=[1.0] * 4, - # fmt: on - ) - - torch.manual_seed(0) - scheduler = FlowMatchEulerDiscreteScheduler() - - torch.manual_seed(0) - config = Qwen2_5_VLConfig( - text_config={ - "hidden_size": 16, - "intermediate_size": 16, - "num_hidden_layers": 2, - "num_attention_heads": 2, - "num_key_value_heads": 2, - "rope_scaling": { - "mrope_section": [1, 1, 2], - "rope_type": "default", - "type": "default", - }, - "rope_theta": 1000000.0, - }, - vision_config={ - "depth": 2, - "hidden_size": 16, - "intermediate_size": 16, - "num_heads": 2, - "out_hidden_size": 16, - }, - hidden_size=16, - vocab_size=152064, - vision_end_token_id=151653, - vision_start_token_id=151652, - vision_token_id=151654, - ) - text_encoder = Qwen2_5_VLForConditionalGeneration(config) - tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") - components = { - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - "text_encoder": text_encoder, - "tokenizer": tokenizer, + key: None + for key in [ + "transformer", + "vae", + "scheduler", + "text_encoder", + "tokenizer", + ] } - return components + return get_pipeline_components(components, self.pipeline_config) - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - mask_image = torch.ones((1, 1, 32, 32)).to(device) - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) + def get_dummy_inputs(self, seed=0): + pt_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) + ms_image = ms.Tensor(pt_image.numpy()) + + pt_inputs = { + "image": pt_image, + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } - inputs = { + ms_inputs = { + "image": ms_image, "prompt": "dance monkey", "negative_prompt": "bad quality", - "image": image, - "mask_image": mask_image, - "generator": generator, "num_inference_steps": 2, "guidance_scale": 3.0, "true_cfg_scale": 1.0, "height": 32, "width": 32, "max_sequence_length": 16, - "output_type": "pt", + "output_type": "np", } - return inputs - - def test_inference(self): - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - generated_image = image[0] - self.assertEqual(generated_image.shape, (3, 32, 32)) - - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) - - def test_attention_slicing_forward_pass( - self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 - ): - if not self.test_attention_slicing: - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - output_without_slicing = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=1) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing1 = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=2) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing2 = pipe(**inputs)[0] - - if test_max_difference: - max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() - max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() - self.assertLess( - max(max_diff1, max_diff2), - expected_max_diff, - "Attention slicing should not affect the inference results", - ) - - def test_vae_tiling(self, expected_diff_max: float = 0.2): - generator_device = "cpu" - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe.to("cpu") - pipe.set_progress_bar_config(disable=None) - - # Without tiling - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_without_tiling = pipe(**inputs)[0] - - # With tiling - pipe.vae.enable_tiling( - tile_sample_min_height=96, - tile_sample_min_width=96, - tile_sample_stride_height=64, - tile_sample_stride_width=64, - ) - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_with_tiling = pipe(**inputs)[0] - - self.assertLess( - (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), - expected_diff_max, - "VAE tiling should not affect the inference results", + return pt_inputs, ms_inputs + + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + ms.set_context(mode=mode) + + pt_components, ms_components = self.get_dummy_components() + pt_pipe_cls = get_module("diffusers.pipelines.qwenimage.QwenImageInpaintPipeline") + ms_pipe_cls = get_module("mindone.diffusers.pipelines.qwenimage.QwenImageInpaintPipeline") + + pt_pipe = pt_pipe_cls(**pt_components) + ms_pipe = ms_pipe_cls(**ms_components) + + pt_pipe.set_progress_bar_config(disable=None) + ms_pipe.set_progress_bar_config(disable=None) + + ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype) + pt_pipe = pt_pipe.to(pt_dtype) + ms_pipe = ms_pipe.to(ms_dtype) + + sys.modules[ms_pipe.__module__].randn_tensor = randn_tensor + sys.modules[ms_pipe.vae.diag_gauss_dist.__module__].randn_tensor = randn_tensor + + pt_inputs, ms_inputs = self.get_dummy_inputs() + + torch.manual_seed(0) + pt_image = pt_pipe(**pt_inputs).images + torch.manual_seed(0) + ms_image = ms_pipe(**ms_inputs)[0] + + pt_generated_image = pt_image[0] + ms_generated_image = ms_image[0] + + threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16 + assert np.max(np.linalg.norm(pt_generated_image - ms_generated_image) / np.linalg.norm(pt_generated_image)) < threshold + + +@slow +@ddt +class QwenImageInpaintPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + if dtype == "float32": + pytest.skip("Skipping this case since this pipeline will OOM in float32") + + ms.set_context(mode=mode) + ms_dtype = getattr(ms, dtype) + + # model_id = "Qwen/Qwen-Image" + model_id = "/data6/Qwen-Image" + image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image + + pipe = QwenImageInpaintPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) + + pipe.vae.enable_tiling() + + torch.manual_seed(0) + image = pipe( + image=ms.Tensor(image.numpy()), + prompt="dance monkey", + negative_prompt="bad quality", + )[0][0] + + # The text_coder causes deviations between ms and pt versions. However, the deviation\ + # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. + expected_image = load_numpy_from_local_file( + # "mindone-testing-arrays", + "/data4/mindone-testing-arrays", + f"qwenimage_inpaint_{dtype}.npy", + subfolder="qwenimage", ) + + assert np.mean(np.abs(np.array(image, dtype=np.float32) - expected_image)) < THRESHOLD_PIXEL From d6cf7e10b5912836061ec888cb71b1d1d83721e6 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Sep 2025 14:21:39 +0800 Subject: [PATCH 45/77] 2025/9/4 14:21, edit bugs --- .../qwenimage/pipeline_qwenimage_edit.py | 6 +- .../pipeline_qwenimage_edit_inpaint.py | 237 ++++++++---------- .../qwenimage/test_qwenimage_inpaint.py | 9 +- 3 files changed, 117 insertions(+), 135 deletions(-) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index c52049d103..85a0f5e5bc 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -248,7 +248,7 @@ def _get_qwen_prompt_embeds( ) hidden_states = outputs.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(model_inputs.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) @@ -622,7 +622,7 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[0].shape if isinstance(image, list) else image.shape + image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width @@ -659,7 +659,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Preprocess image - if image is not None and not (isinstance(image, ms.Tensor) and image.shape[1] == self.latent_channels): + if image is not None and not (isinstance(image, ms.Tensor) and image.size[1] == self.latent_channels): image = self.image_processor.resize(image, calculated_height, calculated_width) prompt_image = image image = self.image_processor.preprocess(image, calculated_height, calculated_width) diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 0794b84ebd..a290a50d28 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -80,7 +80,6 @@ def calculate_shift( def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, @@ -95,8 +94,6 @@ def retrieve_timesteps( num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. @@ -105,7 +102,7 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -117,7 +114,7 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: @@ -127,28 +124,28 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + vae, encoder_output: ms.Tensor, generator: Optional[np.random.Generator] = None, sample_mode: str = "sample" ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents + if sample_mode == "sample": + return vae.diag_gauss_dist.sample(encoder_output, generator=generator) + elif sample_mode == "argmax": + return vae.diag_gauss_dist.mode(encoder_output) + # This brach is not needed because the encoder_output type is ms.Tensor as per AutoencoderKLOuput change + # elif hasattr(encoder_output, "latents"): + # return encoder_output.latents else: - raise AttributeError("Could not access latents of provided encoder_output") - + return encoder_output # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.calculate_dimensions def calculate_dimensions(target_area, ratio): @@ -222,11 +219,11 @@ def __init__( self.default_sample_size = 128 # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + def _extract_masked_hidden(self, hidden_states: ms.Tensor, mask: ms.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] - split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + split_result = mint.split(selected, valid_lengths.tolist(), dim=0) return split_result @@ -234,11 +231,9 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - image: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + image: Optional[ms.Tensor] = None, + dtype: Optional[ms.dtype] = None, ): - device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -251,30 +246,30 @@ def _get_qwen_prompt_embeds( text=txt, images=image, padding=True, - return_tensors="pt", - ).to(device) + return_tensors="np", + ) outputs = self.text_encoder( - input_ids=model_inputs.input_ids, - attention_mask=model_inputs.attention_mask, - pixel_values=model_inputs.pixel_values, - image_grid_thw=model_inputs.image_grid_thw, + input_ids=ms.Tensor(model_inputs.input_ids), + attention_mask=ms.Tensor(model_inputs.attention_mask), + pixel_values=ms.Tensor(model_inputs.pixel_values), + image_grid_thw=ms.Tensor(model_inputs.image_grid_thw), output_hidden_states=True, ) hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] + max_seq_len = max([e.shape[0] for e in split_hidden_states]) + prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + encoder_attention_mask = mint.stack( + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0]))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype) return prompt_embeds, encoder_attention_mask @@ -282,11 +277,10 @@ def _get_qwen_prompt_embeds( def encode_prompt( self, prompt: Union[str, List[str]], - image: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, + image: Optional[ms.Tensor] = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, max_sequence_length: int = 1024, ): r""" @@ -294,23 +288,19 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - image (`torch.Tensor`, *optional*): + image (`ms.Tensor`, *optional*): image to be encoded - device: (`torch.device`): - torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. """ - device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -423,23 +413,25 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])) + for i in range(image.shape[0]) + ] + image_latents = mint.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae, self.vae.encode(image)) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(image_latents.device, image_latents.dtype) + .to(image_latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - image_latents.device, image_latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.dtype ) image_latents = (image_latents - latents_mean) * latents_std @@ -447,7 +439,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -497,7 +489,6 @@ def prepare_latents( height, width, dtype, - device, generator, latents=None, ): @@ -520,9 +511,9 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - return latents.to(device=device, dtype=dtype) + return latents.to(dtype=dtype) - image = image.to(device=device, dtype=dtype) + image = image.to(dtype=dtype) if image.shape[1] != self.latent_channels: image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] else: @@ -530,21 +521,21 @@ def prepare_latents( if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + image_latents = mint.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: - image_latents = torch.cat([image_latents], dim=0) + image_latents = mint.cat([image_latents], dim=0) image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: - noise = latents.to(device) + noise = latents latents = noise noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) @@ -564,7 +555,6 @@ def prepare_mask_latents( height, width, dtype, - device, generator, ): # VAE applies 8x compression on images but we must also account for packing which requires @@ -574,8 +564,8 @@ def prepare_mask_latents( # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) - mask = mask.to(device=device, dtype=dtype) + mask = mint.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -584,7 +574,7 @@ def prepare_mask_latents( elif masked_image.dim() != 5: raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") - masked_image = masked_image.to(device=device, dtype=dtype) + masked_image = masked_image.to(dtype=dtype) if masked_image.shape[1] == self.latent_channels: masked_image_latents = masked_image @@ -609,8 +599,7 @@ def prepare_mask_latents( ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = masked_image_latents.to(dtype=dtype) masked_image_latents = self._pack_latents( masked_image_latents, @@ -649,8 +638,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Optional[PipelineImageInput] = None, @@ -667,12 +654,12 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: Optional[float] = None, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + prompt_embeds_mask: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_mask: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -684,7 +671,7 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a @@ -704,14 +691,14 @@ def __call__( enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + mask_image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + mask_image_latent (`ms.Tensor`, `List[ms.Tensor]`): `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask latents tensor will ge generated by `mask_image`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -750,17 +737,17 @@ def __call__( enable classifier-free guidance computations). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/np.random.Generator.html) to make generation deterministic. - latents (`torch.Tensor`, *optional*): + latents (`ms.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): + prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): + negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. @@ -834,7 +821,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device # 3. Preprocess image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) @@ -843,7 +829,7 @@ def __call__( crops_coords = None resize_mode = "default" - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if image is not None and not (isinstance(image, ms.Tensor) and image.size(1) == self.latent_channels): image = self.image_processor.resize(image, calculated_height, calculated_width) original_image = image prompt_image = image @@ -854,7 +840,7 @@ def __call__( crops_coords=crops_coords, resize_mode=resize_mode, ) - image = image.to(dtype=torch.float32) + image = image.to(dtype=ms.float32) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -875,7 +861,6 @@ def __call__( prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -885,7 +870,6 @@ def __call__( prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, - device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) @@ -903,12 +887,11 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, sigmas=sigmas, mu=mu, ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) if num_inference_steps < 1: raise ValueError( @@ -927,7 +910,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, latents, ) @@ -950,7 +932,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, ) @@ -968,7 +949,7 @@ def __call__( if self.transformer.config.guidance_embeds and guidance_scale is None: raise ValueError("guidance_scale is required for guidance-distilled model.") elif self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = mint.full([1], guidance_scale, dtype=ms.float32) guidance = guidance.expand(latents.shape[0]) elif not self.transformer.config.guidance_embeds and guidance_scale is not None: logger.warning( @@ -996,23 +977,23 @@ def __call__( latent_model_input = latents if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) + latent_model_input = mint.cat([latents, image_latents], dim=1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] + # with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.shape[1]] if do_true_cfg: with self.transformer.cache_context("uncond"): @@ -1027,11 +1008,11 @@ def __call__( attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.shape[1]] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + cond_norm = mint.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = mint.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) # compute the previous noisy sample x_t -> x_t-1 @@ -1045,15 +1026,13 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise + init_latents_proper, ms.Tensor([noise_timestep]), noise ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} @@ -1068,9 +1047,6 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - self._current_timestep = None if output_type == "latent": image = latents @@ -1078,15 +1054,17 @@ def __call__( latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = latents.to(self.vae.dtype) latents_mean = ( - torch.tensor(self.vae.config.latents_mean) + ms.Tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents_std = 1.0 / ms.Tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype ) latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) if padding_mask_crop is not None: @@ -1094,9 +1072,6 @@ def __call__( self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image ] - # Offload all models - self.maybe_free_model_hooks() - if not return_dict: return (image,) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py index 946c6404af..8dca8335bf 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py @@ -27,7 +27,7 @@ import mindspore as ms -from diffusers import ( +from mindone.diffusers import ( AutoencoderKLQwenImage, QwenImageInpaintPipeline, QwenImageTransformer2DModel, @@ -176,13 +176,17 @@ def get_dummy_components(self): def get_dummy_inputs(self, seed=0): pt_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) + pt_mask_image = torch.ones((1, 1, 32, 32)) ms_image = ms.Tensor(pt_image.numpy()) + ms_mask_image = ms.mint.ones((1, 1, 32, 32)) pt_inputs = { "image": pt_image, + "mask_image": pt_mask_image, "prompt": "dance monkey", "negative_prompt": "bad quality", "num_inference_steps": 2, + "guidance_scale": 3.0, "true_cfg_scale": 1.0, "height": 32, @@ -193,6 +197,7 @@ def get_dummy_inputs(self, seed=0): ms_inputs = { "image": ms_image, + "mask_image": ms_mask_image, "prompt": "dance monkey", "negative_prompt": "bad quality", "num_inference_steps": 2, @@ -257,6 +262,7 @@ def test_inference(self, mode, dtype): # model_id = "Qwen/Qwen-Image" model_id = "/data6/Qwen-Image" image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image + mask_image = ms.mint.ones((1, 1, 32, 32)) pipe = QwenImageInpaintPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) @@ -265,6 +271,7 @@ def test_inference(self, mode, dtype): torch.manual_seed(0) image = pipe( image=ms.Tensor(image.numpy()), + mask_image=mask_image, prompt="dance monkey", negative_prompt="bad quality", )[0][0] From dfc5e23f72a9d736ec91300f8dd01281e38966c6 Mon Sep 17 00:00:00 2001 From: GUOGUO <55723162+Dong1017@users.noreply.github.com> Date: Thu, 4 Sep 2025 15:58:34 +0800 Subject: [PATCH 46/77] 2025/9/4 15:58, edit ut --- mindone/diffusers/__init__.py | 1 + .../models/BAK_model_loading_utils.py | 702 ++++ .../diffusers/models/BAK_modeling_utils.py | 1307 +++++++ .../diffusers/models/model_loading_utils.py | 11 +- mindone/diffusers/models/modeling_patch.py | 49 + mindone/diffusers/models/modeling_utils.py | 8 +- mindone/diffusers/pipelines/__init__.py | 1 + .../diffusers/pipelines/qwenimage/__init__.py | 2 + mindone/transformers/BAK_modeling_utils.py | 3210 +++++++++++++++++ mindone/transformers/modeling_patch.py | 49 + mindone/transformers/modeling_utils.py | 26 +- .../qwenimage/test_qwenimage_edit.py | 274 ++ 12 files changed, 5628 insertions(+), 12 deletions(-) create mode 100644 mindone/diffusers/models/BAK_model_loading_utils.py create mode 100644 mindone/diffusers/models/BAK_modeling_utils.py create mode 100644 mindone/diffusers/models/modeling_patch.py create mode 100644 mindone/transformers/BAK_modeling_utils.py create mode 100644 mindone/transformers/modeling_patch.py create mode 100644 tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index c607ae65e7..78ff933bdf 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -216,6 +216,7 @@ "QwenImageInpaintPipeline", "QwenImagePipeline", "QwenImageEditPipeline", + "QwenImageEditInpaintPipeline", "ReduxImageEncoder", "SanaControlNetPipeline", "SanaPAGPipeline", diff --git a/mindone/diffusers/models/BAK_model_loading_utils.py b/mindone/diffusers/models/BAK_model_loading_utils.py new file mode 100644 index 0000000000..34bc6059db --- /dev/null +++ b/mindone/diffusers/models/BAK_model_loading_utils.py @@ -0,0 +1,702 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import importlib +import json +import os +from collections import OrderedDict +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass, field +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from huggingface_hub import DDUFEntry +from huggingface_hub.utils import EntryNotFoundError + +import mindspore as ms +from mindspore import nn, ops + +from ...safetensors.mindspore import load as safe_load +from ...safetensors.mindspore import load_file as safe_load_file +from ..utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_FILE_EXTENSION, + WEIGHTS_INDEX_NAME, + _add_variant, + _get_model_file, + deprecate, + logging, +) + +logger = logging.get_logger(__name__) + +_CLASS_REMAPPING_DICT = { + "Transformer2DModel": { + "ada_norm_zero": "DiTTransformer2DModel", + "ada_norm_single": "PixArtTransformer2DModel", + } +} + + +def _fetch_remapped_cls_from_config(config, old_class): + previous_class_name = old_class.__name__ + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) + + # Details: + # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818 + if remapped_class_name: + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0] + ".diffusers") + remapped_class = getattr(diffusers_library, remapped_class_name) + logger.info( + f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." + f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" + " DOESN'T affect the final results." + ) + return remapped_class + else: + return old_class + + +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + disable_mmap: bool = False, +): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + # TODO: maybe refactor a bit this part where we pass a dict here + if isinstance(checkpoint_file, dict): + return checkpoint_file + try: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + if dduf_entries: + # tensors are loaded on cpu + with dduf_entries[checkpoint_file].as_mmap() as mm: + return safe_load(mm) + if disable_mmap: + return safe_load(open(checkpoint_file, "rb").read()) + else: + return safe_load_file(checkpoint_file) + else: + raise NotImplementedError( + f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. " + ) + + +def _load_state_dict_into_model( + model_to_load, state_dict: OrderedDict, keep_in_fp32_modules=None, dtype=None, is_sharded=False +) -> List[str]: + # TODO: error_msgs is always empty for now. Maybe we need to rewrite MindSpore's `load_param_into_net`. + # Error msgs should contain caught exception like size mismatch instead of missing/unexpected keys. + # TODO: We should support loading float16 state_dict into float32 model, like PyTorch's behavior. + error_msgs = [] + # TODO: State dict loading in mindspore does not cast dtype correctly. We do it manually. It's might unsafe. + local_state = {k: v for k, v in model_to_load.parameters_and_names()} + for k, v in state_dict.items(): + if k in local_state: + # todo: unavailable mint interface + if ops.is_floating_point(v): + if ( + keep_in_fp32_modules is not None + and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) + and dtype == ms.float16 + ): + v.set_dtype(ms.float32) + else: + v.set_dtype(local_state[k].dtype) + else: + v.set_dtype(local_state[k].dtype) + else: + pass # unexpect key keeps origin dtype + cm = silence_mindspore_logger() if is_sharded else nullcontext() + with cm: + ms.load_param_into_net(model_to_load, state_dict, strict_load=True) + return error_msgs + + +def _fetch_index_file( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ) + else: + index_file_in_repo = Path( + subfolder or "", + _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), + ).as_posix() + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=None, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + if not dduf_entries: + index_file = Path(index_file) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file + + +def _fetch_index_file_legacy( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file.split(".") + split_index = -3 if ".cache" in index_file else -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file = ".".join(splits) + if os.path.exists(index_file): + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." # noqa: E501 + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + index_file = Path(index_file) + else: + index_file = None + else: + if variant is not None: + index_file_in_repo = Path( + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file_in_repo.split(".") + split_index = -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file_in_repo = ".".join(splits) + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=None, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + index_file = Path(index_file) + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." # noqa: E501 + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file + + +# =============================================== +# Sharded loading utils by huggingface Accelerate +# =============================================== + +_SAFE_MODEL_NAME = "model" +_SAFE_WEIGHTS_NAME = f"{_SAFE_MODEL_NAME}.safetensors" + + +# Copied from mindone.transformers.modeling_utils.silence_mindspore_logger +@contextmanager +def silence_mindspore_logger(): + ms_logger = ms.log._get_logger() + ms_level = ms_logger.level + ms_logger.setLevel("ERROR") + yield + ms_logger.setLevel(ms_level) + + +def load_checkpoint_and_dispatch( + model: nn.Cell, + checkpoint: Union[str, os.PathLike], + dtype: Optional[Union[str, ms.Type]] = None, + keep_in_fp32_modules=None, + strict: bool = False, +): + """ + Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are + loaded and adds the various hooks that will make this model run properly (even if split across devices). + + Args: + model (`mindspore.nn.Cell`): The model in which we want to load a checkpoint. + checkpoint (`str` or `os.PathLike`): + The folder checkpoint to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + dtype (`str` or `mindspore.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + force_hooks (`bool`, *optional*, defaults to `False`): + Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a + single device. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's + state_dict. + + Example: + + ```python + >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch + >>> from huggingface_hub import hf_hub_download + >>> from transformers import AutoConfig, AutoModelForCausalLM + + >>> # Download the Weights + >>> checkpoint = "EleutherAI/gpt-j-6B" + >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin") + + >>> # Create a model and initialize it with empty weights + >>> config = AutoConfig.from_pretrained(checkpoint) + >>> with init_empty_weights(): + ... model = AutoModelForCausalLM.from_config(config) + + >>> # Load the checkpoint and dispatch it to the right devices + >>> model = load_checkpoint_and_dispatch( + ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"] + ... ) + ``` + """ + + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("mindspore.", "") + dtype = getattr(ms, dtype) + + checkpoint_files = None + index_filename = None + if os.path.isfile(checkpoint): + if str(checkpoint).endswith(".json"): + index_filename = checkpoint + else: + checkpoint_files = [checkpoint] + elif os.path.isdir(checkpoint): + # check if the whole state dict is present + potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == _SAFE_WEIGHTS_NAME] + if len(potential_state_safetensor) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + else: + # otherwise check for sharded checkpoints + potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + if len(potential_index) == 0: + raise ValueError( + f"{checkpoint} is not a folder containing a `.index.json` file or a {_SAFE_WEIGHTS_NAME} file" + ) + elif len(potential_index) == 1: + index_filename = os.path.join(checkpoint, potential_index[0]) + else: + raise ValueError( + f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." + ) + else: + raise ValueError( + "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " + f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." + ) + + if index_filename is not None: + checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename) as f: + index = json.loads(f.read()) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = sorted(list(set(index.values()))) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + + # Logic for missing/unexepected keys goes here. + unexpected_keys = set() + model_keys = set(model.parameters_dict().keys()) + is_sharded = index_filename is not None + cm = silence_mindspore_logger() if is_sharded else nullcontext() + with cm: + for checkpoint_file in checkpoint_files: + loaded_checkpoint = load_state_dict(checkpoint_file) + _ = _load_state_dict_into_model(model, loaded_checkpoint, keep_in_fp32_modules, dtype) + unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys) + del loaded_checkpoint + gc.collect() + + if not strict and len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {checkpoint} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint." # noqa E501 + ) + + return model + + +# ============================================= +# Sharded saving by huggingface huggingface_hub +# ============================================= + +TensorT = TypeVar("TensorT") +TensorSizeFn_T = Callable[[TensorT], int] +StorageIDFn_T = Callable[[TensorT], Optional[Any]] + +_MAX_SHARD_SIZE = "5GB" +_SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors" +_SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +@dataclass +class StateDictSplit: + is_sharded: bool = field(init=False) + metadata: Dict[str, Any] + filename_to_tensors: Dict[str, List[str]] + tensor_to_filename: Dict[str, str] + + def __post_init__(self): + self.is_sharded = len(self.filename_to_tensors) > 1 + + +def split_state_dict_into_shards_factory( + state_dict: Dict[str, TensorT], + *, + get_storage_size: TensorSizeFn_T, + filename_pattern: str, + get_storage_id: StorageIDFn_T = lambda tensor: None, + max_shard_size: Union[int, str] = _MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + get_storage_size (`Callable[[Tensor], int]`): + A function that returns the size of a tensor when saved on disk in bytes. + get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): + A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the + same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage + during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + storage_id_to_tensors: Dict[Any, List[str]] = {} + + shard_list: List[Dict[str, TensorT]] = [] + current_shard: Dict[str, TensorT] = {} + current_shard_size = 0 + total_size = 0 + + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) + + for key, tensor in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(tensor, str): + logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) + continue + + # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block` + storage_id = get_storage_id(tensor) + if storage_id is not None: + if storage_id in storage_id_to_tensors: + # We skip this tensor for now and will reassign to correct shard later + storage_id_to_tensors[storage_id].append(key) + continue + else: + # This is the first tensor with this storage_id, we create a new entry + # in the storage_id_to_tensors dict => we will assign the shard id later + storage_id_to_tensors[storage_id] = [key] + + # Compute tensor size + tensor_size = get_storage_size(tensor) + + # If this tensor is bigger than the maximal size, we put it in its own shard + if tensor_size > max_shard_size: + total_size += tensor_size + shard_list.append({key: tensor}) + continue + + # If this tensor is going to tip up over the maximal size, we split. + # Current shard already has some tensors, we add it to the list of shards and create a new one. + if current_shard_size + tensor_size > max_shard_size: + shard_list.append(current_shard) + current_shard = {} + current_shard_size = 0 + + # Add the tensor to the current shard + current_shard[key] = tensor + current_shard_size += tensor_size + total_size += tensor_size + + # Add the last shard + if len(current_shard) > 0: + shard_list.append(current_shard) + nb_shards = len(shard_list) + + # Loop over the tensors that share the same storage and assign them together + for storage_id, keys in storage_id_to_tensors.items(): + # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard + for shard in shard_list: + if keys[0] in shard: + for key in keys: + shard[key] = state_dict[key] + break + + # If we only have one shard, we return it => no need to build the index + if nb_shards == 1: + filename = filename_pattern.format(suffix="") + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors={filename: list(state_dict.keys())}, + tensor_to_filename={key: filename for key in state_dict.keys()}, + ) + + # Now that each tensor is assigned to a shard, let's assign a filename to each shard + tensor_name_to_filename = {} + filename_to_tensors = {} + for idx, shard in enumerate(shard_list): + filename = filename_pattern.format(suffix=f"-{idx+1:05d}-of-{nb_shards:05d}") + for key in shard: + tensor_name_to_filename[key] = filename + filename_to_tensors[filename] = list(shard.keys()) + + # Build the index and return + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors=filename_to_tensors, + tensor_to_filename=tensor_name_to_filename, + ) + + +def parse_size_to_int(size_as_str: str) -> int: + """ + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". + + Args: + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> parse_size_to_int("5MB") + 5000000 + ``` + """ + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in _SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = _SIZE_UNITS[unit] + + # Parse value + try: + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) + + +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "ms.Tensor"], + *, + filename_pattern: str = _SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = _MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, ms.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, ms.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + ) + + +def get_torch_storage_size(tensor: "ms.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + return tensor.nelement() * _get_dtype_size(tensor.dtype) + + +@lru_cache() +def _get_dtype_size(dtype: "ms.Type") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 + """ + import mindspore as ms + + _SIZE = { + ms.int64: 8, + ms.float32: 4, + ms.int32: 4, + ms.bfloat16: 2, + ms.float16: 2, + ms.int16: 2, + ms.uint8: 1, + ms.int8: 1, + ms.bool_: 1, + ms.float64: 8, + } + return _SIZE[dtype] diff --git a/mindone/diffusers/models/BAK_modeling_utils.py b/mindone/diffusers/models/BAK_modeling_utils.py new file mode 100644 index 0000000000..a60071ceb6 --- /dev/null +++ b/mindone/diffusers/models/BAK_modeling_utils.py @@ -0,0 +1,1307 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/diffusers +# with modifications to run diffusers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import json +import os +import re +from collections import OrderedDict +from contextlib import ExitStack +from pathlib import Path +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union + +from huggingface_hub import DDUFEntry, create_repo +from huggingface_hub.utils import validate_hf_hub_args +from typing_extensions import Self + +import mindspore as ms +from mindspore import mint, nn +from mindspore.nn.utils import no_init_parameters + +from mindone.safetensors.mindspore import save_file as safe_save_file + +from .. import __version__ +from ..utils import ( + CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + logging, +) +from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from .model_loading_utils import ( + _fetch_index_file, + _fetch_index_file_legacy, + _load_state_dict_into_model, + load_state_dict, + split_torch_state_dict_into_shards, +) + + +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + for context_manager in self.context_managers: + self.stack.enter_context(context_manager) + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) + + +logger = logging.get_logger(__name__) + +_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + +def _get_pt2ms_mappings(m): + mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func) + for name, cell in m.cells_and_names(): + if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)): + mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter(x.unsqueeze(dim=-2), name=x.name) + if "weight_norm_cell" in name: + ori_name = name.replace(".weight_norm_cell", "") + mappings[f"{ori_name}.weight_g"] = f"{ori_name}.weight_g", lambda x: ms.Parameter( + x.unsqueeze(dim=-2), name=x.name + ) + mappings[f"{ori_name}.weight_v"] = f"{ori_name}.weight_v", lambda x: ms.Parameter( + x.unsqueeze(dim=-2), name=x.name + ) + mappings[f"{ori_name}.bias"] = f"{name}.bias", lambda x: x + elif isinstance(cell, nn.Embedding): + mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x + elif isinstance(cell, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + mappings[f"{name}.weight"] = f"{name}.gamma", lambda x: x + mappings[f"{name}.bias"] = f"{name}.beta", lambda x: x + if isinstance( + cell, + ( + nn.BatchNorm1d, + nn.BatchNorm2d, + ), + ): + mappings[f"{name}.running_mean"] = f"{name}.moving_mean", lambda x: x + mappings[f"{name}.running_var"] = f"{name}.moving_variance", lambda x: x + mappings[f"{name}.num_batches_tracked"] = None, lambda x: x + elif isinstance(cell, mint.nn.BatchNorm2d): + mappings[f"{name}.num_batches_tracked"] = None, lambda x: x.to(ms.float32) + return mappings + + +def _convert_state_dict(m, state_dict_pt): + if not state_dict_pt: + return state_dict_pt + pt2ms_mappings = _get_pt2ms_mappings(m) + state_dict_ms = {} + while state_dict_pt: + name_pt, data_pt = state_dict_pt.popitem() + name_ms, data_mapping = pt2ms_mappings.get(name_pt, (name_pt, lambda x: x)) + data_ms = data_mapping(data_pt) + if name_ms is not None: + state_dict_ms[name_ms] = data_ms + return state_dict_ms + + +def get_parameter_dtype(module: nn.Cell) -> ms.Type: + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in module.get_parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + +class ModelMixin(nn.Cell, PushToHubMixin): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + _keep_in_fp32_modules = None + _skip_layerwise_casting_patterns = None + _supports_group_offloading = True + + def __init__(self): + super().__init__() + + self._gradient_checkpointing_func = None + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `nn.Cell`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." # noqa: E501 + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for _, m in self.cells_and_names()) + + def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None: + """ + Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + + Args: + gradient_checkpointing_func (`Callable`, *optional*): + The function to use for gradient checkpointing. If `None`, the default MindSpore checkpointing function + is used (`mindspore.nn.Cell.recompute_`). + """ + if not self._supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute " + f"`_supports_gradient_checkpointing` to `True` in the class definition." + ) + + if gradient_checkpointing_func is None: + + def _gradient_checkpointing_func(module, *args): + module.recompute_(mode=True) + return module + + gradient_checkpointing_func = _gradient_checkpointing_func + + self._set_gradient_checkpointing(enable=True) + + def disable_gradient_checkpointing(self) -> None: + """ + Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if self._supports_gradient_checkpointing: + self._set_gradient_checkpointing(enable=False) + + def enable_flash_sdp(self, enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables flash scaled dot product attention. + """ + + # Recursively walk through all the children. + # Any children which exposes the enable_flash_sdp method + # gets the message + def fn_recursive_set_mem_eff(module: nn.Cell): + if hasattr(module, "enable_flash_sdp"): + module.enable_flash_sdp(enabled) + + for child in module.cells(): + fn_recursive_set_mem_eff(child) + + for module in self.cells(): + if isinstance(module, nn.Cell): + fn_recursive_set_mem_eff(module) + + def set_flash_attention_force_cast_dtype(self, force_cast_dtype: Optional[ms.Type]): + r""" + Since the flash-attention operator in MindSpore only supports float16 and bfloat16 data types, we need to manually + set whether to force data type conversion. + + When the attention interface encounters data of an unsupported data type, + if `force_cast_dtype` is not None, the function will forcibly convert the data to `force_cast_dtype` for computation + and then restore it to the original data type afterward. If `force_cast_dtype` is None, it will fall back to the + original attention calculation using mathematical formulas. + + Parameters: + force_cast_dtype (Optional): The data type to which the input data should be forcibly converted. If None, no forced + conversion is performed. + """ + + # Recursively walk through all the children. + # Any children which exposes the set_flash_attention_force_cast_dtype method + # gets the message + def fn_recursive_set_mem_eff(module: nn.Cell): + if hasattr(module, "set_flash_attention_force_cast_dtype"): + module.set_flash_attention_force_cast_dtype(force_cast_dtype) + + for child in module.cells(): + fn_recursive_set_mem_eff(child) + + for module in self.cells(): + if isinstance(module, nn.Cell): + fn_recursive_set_mem_eff(module) + + def set_use_memory_efficient_attention_xformers(self, valid: bool, attention_op: Optional[Callable] = None) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: nn.Cell): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.cells(): + fn_recursive_set_mem_eff(child) + + for module in self.cells(): + if isinstance(module, nn.Cell): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Not supported for now. + + Examples: + + ```py + >>> import mindspore as ms + >>> from mindone.diffusers import UNet2DConditionModel + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", mindspore_dtype=ms.float16 + ... ) + >>> model.enable_xformers_memory_efficient_attention() + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self) -> None: + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def enable_layerwise_casting( + self, + storage_dtype: ms.Type, + compute_dtype: Optional[ms.Type] = None, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[nn.Cell], ...]] = None, + non_blocking: bool = False, + ) -> None: + r""" + Activates layerwise casting for the current model. + + Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but + upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the + memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations + are negligible, mostly stemming from weight casting in normalization and modulation layers. + + By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch + embedding, positional embedding and normalization layers. This is because these layers are most likely + precision-critical for quality. If you wish to change this behavior, you can set the + `_skip_layerwise_casting_patterns` attribute to `None`, or call + [`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments. + + Example: + Using [`~models.ModelMixin.enable_layerwise_casting`]: + + ```python + >>> from mindone.diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", mindspore_dtype=ms.bfloat16 + ... ) + + >>> # Enable layerwise casting via the model, which ignores certain modules by default + >>> transformer.enable_layerwise_casting(storage_dtype=ms.float8_e4m3fn, compute_dtype=ms.bfloat16) + ``` + + Args: + storage_dtype (`mindspore.Type`): + The dtype to which the model should be cast for storage. + compute_dtype (`mindspore.Type`): + The dtype to which the model weights should be cast during the forward pass. + skip_modules_pattern (`Tuple[str, ...]`, *optional*): + A list of patterns to match the names of the modules to skip during the layerwise casting process. If + set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT + layers. + skip_modules_classes (`Tuple[Type[nn.Cell], ...]`, *optional*): + A list of module classes to skip during the layerwise casting process. + non_blocking (`bool`, *optional*, defaults to `False`): + If `True`, the weight casting operations are non-blocking. + """ + raise NotImplementedError("`enable_layerwise_casting` is not yet supported.") + + def enable_group_offload( + self, + onload_device: str = "Ascend", + offload_device: str = "CPU", + offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, + record_stream: bool = False, + low_cpu_mem_usage=False, + ) -> None: + r""" + Activates group offloading for the current model. + + See [`~hooks.group_offloading.apply_group_offloading`] for more information. + + Example: + + ```python + >>> from mindone.diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", mindspore_dtype=ms.bfloat16 + ... ) + + >>> transformer.enable_group_offload( + ... onload_device="Ascend", + ... offload_device="CPU", + ... offload_type="leaf_level", + ... use_stream=True, + ... ) + ``` + """ + raise NotImplementedError("`enable_group_offload` is not yet supported.") + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Optional[Callable] = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + max_shard_size: Union[int, str] = "10GB", + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~models.ModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `mindspore.save_checkpoint` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Only save the model itself if we are using distributed training + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = {k: v for k, v in model_to_save.parameters_and_names()} + + # Save the model + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) + + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, filepath, metadata={"format": "np"}) + else: + ms.save_checkpoint(shard, filepath) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + + if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + model_card.save(Path(save_directory, "README.md").as_posix()) + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self: + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + mindspore_dtype (`mindspore.Type`, *optional*): + Override the default `mindspore.Type` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from mindone.diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at + runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + subfolder = kwargs.pop("subfolder", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) + disable_mmap = kwargs.pop("disable_mmap", False) + + if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + mindspore_dtype = ms.float32 + logger.warning( + f"Passed `mindspore_dtype` {mindspore_dtype} is not a `ms.Type`. Defaulting to `ms.float32`." + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + unused_kwargs = {} + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + dduf_entries=dduf_entries, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # Check if `_keep_in_fp32_modules` is not None + # use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and ( + # hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) + # ) + use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None + + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + else: + keep_in_fp32_modules = [] + + is_sharded = False + resolved_model_file = None + + # Determine if we're loading from a directory of sharded checkpoints. + sharded_metadata = None + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + "dduf_entries": dduf_entries, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and (dduf_entries or index_file.is_file()): + is_sharded = True + + # load model + if from_flax: + raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") + else: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, + ) + elif use_safetensors: + try: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_model_file is None and not is_sharded: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] + + # set dtype to instantiate the model under: + # 1. If mindspore_dtype is not None, we use that dtype + # 2. If mindspore_dtype is float8, we don't use _set_default_mindspore_dtype and we downcast after loading the model + dtype_orig = None # noqa + if mindspore_dtype is not None: + if not isinstance(mindspore_dtype, ms.Type): + raise ValueError( + f"{mindspore_dtype} needs to be of type `mindspore.Type`, e.g. `mindspore.float16`, but is {type(mindspore_dtype)}." + ) + + with no_init_parameters(): + model = cls.from_config(config, **unused_kwargs) + + state_dict = None + if not is_sharded: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries) + # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. + model._fix_state_dict_keys_on_load(state_dict) + + if is_sharded: + loaded_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = _convert_state_dict(model, state_dict) + loaded_keys = list(state_dict.keys()) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + resolved_model_file, + pretrained_model_name_or_path, + loaded_keys, + ignore_mismatched_sizes=ignore_mismatched_sizes, + dtype=mindspore_dtype, + keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if mindspore_dtype is not None: + model = model.to(mindspore_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.set_train(False) + + if output_loading_info: + return model, loading_info + + return model + + def to(self, dtype: Optional[ms.Type] = None): + for p in self.get_parameters(): + p.set_dtype(dtype) + return self + + def half(self): + for p in self.get_parameters(): + p.set_dtype(ms.float16) + return self + + def float(self): + for p in self.get_parameters(): + p.set_dtype(ms.float32) + return self + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + resolved_model_file: List[str], + pretrained_model_name_or_path: Union[str, os.PathLike], + loaded_keys: List[str], + ignore_mismatched_sizes: bool = False, + dtype: Optional[Union[str, ms.Type]] = None, + keep_in_fp32_modules: Optional[List[str]] = None, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + ): + model_state_dict = {k: v for k, v in model.parameters_and_names()} + expected_keys = list(model_state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + mismatched_keys = [] + + error_msgs = [] + + if state_dict is not None: + # load_state_dict will manage the case where we pass a dict instead of a file + # if state dict is not None, it means that we don't need to read the files from resolved_model_file also + resolved_model_file = [state_dict] + + if len(resolved_model_file) > 1: + resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + + for shard_file in resolved_model_file: + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ) + if len(resolved_model_file) > 1: + error_msgs += _load_state_dict_into_model(model, state_dict, is_sharded=True) + else: + error_msgs += _load_state_dict_into_model(model, state_dict, is_sharded=False) + + offload_index = None + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" # noqa + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + # Adapted from `transformers` modeling_utils.py + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be split when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, ModelMixin): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.cells()) + return list(_no_split_modules) + + @classmethod + def _set_default_mindspore_dtype(cls, dtype: ms.Type) -> ms.Type: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`mindspore.Type`): + a floating dtype to set to. + + Returns: + `mindspore.Type`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `ms.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + raise NotImplementedError("`_set_default_mindspore_dtype` is not yet supported.") + + @property + def dtype(self) -> ms.Type: + """ + `mindspore.Type`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (trainable or non-embedding) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters. + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embedding parameters. + + Returns: + `int`: The number of parameters. + + Example: + + ```py + from mindone.diffusers import UNet2DConditionModel + + model_id = "runwayml/stable-diffusion-v1-5" + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") + unet.num_parameters(only_trainable=True) + 859520964 + ``` + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.cells_and_names() + if isinstance(module_type, mint.nn.Embedding) + ] + total_parameters = [ + parameter for name, parameter in self.parameters_and_names() if name not in embedding_param_names + ] + else: + total_parameters = list(self.get_parameters()) + + total_numel = [] + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + total_numel.append(param.numel()) + + return sum(total_numel) + + def _set_gradient_checkpointing(self, enable: bool = True) -> None: + is_gradient_checkpointing_set = False + + for name, module in self.cells_and_names(): + if hasattr(module, "recompute_"): + logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'") + module.recompute_(enable) + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to " + f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`." + ) + + def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: + """ + This function fix the state dict of the model to take into account some changes that were made in the model + architecture: + - deprecated attention blocks (happened before we introduced sharded checkpoint, + so this is why we apply this method only when loading non sharded checkpoints for now) + """ + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.name_cells().items(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + # TODO : MindSpore 2.6 share weight bug. Unable to load WTE and LM-Head layer weights properly. It will be + # deleted until fixed load_state_dict_into_model and parameters_and_names。 + if hasattr(self, "wte_lm_share") and self.wte_lm_share: + state_dict["transformer.transformer.wte.embedding_table"] = state_dict["transformer.lm_head.weight"] + + return state_dict + + def get_submodule(self, target: str) -> nn.Cell: + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + For example, let's say you have an ``nn.Cell`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) + ) + (linear): Dense(input_channels=100, output_channels=200, has_bias=True) + ) + ) + + (The diagram shows an ``nn.Cell`` ``A``. ``A`` has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To check whether or not we have the ``linear`` submodule, we + would call ``get_submodule("net_b.linear")``. To check whether + we have the ``conv`` submodule, we would call + ``get_submodule("net_b.net_c.conv")``. + + The runtime of ``get_submodule`` is bounded by the degree + of module nesting in ``target``. A query against + ``named_modules`` achieves the same result, but it is O(N) in + the number of transitive modules. So, for a simple check to see + if some submodule exists, ``get_submodule`` should always be + used. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + + Returns: + nn.Cell: The submodule referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Cell`` + """ + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: nn.Cell = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError(mod.cls_name + " has no " "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, nn.Cell): + raise AttributeError("`" + item + "` is not " "an nn.Module") + + return mod + + +class LegacyModelMixin(ModelMixin): + r""" + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # To prevent dependency import problem. + from .model_loading_utils import _fetch_remapped_cls_from_config + + # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. + kwargs_copy = kwargs.copy() + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, _, _ = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) diff --git a/mindone/diffusers/models/model_loading_utils.py b/mindone/diffusers/models/model_loading_utils.py index 34bc6059db..0bcdc6a983 100644 --- a/mindone/diffusers/models/model_loading_utils.py +++ b/mindone/diffusers/models/model_loading_utils.py @@ -33,9 +33,9 @@ import mindspore as ms from mindspore import nn, ops +from mindspore.ops import Cast from ...safetensors.mindspore import load as safe_load -from ...safetensors.mindspore import load_file as safe_load_file from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -47,6 +47,7 @@ ) logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _CLASS_REMAPPING_DICT = { "Transformer2DModel": { @@ -97,7 +98,7 @@ def load_state_dict( if disable_mmap: return safe_load(open(checkpoint_file, "rb").read()) else: - return safe_load_file(checkpoint_file) + return ms.load_checkpoint(checkpoint_file, format="safetensors") else: raise NotImplementedError( f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" @@ -140,11 +141,11 @@ def _load_state_dict_into_model( and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) and dtype == ms.float16 ): - v.set_dtype(ms.float32) + state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() diff --git a/mindone/diffusers/models/modeling_patch.py b/mindone/diffusers/models/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/diffusers/models/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index a60071ceb6..20f48d9246 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -58,6 +58,7 @@ load_state_dict, split_torch_state_dict_into_shards, ) +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype class ContextManagers: @@ -819,7 +820,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) model = cls.from_config(config, **unused_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() state_dict = None if not is_sharded: @@ -874,7 +879,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P def to(self, dtype: Optional[ms.Type] = None): for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def half(self): diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index 173bd904cc..d0b64ec0ab 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -396,6 +396,7 @@ from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( QwenImageEditPipeline, + QwenImageEditInpaintPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, diff --git a/mindone/diffusers/pipelines/qwenimage/__init__.py b/mindone/diffusers/pipelines/qwenimage/__init__.py index b6f05c169f..fcbc0f2067 100644 --- a/mindone/diffusers/pipelines/qwenimage/__init__.py +++ b/mindone/diffusers/pipelines/qwenimage/__init__.py @@ -10,6 +10,7 @@ "pipeline_qwenimage_img2img": ["QwenImageImg2ImgPipeline"], "pipeline_qwenimage_inpaint": ["QwenImageInpaintPipeline"], "pipeline_qwenimage_edit": ["QwenImageEditPipeline"], + "pipeline_qwenimage_edit_inpaint": ["QwenImageEditInpaintPipeline"], } if TYPE_CHECKING: @@ -17,6 +18,7 @@ from .pipeline_qwenimage_edit import QwenImageEditPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline + from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline else: import sys diff --git a/mindone/transformers/BAK_modeling_utils.py b/mindone/transformers/BAK_modeling_utils.py new file mode 100644 index 0000000000..8731ed7a42 --- /dev/null +++ b/mindone/transformers/BAK_modeling_utils.py @@ -0,0 +1,3210 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc +import json +import os +import re +import warnings +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.dynamic_module_utils import custom_object_save +from transformers.generation.configuration_utils import GenerationConfig +from transformers.safetensors_conversion import auto_conversion +from transformers.utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + DUMMY_INPUTS, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ModelOutput, + PushToHubMixin, + cached_file, + download_url, + extract_commit_hash, + find_adapter_config_file, + has_file, + is_offline_mode, + is_remote_url, + is_safetensors_available, + logging, + replace_return_docstrings, +) +from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files + +import mindspore as ms +from mindspore import Parameter, Tensor, mint, nn, ops +from mindspore.nn import CrossEntropyLoss, Identity + +from .activations import get_activation +from .generation.utils import GenerationMixin +from .integrations import PeftAdapterMixin +from .integrations.flash_attention import flash_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward +from .loss.loss_utils import LOSS_MAPPING +from .mindspore_adapter import dtype_to_str +from .mindspore_utils import ( # noqa: F401 + Conv1D, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_conv1d_layer, + prune_layer, + prune_linear_layer, +) +from .modeling_attn_mask_utils import dtype_to_min +from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available + +if is_safetensors_available(): + from safetensors import safe_open + + # from mindone.safetensors.mindspore import load_file as safe_load_file + from mindone.safetensors.mindspore import save_file as safe_save_file + +logger = logging.get_logger(__name__) + +_init_weights = True + + +def _get_pt2ms_mappings(m): + mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func) + for name, cell in m.cells_and_names(): + if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)): + mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter( + ops.expand_dims(x, axis=-2), name=x.name + ) + if "weight_norm_cell" in name: + ori_name = name.replace(".weight_norm_cell", "") + mappings[f"{ori_name}.weight_g"] = f"{ori_name}.weight_g", lambda x: ms.Parameter( + ops.expand_dims(x, axis=-2), name=x.name + ) + mappings[f"{ori_name}.weight_v"] = f"{ori_name}.weight_v", lambda x: ms.Parameter( + ops.expand_dims(x, axis=-2), name=x.name + ) + mappings[f"{ori_name}.bias"] = f"{name}.bias", lambda x: x + elif isinstance(cell, nn.Embedding): + mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x + elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + mappings[f"{name}.weight"] = f"{name}.gamma", lambda x: x + mappings[f"{name}.bias"] = f"{name}.beta", lambda x: x + if isinstance(cell, (nn.BatchNorm2d,)): + mappings[f"{name}.running_mean"] = f"{name}.moving_mean", lambda x: x + mappings[f"{name}.running_var"] = f"{name}.moving_variance", lambda x: x + mappings[f"{name}.num_batches_tracked"] = None, lambda x: x + return mappings + + +def _get_pt2ms_mapped_k(mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix): + if has_prefix_module and not expects_prefix_module: + loaded_keys = [ + mappings.get(s[len(prefix) + 1 :], (s[len(prefix) + 1 :], lambda x: x))[0] + if s.startswith(prefix) + else mappings.get(s, (s, lambda x: x))[0] + for s in loaded_keys + ] + loaded_keys = [".".join([prefix, s]) for s in loaded_keys] + elif not has_prefix_module and expects_prefix_module: + loaded_keys = [ + mappings.get(".".join([prefix, s]), (".".join([prefix, s]), lambda x: x))[0] for s in loaded_keys + ] + loaded_keys = [s[len(prefix) + 1 :] if s.startswith(prefix) else s for s in loaded_keys] + else: + loaded_keys = [mappings.get(s, (s, lambda x: x))[0] for s in loaded_keys] + return loaded_keys + + +def _convert_state_dict(m, state_dict_pt, prefix=""): + if not state_dict_pt: + return state_dict_pt + pt2ms_mappings = _get_pt2ms_mappings(m) + state_dict_ms = {} + while state_dict_pt: + name_pt, data_pt = state_dict_pt.popitem() + for name, param in m.parameters_and_names(): + name_ms = param.name + length = len(prefix) + 1 + if name_pt.startswith(prefix): + if name_ms.rsplit(".", 1)[0] == name_pt.rsplit(".", 1)[0][length:] or name_ms == name_pt[length:]: + name_pt = name_pt[length:] + elif not name_pt.startswith(prefix): + if name_pt.rsplit(".", 1)[0] == name_ms.rsplit(".", 1)[0][length:] or name_pt == name_ms[length:]: + name_pt = ".".join([prefix, name_pt]) + name_ms, data_mapping = pt2ms_mappings.get(name_pt, (name_pt, lambda x: x)) + data_ms = data_mapping(data_pt) + if name_ms is not None: + state_dict_ms[name_ms] = data_ms + return state_dict_ms + + +@contextmanager +def silence_mindspore_logger(): + ms_logger = ms.log._get_logger() + ms_level = ms_logger.level + ms_logger.setLevel("ERROR") + yield + ms_logger.setLevel(ms_level) + + +def get_first_parameter_dtype(parameter: Union[nn.Cell, "ModuleUtilsMixin"]): + """ + Returns the first parameter dtype (can be non-floating) or asserts if none were found. + """ + return next(parameter.parameters_dict()).dtype + + +def get_parameter_dtype(parameter: Union[nn.Cell, "ModuleUtilsMixin"]): + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for t in parameter.get_parameters(): + last_dtype = t.dtype + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + +def get_state_dict_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + return next(state_dict.values()).dtype + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one ms.Parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(ms.float32) + 4 + ``` + """ + if dtype == ms.bool_: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def shard_checkpoint( + state_dict: Dict[str, Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [{}] + last_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + + # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one + # weight in the current shard. + if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: + sharded_state_dicts.append({}) + last_block_size = 0 + + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size + total_size += weight_size + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors") + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="np") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "np"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return ms.load_checkpoint(checkpoint_file, format="safetensors") + else: + raise NotImplementedError( + f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + ) + + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=False): + # # add prefix to the name of parameters + # if len(start_prefix) > 0: + # for name, param in model_to_load.parameters_and_names(): + # if param.name != name: + # logger.error( + # f"When Loading state dict into model {model_to_load.__class__.__name__}, the attribute 'name' of 'mindspore.ms.Parameter' object is {param.name} which should be {name}.\n" # noqa: E501 + # f"There are several possible reasons for this misalignment:\n" + # f" 1. {model_to_load.__class__.__name__} didn't call 'MSPreTrainedModel.post_init()' correctly.\n" + # f" 2. You have made changes to the model before loading the weights, which may be implicit. For example, you created an optimizer using the parameters of model.\n" # noqa: E501 + # f"If you encounter this error, please report it to the developer." + # ) + # param.name = start_prefix + name + + # TODO: error_msgs is always empty for now. Maybe we need to rewrite MindSpore's `load_param_into_net`. + # Error msgs should contain caught exception like size mismatch instead of missing/unexpected keys. + # TODO: We should support loading float16 state_dict into float32 model, like PyTorch's behavior. + error_msgs = [] + # TODO: State dict loading in mindspore does not cast dtype correctly. We do it manually. It's might unsafe. + local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} + for k, v in state_dict.items(): + if k in local_state: + v.set_dtype(local_state[k].dtype) + else: + pass # unexpect key keeps origin dtype + cm = silence_mindspore_logger() if is_sharded else nullcontext() + with cm: + ms.load_param_into_net(model_to_load, state_dict, strict_load=True) + + # remove prefix from the name of parameters + if len(start_prefix) > 0: + for name, param in model_to_load.parameters_and_names(): + param.name = name + + return error_msgs + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +class ModuleUtilsMixin: + """ + A few utilities for `mindspore.nn.Cell`, to be used as a mixin. + """ + + def _get_name(self): + return self.__class__.__name__ + + def to(self, dtype: Optional[ms.Type] = None): + # FIXME: In ms 2.6.0 `tensor.set_dtype()` encountered a bug that it occurs wrong values. + # Resume to use self.register_buffer() in network and set dtype for buffer tensors after ms2.7.0 launched. + # Now we use `Parameter` and `Parameter.set_dtype()` instead. + + for p in self.get_parameters(): + p.set_dtype(dtype) + return self + + def float(self): + for p in self.get_parameters(): + p.set_dtype(ms.float32) + return self + + def half(self): + for p in self.get_parameters(): + p.set_dtype(ms.float16) + return self + + @property + def dtype(self) -> ms.Type: + """ + `ms.Type`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`Tensor`): An attention mask. + + Returns: + `Tensor`: The inverted attention mask. + """ + encoder_extended_attention_mask = None + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * dtype_to_min(self.dtype) + + return encoder_extended_attention_mask + + @staticmethod + def create_extended_attention_mask_for_decoder(input_shape, attention_mask): + batch_size, seq_length = input_shape + seq_ids = ops.arange(seq_length) + causal_mask = seq_ids[None, None, :].tile((batch_size, seq_length, 1)) <= seq_ids[None, :, None] + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = ops.cat( + [ + ops.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + # extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + extended_attention_mask = ops.mul(causal_mask.unsqueeze(1), attention_mask.unsqueeze(1).unsqueeze(1)) + return extended_attention_mask + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], dtype: ms.float32 = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * dtype_to_min(dtype) + return extended_attention_mask + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.tile((num_hidden_layers, 1, 1, 1, 1)) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.cells_and_names() + if isinstance(module_type, nn.Embedding) + ] + total_parameters = [ + ms.Parameter for name, ms.Parameter in self.parameters_and_names() if name not in embedding_param_names + ] + else: + total_parameters = list(self.get_parameters()) + + total_numel = [] + for param in total_parameters: + if param.requires_grad or not only_trainable: + total_numel.append(param.numel()) + + return sum(total_numel) + + def estimate_tokens(self, input_dict: Dict[str, Union[ms.Tensor, Any]]) -> int: + """ + Helper function to estimate the total number of tokens from the model inputs. + + Args: + inputs (`dict`): The model inputs. + + Returns: + `int`: The total number of tokens. + """ + if not hasattr(self, "warnings_issued"): + self.warnings_issued = {} + if self.main_input_name in input_dict: + return input_dict[self.main_input_name].numel() + elif "estimate_tokens" not in self.warnings_issued: + logger.warning( + "Could not estimate the number of tokens of the input, floating-point operations will not be computed" + ) + self.warnings_issued["estimate_tokens"] = True + return 0 + + def floating_point_ops(self, input_dict: Dict[str, Union[ms.Tensor, Any]], exclude_embeddings: bool = True) -> int: + """ + Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a + batch with this transformer model. Default approximation neglects the quadratic dependency on the number of + tokens (valid if `12 * d_model << sequence_length`) as laid out in [this + paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter + re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths. + + Args: + batch_size (`int`): + The batch size for the forward pass. + + sequence_length (`int`): + The number of tokens in each line of the batch. + + exclude_embeddings (`bool`, *optional*, defaults to `True`): + Whether or not to count embedding and softmax operations. + + Returns: + `int`: The number of floating-point operations. + """ + + return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + + +class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + r""" + Base class for all models. + + [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a MindSpore model, + taking as arguments: + + - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. + - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. + - **path** (`str`) -- A path to the TensorFlow checkpoint. + + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + model_tags = None + + _auto_class = None + _no_split_modules = None + _skip_keys_device_placement = None + _keep_in_fp32_modules = None + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't + # trained, but which are either deterministic or tied variables) + _keys_to_ignore_on_save = None + # a list of `state_dict` keys that are potentially tied to another key in the state_dict. + _tied_weights_keys = None + + is_parallelizable = False + supports_gradient_checkpointing = False + + # Flash Attention 2 support + _supports_flash_attn_2 = False + + # SDPA support + _supports_sdpa = False + + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? + _supports_cache_class = False + _supports_static_cache = False + + # Has support for dynamic model input? + _supports_dynamic_input = False + + # Has support for a `QuantoQuantizedCache` instance as `past_key_values` + _supports_quantized_cache = False + + # This flag signal that the model can be used as an efficient backend in TGI and vLLM + # In practice, it means that they support attention interface functions, fully pass the kwargs + # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan + _supports_attention_backend = False + + @property + def dummy_inputs(self) -> Dict[str, Tensor]: + """ + `Dict[str, Tensor]`: Dummy inputs to do a forward pass in the network. + """ + return {"input_ids": Tensor(DUMMY_INPUTS)} + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a MindSpore model. + """ + return "ms" + + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + f"ms.Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + self.config = config + self.name_or_path = config.name_or_path + self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # Overwrite the class attribute to make it an instance attribute, so models like + # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute + # when a different component (e.g. language_model) is used. + self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + mindspore_dtype=None, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) + 4. The default model's implementation otherwise (`LlamaAttention` for example) . + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were ' + f"used when loading the model, which are not compatible." + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if config._attn_implementation not in ["eager", "paged_attention"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): + message = ( + f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. ' + f'The only possible arguments are `attn_implementation="eager"`' + f" (manual attention implementation)" + ) + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using scaled_dot_product_attention)' + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the + # user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + if use_flash_attention_2: + logger.warning_once( + "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a " + 'future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + if config._attn_implementation == "flash_attention_2": + cls._check_and_enable_flash_attn_2( + config, + mindspore_dtype=mindspore_dtype, + hard_check_only=False, + ) + elif requested_attn_implementation in [None, "sdpa"]: + # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. + config = cls._check_and_enable_sdpa( + config, + hard_check_only=False if requested_attn_implementation is None else True, + ) + + return config + + @property + def base_model(self) -> nn.Cell: + """ + `mindspore.nn.Cell`: The main body of the model. + """ + return getattr(self, self.base_model_prefix, self) + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`. + + Under the hood, on classes where this function returns True, some generation-specific changes are triggered: + for instance, the model instance will have a populated `generation_config` attribute. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # The class inherits from a class that can generate (recursive check) -> can generate + for base in cls.__bases__: + if not hasattr(base, "can_generate"): + continue + if "PreTrainedModel" not in str(base) and base.can_generate(): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n - If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False + + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + mindspore_dtype=None, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 2 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute + `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_2: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_2_available(): + raise ImportError("FlashAttention2 has been toggled on, but it cannot be used due to some error") + + if mindspore_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 2.0 without specifying a MindSpore dtype. This might lead to unexpected behaviour" + ) + elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]: + logger.warning_once( + "Flash Attention 2.0 only supports ms.float16 and ms.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {mindspore_dtype}. You should run training or inference using " + f"Automatic Mixed-Precision via the `network=auto_mix_precision(network, ...)` decorator," + " or load the model with the `mindspore_dtype` argument. Example: `model = " + 'AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`' + ) + + if not hard_check_only: + config._attn_implementation = "flash_attention_2" + return config + + @property + def loss_function(self): + if hasattr(self, "_loss_function"): + return self._loss_function + + loss_type = getattr(self, "loss_type", None) + + if loss_type is None or loss_type not in LOSS_MAPPING: + logger.warning_once( + f"`loss_type={loss_type}` was set in the config but it is unrecognised." + f"Using the default loss: `ForCausalLMLoss`." + ) + loss_type = "ForCausalLM" + return LOSS_MAPPING[loss_type] + + @loss_function.setter + def loss_function(self, value): + self._loss_function = value + + @classmethod + def is_backend_compatible(cls): + return cls._supports_attention_backend + + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of SDPA for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` + to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through `scaled_dot_product_attention` yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. " + "If you believe this error is a bug, please open an issue in Transformers GitHub repository and " + 'load your model with the argument `attn_implementation="eager"` meanwhile. Example: ' + '`model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_sdpa_available(): + raise ImportError("SDPA requirements in Transformers are not met.") + + if not is_sdpa_available() or not cls._supports_sdpa: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + return config + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (str, *optional*): + Override the default torch_dtype and load the model under this dtype. + """ + # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 + # a warning is raised that dtype should be fp16. Since we never pass dtype from within + # modeling code, we can try to infer it here same way as done in `from_pretrained` + if hasattr(config, "mindspore_dtype"): + mindspore_dtype = kwargs.pop("mindspore_dtype", config.mindspore_dtype) + else: + mindspore_dtype = kwargs.pop("torch_dtype", config.torch_dtype) + + if isinstance(mindspore_dtype, str): + mindspore_dtype = getattr(ms, mindspore_dtype) + elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + TORCH_TO_MINDSPORE_DTYPE_MAP = { + "torch.float32": ms.float32, + "torch.bfloat16": ms.bfloat16, + "torch.float16": ms.float16, + } + mindspore_dtype = str(mindspore_dtype) + mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype] + + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. + + if config._attn_implementation_internal is not None: + # In this case, the config has been created with the attn_implementation set by the user, which we + # should respect. + attn_implementation = config._attn_implementation_internal + else: + attn_implementation = None + + config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + mindspore_dtype=mindspore_dtype, + ) + + model = cls(config, **kwargs) + + # We cannot set default mindspore dtype. So we need to cast model weights after creating. + if mindspore_dtype is not None: + model = model.to(mindspore_dtype) + + logger.info( + f"convert model:{model.__class__.__name__} parameters to mindspore_dtype {dtype_to_str(mindspore_dtype)}" + ) + + return model + + def get_input_embeddings(self) -> nn.Cell: + """ + Returns the model's input embeddings. + + Returns: + `nn.Cell`: A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Cell): + """ + Set model's input embeddings. + + Args: + value (`nn.Cell`): A cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + + def get_output_embeddings(self) -> nn.Cell: + """ + Returns the model's output embeddings. + + Returns: + `nn.Cell`: A mindspore cell mapping hidden states to vocabulary. + """ + return None # Overwrite for models with output embeddings + + def _init_weights(self, module): + """ + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the mindspore.common.initializer function are all replaced with skip. + """ + pass + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_hf_initialized", False): + return + self._init_weights(module) + module._is_hf_initialized = True + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix, "encoder" + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + for name, module in self.cells_and_names(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + @staticmethod + def _tie_encoder_decoder_weights( + encoder: nn.Cell, decoder: nn.Cell, base_model_prefix: str, base_encoder_name: str + ): + uninitialized_encoder_weights: List[str] = [] + tied_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" + " weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Cell, + encoder_pointer: nn.Cell, + module_name: str, + base_encoder_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + total_decoder_name="", + total_encoder_name="", + ): + assert isinstance(decoder_pointer, nn.Cell) and isinstance( + encoder_pointer, nn.Cell + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") + encoder_pointer.bias = decoder_pointer.bias + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()} + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" + " a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + base_encoder_name, + uninitialized_encoder_weights, + depth=depth + 1, + total_encoder_name=f"{total_encoder_name}.{encoder_name}", + total_decoder_name=f"{total_decoder_name}.{decoder_name}", + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively( + decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights + ) + + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + return tied_weights + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """Tie or clone module weights depending of whether we are using TorchScript or not""" + if self.config.torchscript: + try: + output_embeddings.weight = Parameter(input_embeddings.embedding_table) + except AttributeError: + # in case of mint.nn.Embedding + output_embeddings.weight = Parameter(input_embeddings.weight) + else: + try: + output_embeddings.weight = input_embeddings.embedding_table + except AttributeError: + # in case of mint.nn.Embedding + output_embeddings.weight = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias = mint.nn.functional.pad( + output_embeddings.bias, + ( + 0, + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], + ), + "constant", + 0, + ) + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `mindspore.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + Return: + `mindspore.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.vocab_size = model_embeds.embedding_table.shape[0] + self.vocab_size = model_embeds.embedding_table.shape[0] + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + + old_embeddings_requires_grad = old_embeddings.embedding_table.requires_grad + new_embeddings.embedding_table.requires_grad = old_embeddings_requires_grad + self.set_input_embeddings(new_embeddings) + + # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: + new_num_tokens = new_embeddings.embedding_table.shape[0] + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + old_lm_head_requires_grad = old_lm_head.weight.requires_grad + new_lm_head.weight.requires_grad = old_lm_head_requires_grad + self.set_output_embeddings(new_lm_head) + + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`mindspore.nn.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `mindspore.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + Return: + `mindspore.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if + `new_num_tokens` is `None` + """ + + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, " + f"which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.embedding_table.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` ms.Parameter. This means that the new embedding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + old_num_tokens, old_embedding_dim = old_embeddings.embedding_table.shape + + if old_num_tokens == new_num_tokens: + return old_embeddings + + if not isinstance(old_embeddings, nn.Embedding): + raise TypeError( + f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" + " should either use a different resize function or make sure that `old_embeddings` are an instance of" + f" {nn.Embedding}." + ) + + # Build new embeddings + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + ) + new_embeddings.embedding_table.set_dtype(old_embeddings.embedding_table.dtype) + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy token embeddings from the previous weights + + # numbers of tokens to copy + n = min(old_num_tokens, new_num_tokens) + new_embeddings.embedding_table.data[:n, :] = old_embeddings.embedding_table.data[:n, :] + + # Replace weights in old_embeddings and return to maintain the same embedding type. + # This ensures correct functionality when a Custom Embedding class is passed as input. + # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979) + old_embeddings.embedding_table.set_data(new_embeddings.embedding_table.data) + old_embeddings.num_embeddings = new_embeddings.embedding_table.data.shape[0] + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + + return new_embeddings + + def _get_resized_lm_head( + self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + ) -> nn.Dense: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (`mindspore.nn.Dense`): + Old lm head liner layer to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `mindspore.nn.Dense` module of the model without doing anything. transposed (`bool`, *optional*, defaults + to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, + vocab_size` else `vocab_size, lm_head_dim`. + + Return: + `mindspore.nn.Dense`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is + `None` + """ + if new_num_tokens is None: + return old_lm_head + + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.shape if not transposed else old_lm_head.weight.transpose().shape + ) + + if old_num_tokens == new_num_tokens: + return old_lm_head + + if not isinstance(old_lm_head, nn.Dense): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Dense}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Dense}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + new_lm_head = nn.Dense( + *new_lm_head_shape, + has_bias=has_new_lm_head_bias, + dtype=old_lm_head.weight.dtype, + ) + + # initialize new lm head (in particular added tokens) + self._init_weights(new_lm_head) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _init_added_embeddings_weights_with_mean( + self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ): + old_embeddings_weight = old_embeddings.weight.to(ms.float32) + mean_embeddings = mint.mean(old_embeddings_weight, axis=0) + + # Check if the covariance is positive definite. + is_covariance_psd = False + if is_covariance_psd: + raise NotImplementedError + else: + # Otherwise, just initialize with the mean. because distribution will not be created. + new_embeddings.weight[-1 * added_num_tokens :, :] = ( + mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype) + ) + + def _init_added_lm_head_weights_with_mean( + self, + old_lm_head, + new_lm_head, + old_lm_head_dim, + old_num_tokens, + added_num_tokens, + transposed=False, + ): + if transposed: + # Transpose to the desired shape for the function. + new_lm_head.weight = new_lm_head.weight.t() + old_lm_head.weight.data = old_lm_head.weight.t() + + # The same initialization logic as Embeddings. + self._init_added_embeddings_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens + ) + + if transposed: + # Transpose again to the correct shape. + new_lm_head.weight = new_lm_head.weight.t() + old_lm_head.weight = old_lm_head.weight.t() + + def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens): + bias_mean = mint.mean(old_lm_head.bias.data, axis=0, dtype=ms.float32) + bias_std = mint.std(old_lm_head.bias.data, axis=0).to(ms.float32) + new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std) + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + def resize_position_embeddings(self, new_num_position_embeddings: int): + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + if _init_weights: + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + # MindSpore patch. Refresh name of parameters. + for name, param in self.parameters_and_names(): + param.name = name + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = ms.save_checkpoint, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + state_dict (nested dictionary of `Tensor`): + The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only + save parts of the model or if special precautions need to be taken when recovering the state dictionary + of a model (like when using model parallelism). + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `ms.save_checkpoint` by another method. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = self # we don't unwrap_model(self) in mindspore + + # save the string version of dtype to the config, e.g. convert ms.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = repr(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + # Save the config + if is_main_process: + if not _hf_peft_config_loaded: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + # generation config built from the model config + the model config holds generation kwargs -> generate + # may revert to legacy behavior if the two don't match + if ( + model_to_save.generation_config._from_model_config + and model_to_save.config._get_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(model_to_save.config) + if new_generation_config != model_to_save.generation_config: + logger.warning( + "Your generation config was originally created from the model config, but the model " + "config has changed since then. Unless you pass the `generation_config` argument to this " + "model's `generate` calls, they will revert to the legacy behavior where the base " + "`generate` parameterization is loaded from the model config instead. " + "To avoid this behavior and this warning, we recommend you to overwrite the generation " + "config model attribute before calling the model's `save_pretrained`, preferably also " + "removing any generation kwargs from the model config. This warning will be raised to an " + "exception in v4.41." + ) + model_to_save.generation_config.save_pretrained(save_directory) + + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will " + "be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. " + "You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = {k: v for k, v in model_to_save.parameters_and_names()} + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + + # Shard the model if it is too big. + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME + + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. mindspore_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards.keys() + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "np"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + r""" + Instantiate a pretrained mindspore model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + MindSpore model using the provided conversion scripts and loading the MindSpore model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (`Dict[str, Tensor]`, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + mindspore_dtype (`str` or `mindspore.Type`, *optional*): + Override the default `mindspore.Type` and load the model under a specific `dtype`. The different options + are: + + 1. `ms.float16` or `ms.bfloat16` or `ms.float32`: load in a specified + `dtype`, ignoring the model's `config.mindspore_dtype` if one exists. If not specified + - the model will get loaded in `ms.float32` (fp32). + + 2. `"auto"` - A `mindspore_dtype` entry in the `config.json` file of the model will be + attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in + the checkpoint that's of a floating point type and use that as `dtype`. This will load the model + using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + + + + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or + reach out to the authors and ask them to add this information to the model's card and to insert the + `mindspore_dtype` entry in `config.json` on the hub. + + + + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* mindspore_model..bin. `variant` is + ignored when using `from_tf` or `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import BertConfig, BertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = BertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a TF checkpoint file instead of a MindSpore model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") + >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) + >>> # Loading from a Flax checkpoint file instead of a MindSpore model (slower) + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) + ``` + + * `low_cpu_mem_usage` algorithm: + + This is an experimental function that loads the model using ~1x model size CPU memory + + Here is how it works: + + 1. save which state_dict keys we have + 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory + 3. after the model has been instantiated switch to the meta device all params/buffers that + are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors + + """ + state_dict = kwargs.pop("state_dict", None) + from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + adapter_kwargs = kwargs.pop("adapter_kwargs", {}) + adapter_name = kwargs.pop("adapter_name", "default") + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: + adapter_kwargs["token"] = token + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + try: + _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) + except Exception: + _adapter_model_path = None + adapter_kwargs = {} + + if _adapter_model_path is None: + _adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) + if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): + with open(_adapter_model_path, "r", encoding="utf-8") as f: + _adapter_model_path = pretrained_model_name_or_path + pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + + from_pt = not (from_tf | from_flax) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + # In case one passes a config to `from_pretrained` + "attn_implementation" + # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs + # Please see: https://github.com/huggingface/transformers/issues/28038 + + # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory + # we pop attn_implementation from the kwargs but this handles the case where users + # passes manually the config to `from_pretrained`. + config = copy.deepcopy(config) + + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp + model_kwargs = kwargs + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None + # Load model + loading_info = None + + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," + f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + resolved_archive_file, revision, is_sharded = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" + f" {FLAX_WEIGHTS_NAME}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="np") as f: + metadata = f.metadata() + + if metadata.get("format") in ("np", "pt"): + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info("A TensorFlow safetensors file is being loaded in a MindSpore model.") + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) + + # load pt weights early so that we know which dtype to init the model under + if from_pt: + if not is_sharded and state_dict is None: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + + # set dtype to instantiate the model under: + # 1. If mindspore_dtype is not None, we use that dtype + # 2. If mindspore_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + + if mindspore_dtype is not None: + config.mindspore_dtype = dtype_to_str(mindspore_dtype) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.mindspore_dtype = mindspore_dtype + if isinstance(mindspore_dtype, str): + if mindspore_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + mindspore_dtype = config.torch_dtype + logger.info(f"Will use dtype={mindspore_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + mindspore_dtype = sharded_metadata["dtype"] + elif not is_sharded: + mindspore_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(resolved_archive_file[0]) + mindspore_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + f"Since the `torch_dtype` attribute can't be found in model's config object, " + f"will use dtype={mindspore_dtype} as derived from model's weights" + ) + else: + raise ValueError( + f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}' + ) + # TODO: We cannot set default mindspore dtype! + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (mindspore_dtype == ms.float16) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = list(state_dict.keys()) + + config.name_or_path = pretrained_model_name_or_path + + config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype + ) + + model = cls(config, *model_args, **model_kwargs) + + # Make sure to tie the weights correctly + model.tie_weights() + + # We cannot set default mindspore dtype. So we need to cast model weights after creating. + if mindspore_dtype is not None: + model = model.to(mindspore_dtype) + + logger.info( + f"convert model:{model.__class__.__name__} parameters to mindspore_dtype {dtype_to_str(mindspore_dtype)}" + ) + + # make sure we use the model's config since the __init__ call might have copied it + config = model.config + + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if from_tf: + raise NotImplementedError("loading tf checkpoint in mindspore model is not yet supported.") + elif from_flax: + raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") + elif from_pt: + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + dtype=mindspore_dtype, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if _adapter_model_path is not None: + model.load_adapter( + _adapter_model_path, + adapter_name=adapter_name, + token=token, + adapter_kwargs=adapter_kwargs, + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.set_train(False) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate() and pretrained_model_name_or_path is not None: + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if output_loading_info: + if loading_info is None: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + dtype=None, + keep_in_fp32_modules=None, + ): + model.tie_weights() + + # Retrieve missing & unexpected_keys + model_state_dict = {k: v for k, v in model.parameters_and_names()} + expected_keys = list(model_state_dict.keys()) + prefix = model.base_model_prefix + original_loaded_keys = loaded_keys + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # Mapping loaded_keys from pt to ms + pt2ms_mappings = _get_pt2ms_mappings(model) + loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix) + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + _prefix = f"{prefix}." + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] + expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] + elif add_prefix_to_model: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = sorted(set(expected_keys) - set(loaded_keys)) + unexpected_keys = set(loaded_keys) - set(expected_keys) + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.parameters_and_names(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.set_dtype(ms.float32) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + base_model_expected_keys = list(k for k, v in model_to_load.parameters_and_names()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + if ( + state_dict[checkpoint_key].shape[-1] == 1 + and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + ): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + state_dict = _convert_state_dict(model, state_dict, prefix) + + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=False) + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + mismatched_keys = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + + # loading checkpoint + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + state_dict = _convert_state_dict(model, state_dict, prefix) + + # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + + error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=True) + + # force memory release + del state_dict + gc.collect() + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warning if model.__class__.__name__ in archs else logger.info + warner( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = {".".join(key.split(".")[:-1]) for key in names} + + # torch.nn.ParameterList is a special case where two parameter keywords + # are appended to the module name, *e.g.* bert.special_embeddings.0 + module_keys = module_keys.union( + {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()} + ) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + _prefix = f"{self.base_model_prefix}." + name = name[len(_prefix) :] if name.startswith(_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " + "https://huggingface.co/docs/transformers/troubleshooting" + "#incorrect-output-when-padding-tokens-arent-masked." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning_once(warn_string) + + +class PoolerStartLogits(nn.Cell): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense = mint.nn.Linear(config.hidden_size, 1) + + def construct(self, hidden_states: ms.Tensor, p_mask: Optional[ms.Tensor] = None) -> ms.Tensor: + """ + Args: + hidden_states (`ms.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`ms.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `ms.Tensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == ms.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Cell): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense_0 = mint.nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = mint.nn.Tanh() + self.LayerNorm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = mint.nn.Linear(config.hidden_size, 1) + + def construct( + self, + hidden_states: ms.Tensor, + start_states: Optional[ms.Tensor] = None, + start_positions: Optional[ms.Tensor] = None, + p_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + """ + Args: + hidden_states (`ms.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`ms.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`ms.Tensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`ms.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `ms.Tensor`: The end logits for SQuAD. + """ + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(mint.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == ms.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Cell): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config): + super().__init__() + self.dense_0 = mint.nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = mint.nn.Tanh() + self.dense_1 = mint.nn.Linear(config.hidden_size, 1, bias=False) + + def construct( + self, + hidden_states: ms.Tensor, + start_states: Optional[ms.Tensor] = None, + start_positions: Optional[ms.Tensor] = None, + cls_index: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + """ + Args: + hidden_states (`ms.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`ms.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`ms.Tensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `ms.Tensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(mint.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +@dataclass +class SquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`]. + + Args: + loss (`mindspore.Tensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`mindspore.Tensor` of shape `(batch_size, config.start_n_top)`, *optional*, + returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`mindspore.Tensor` of shape `(batch_size, config.start_n_top)`, *optional*, + returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`mindspore.Tensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, + returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`mindspore.Tensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, + returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`mindspore.Tensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[ms.Tensor] = None + start_top_log_probs: Optional[ms.Tensor] = None + start_top_index: Optional[ms.Tensor] = None + end_top_log_probs: Optional[ms.Tensor] = None + end_top_index: Optional[ms.Tensor] = None + cls_logits: Optional[ms.Tensor] = None + + +class SQuADHead(nn.Cell): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) + def construct( + self, + hidden_states: ms.Tensor, + start_positions: Optional[ms.Tensor] = None, + end_positions: Optional[ms.Tensor] = None, + cls_index: Optional[ms.Tensor] = None, + is_impossible: Optional[ms.Tensor] = None, + p_mask: Optional[ms.Tensor] = None, + return_dict: bool = False, + ) -> Union[SquadHeadOutput, Tuple[ms.Tensor]]: + """ + Args: + hidden_states (`mindspore.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`mindspore.Tensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`mindspore.Tensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`mindspore.Tensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`mindspore.Tensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`mindspore.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = mint.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = mint.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = mint.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = mint.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = mint.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = mint.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return SquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +class SequenceSummary(nn.Cell): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Dense(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def construct(self, hidden_states: ms.Tensor, cls_index: Optional[ms.Tensor] = None) -> ms.Tensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` + where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(axis=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = ops.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=ms.int64, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +class AttentionInterface(MutableMapping): + """ + Dict-like object keeping track of allowed attention functions. You can easily add a new attention function + with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, + it needs to declare a new instance of this class inside the `modeling_.py`, and declare it on that instance. + """ + + # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if + # a new instance is created (in order to locally override a given function) + _global_mapping = { + "flash_attention_2": flash_attention_forward, + # "flex_attention": flex_attention_forward, # Mindspore dose not support flex_attention yet + "sdpa": sdpa_attention_forward, # Mindspore dose not support sdpa yet. Use vanilla attention to work around + } + + def __init__(self): + self._local_mapping = {} + + def __getitem__(self, key): + # First check if instance has a local override + if key in self._local_mapping: + return self._local_mapping[key] + return self._global_mapping[key] + + def __setitem__(self, key, value): + # Allow local update of the default functions without impacting other instances + self._local_mapping.update({key: value}) + + def __delitem__(self, key): + del self._local_mapping[key] + + def __iter__(self): + # Ensure we use all keys, with the overwritten ones on top + return iter({**self._global_mapping, **self._local_mapping}) + + def __len__(self): + return len(self._global_mapping.keys() | self._local_mapping.keys()) + + @classmethod + def register(cls, key: str, value: Callable): + cls._global_mapping.update({key: value}) + + def valid_keys(self) -> List[str]: + return list(self.keys()) + + +# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones +ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() + +# for BC +MSPreTrainedModel = PreTrainedModel diff --git a/mindone/transformers/modeling_patch.py b/mindone/transformers/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/transformers/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 8731ed7a42..bf348938eb 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -60,6 +60,8 @@ import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.nn import CrossEntropyLoss, Identity +from mindspore.nn.utils import no_init_parameters +from mindspore.ops import Cast from .activations import get_activation from .generation.utils import GenerationMixin @@ -77,6 +79,7 @@ prune_linear_layer, ) from .modeling_attn_mask_utils import dtype_to_min +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available if is_safetensors_available(): @@ -86,6 +89,7 @@ from mindone.safetensors.mindspore import save_file as safe_save_file logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _init_weights = True @@ -349,7 +353,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} for k, v in state_dict.items(): if k in local_state: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() @@ -387,7 +391,8 @@ def to(self, dtype: Optional[ms.Type] = None): # Now we use `Parameter` and `Parameter.set_dtype()` instead. for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def float(self): @@ -950,7 +955,7 @@ def _from_config(cls, config, **kwargs): if isinstance(mindspore_dtype, str): mindspore_dtype = getattr(ms, mindspore_dtype) - elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + elif mindspore_dtype is not None: TORCH_TO_MINDSPORE_DTYPE_MAP = { "torch.float32": ms.float32, "torch.bfloat16": ms.bfloat16, @@ -977,8 +982,12 @@ def _from_config(cls, config, **kwargs): use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype, ) - - model = cls(config, **kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, **kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # We cannot set default mindspore dtype. So we need to cast model weights after creating. if mindspore_dtype is not None: @@ -2348,7 +2357,12 @@ def from_pretrained( config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype ) - model = cls(config, *model_args, **model_kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, *model_args, **model_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # Make sure to tie the weights correctly model.tie_weights() diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py new file mode 100644 index 0000000000..4382a3d6e3 --- /dev/null +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py @@ -0,0 +1,274 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest + +import numpy as np +import pytest +import torch +from PIL import Image +from ddt import data, ddt, unpack +from transformers import Qwen2_5_VLConfig + +import mindspore as ms + +from diffusers import ( + AutoencoderKLQwenImage, + QwenImageEditPipeline, + QwenImageTransformer2DModel, +) + +from mindone.diffusers.utils.testing_utils import ( + load_numpy_from_local_file, + slow, +) + +from ..pipeline_test_utils import ( + THRESHOLD_FP16, + THRESHOLD_FP32, + THRESHOLD_PIXEL, + PipelineTesterMixin, + floats_tensor, + get_module, + get_pipeline_components, + randn_tensor, +) + +test_cases = [ + {"mode": ms.PYNATIVE_MODE, "dtype": "float32"}, + {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, +] + +class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_config = [ + [ + "transformer", + "diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_qwenimage.QwenImageTransformer2DModel", + dict( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), + ), + ], + [ + "vae", + "diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + "mindone.diffusers.models.autoencoders.autoencoder_kl_qwenimage.AutoencoderKLQwenImage", + dict( + base_dim=4 * 6, + z_dim=4, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + # fmt: on + ), + ], + [ + "scheduler", + "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + dict(), + ], + [ + "text_encoder", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + "mindone.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration", + dict( + config=Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + attention_dropout=0.0, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + rms_norm_eps=1e-06, + max_position_embeddings=128000, + hidden_size=16, + hidden_act="silu", + intermediate_size=16, + initializer_range=0.02, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + sliding_window=32768, #None + use_sliding_window=False, + use_cache=True, + attn_implementation="eager", + rope_scaling={ + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + rope_theta=1000000.0, + ), + ), + ], + [ + "tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", + dict( + # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + local_files_only=True, + trust_remote_code=True, + ), + ], + [ + "processor", + "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", + "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", + dict( + # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + local_files_only=True, + trust_remote_code=True, + ), + ], + ] + + def get_dummy_components(self): + components = { + key: None + for key in [ + "transformer", + "vae", + "scheduler", + "text_encoder", + "tokenizer", + "processor", + ] + } + + def get_dummy_inputs(self, seed=0): + inputs = { + "prompt": "dance monkey", + "image": Image.new("RGB", (32, 32)), + "negative_prompt": "bad quality", + "num_inference_steps": 2, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } + + return inputs + + + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + ms.set_context(mode=mode) + + pt_components, ms_components = self.get_dummy_components() + pt_pipe_cls = get_module("diffusers.pipelines.qwenimage.QwenImageEditPipeline") + ms_pipe_cls = get_module("mindone.diffusers.pipelines.qwenimage.QwenImageEditPipeline") + + pt_pipe = pt_pipe_cls(**pt_components) + ms_pipe = ms_pipe_cls(**ms_components) + + pt_pipe.set_progress_bar_config(disable=None) + ms_pipe.set_progress_bar_config(disable=None) + + ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype) + pt_pipe = pt_pipe.to(pt_dtype) + ms_pipe = ms_pipe.to(ms_dtype) + + sys.modules[ms_pipe.__module__].randn_tensor = randn_tensor + sys.modules[ms_pipe.vae.diag_gauss_dist.__module__].randn_tensor = randn_tensor + + inputs = self.get_dummy_inputs() + + torch.manual_seed(0) + pt_image = pt_pipe(**inputs).images + torch.manual_seed(0) + ms_image = ms_pipe(**inputs)[0] + + pt_generated_image = pt_image[0] + ms_generated_image = ms_image[0] + + threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16 + assert np.max(np.linalg.norm(pt_generated_image - ms_generated_image) / np.linalg.norm(pt_generated_image)) < threshold + + +@slow +@ddt +class QwenImageImg2ImgPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): + @data(*test_cases) + @unpack + def test_inference(self, mode, dtype): + if dtype == "float32": + pytest.skip("Skipping this case since this pipeline will OOM in float32") + + ms.set_context(mode=mode) + ms_dtype = getattr(ms, dtype) + + # model_id = "Qwen/Qwen-Image-Edit" + model_id = "/data6/Qwen-Image-Edit" + + pipe = QwenImageEditPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) + + pipe.vae.enable_tiling() + + torch.manual_seed(0) + image = pipe( + image=Image.new("RGB", (32, 32)), + prompt="dance monkey", + negative_prompt="bad quality", + )[0][0] + + # The text_coder causes deviations between ms and pt versions. However, the deviation\ + # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. + expected_image = load_numpy_from_local_file( + # "mindone-testing-arrays", + "/data4/mindone-testing-arrays", + f"qwenimage_edit_{dtype}.npy", + subfolder="qwenimage", + ) + + assert np.mean(np.abs(np.array(image, dtype=np.float32) - expected_image)) < THRESHOLD_PIXEL From 39168ee11d84f3a748a2221895e44a9aeb99e06b Mon Sep 17 00:00:00 2001 From: GUOGUO <55723162+Dong1017@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:08:15 +0800 Subject: [PATCH 47/77] 2025/9/5 15:07, edit-inpaint pipe --- .../pipelines/qwenimage/test_qwenimage_edit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py index 4382a3d6e3..862e5a7c20 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import sys import unittest @@ -41,7 +40,6 @@ THRESHOLD_FP32, THRESHOLD_PIXEL, PipelineTesterMixin, - floats_tensor, get_module, get_pipeline_components, randn_tensor, @@ -52,6 +50,7 @@ {"mode": ms.PYNATIVE_MODE, "dtype": "bfloat16"}, ] +@ddt class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_config = [ [ @@ -183,8 +182,10 @@ def get_dummy_components(self): "processor", ] } + return get_pipeline_components(components, self.pipeline_config) + - def get_dummy_inputs(self, seed=0): + def get_dummy_inputs(self): inputs = { "prompt": "dance monkey", "image": Image.new("RGB", (32, 32)), @@ -238,7 +239,7 @@ def test_inference(self, mode, dtype): @slow @ddt -class QwenImageImg2ImgPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): +class QwenImageEditPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase): @data(*test_cases) @unpack def test_inference(self, mode, dtype): From fdfb3a324827abba2d1dd4c70ced8d96aa42ae80 Mon Sep 17 00:00:00 2001 From: GUOGUO <55723162+Dong1017@users.noreply.github.com> Date: Fri, 5 Sep 2025 17:40:46 +0800 Subject: [PATCH 48/77] 2025/9/5 17:40, fix some bugs --- mindone/diffusers/__init__.py | 1 + mindone/diffusers/pipelines/__init__.py | 1 + .../pipeline_qwenimage_edit_inpaint.py | 35 ++++++++++--------- mindone/transformers/modeling_utils.py | 2 +- .../qwenimage/test_qwenimage_edit.py | 3 ++ 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 78ff933bdf..510bbb6ce3 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -548,6 +548,7 @@ PixArtSigmaPAGPipeline, PixArtSigmaPipeline, QwenImageEditPipeline, + QwenImageEditInpaintPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index d0b64ec0ab..7623e81e35 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -245,6 +245,7 @@ "wan": ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"], "qwenimage": [ "QwenImageEditPipeline", + "QwenImageEditInpaintPipeline", "QwenImageImg2ImgPipeline", "QwenImageInpaintPipeline", "QwenImagePipeline", diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index a290a50d28..64b08da1f6 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -31,7 +31,7 @@ from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging #, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import randn_tensor, pynative_context +from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -258,11 +258,12 @@ def _get_qwen_prompt_embeds( ) hidden_states = outputs.hidden_states[-1] - split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = self._extract_masked_hidden(hidden_states, ms.Tensor(model_inputs.attention_mask)) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( + [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( @@ -418,12 +419,12 @@ def _encode_vae_image(self, image: ms.Tensor, generator: np.random.Generator): with pynative_context(): if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])) + retrieve_latents(self.vae, self.vae.encode(image[i : i + 1])[0]) for i in range(image.shape[0]) ] image_latents = mint.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae, self.vae.encode(image)) + image_latents = retrieve_latents(self.vae, self.vae.encode(image)[0]) latents_mean = ( ms.Tensor(self.vae.config.latents_mean) @@ -950,7 +951,7 @@ def __call__( raise ValueError("guidance_scale is required for guidance-distilled model.") elif self.transformer.config.guidance_embeds: guidance = mint.full([1], guidance_scale, dtype=ms.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand((latents.shape[0],)) elif not self.transformer.config.guidance_embeds and guidance_scale is not None: logger.warning( f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." @@ -996,18 +997,18 @@ def __call__( noise_pred = noise_pred[:, : latents.shape[1]] if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] + # with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] neg_noise_pred = neg_noise_pred[:, : latents.shape[1]] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index bf348938eb..1c2a0b16b1 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -955,7 +955,7 @@ def _from_config(cls, config, **kwargs): if isinstance(mindspore_dtype, str): mindspore_dtype = getattr(ms, mindspore_dtype) - elif mindspore_dtype is not None: + elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): TORCH_TO_MINDSPORE_DTYPE_MAP = { "torch.float32": ms.float32, "torch.bfloat16": ms.bfloat16, diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py index 862e5a7c20..e2a395503f 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py @@ -204,6 +204,9 @@ def get_dummy_inputs(self): @data(*test_cases) @unpack def test_inference(self, mode, dtype): + if dtype == "bfloat16": + print("The PyTorch pipeline in bfloat16 requires ~1 hrs.") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() From 97b3d8e2e3b45be9c86eed5ea8762557b99cd10e Mon Sep 17 00:00:00 2001 From: GUOGUO <55723162+Dong1017@users.noreply.github.com> Date: Mon, 15 Sep 2025 11:28:50 +0800 Subject: [PATCH 49/77] modified qwenimage 2025/9/15 clean no-use notes --- .../autoencoders/autoencoder_kl_qwenimage.py | 4 ---- .../models/transformers/transformer_qwenimage.py | 9 +-------- .../pipelines/qwenimage/pipeline_qwenimage.py | 6 +++--- .../pipelines/qwenimage/pipeline_qwenimage_edit.py | 4 +--- .../qwenimage/pipeline_qwenimage_edit_inpaint.py | 5 +---- .../qwenimage/pipeline_qwenimage_img2img.py | 4 +--- .../qwenimage/pipeline_qwenimage_inpaint.py | 2 -- .../pipelines/qwenimage/test_qwenimage.py | 10 +++------- .../pipelines/qwenimage/test_qwenimage_edit.py | 14 ++++---------- .../pipelines/qwenimage/test_qwenimage_img2img.py | 10 +++------- .../pipelines/qwenimage/test_qwenimage_inpaint.py | 10 +++------- 11 files changed, 20 insertions(+), 58 deletions(-) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index b5f888cec8..d3210da882 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -327,10 +327,6 @@ def construct(self, x): # apply attention x = ops.flash_attention_score(q, k, v, 1, scalar_value=1/math.sqrt(q.shape[-1]), input_layout="BNSD") - # x = ops.operations.nn_ops.FlashAttentionScore(1, input_layout="BNSD")( - # q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), None, None, None, None - # )[3].to(q.dtype) - x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) # output projection diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index 72988d1489..296ce96171 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -123,7 +123,6 @@ def apply_rotary_emb_qwen( cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] - # cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit @@ -217,7 +216,7 @@ def construct(self, video_fhw, txt_seq_lens): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw rope_key = f"{idx}_{height}_{width}" - # jit-related, 25/8/18. Remain to fix. + # TODO: @jit, 25/8/18. Remain to fix. # if not torch.compiler.is_compiling(): # if rope_key not in self.rope_cache: # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) @@ -267,12 +266,6 @@ class QwenDoubleStreamAttnProcessor2_0: _attention_backend = None - # def __init__(self): - # if not hasattr(F, "scaled_dot_product_attention"): - # raise ImportError( - # "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - # ) - def __call__( self, attn: Attention, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 0822f2495b..037b4e2256 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -616,7 +616,6 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand((latents.shape[0],)).to(latents.dtype) - # with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -630,7 +629,6 @@ def __call__( )[0] if do_true_cfg: - # with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -683,7 +681,9 @@ def __call__( latents.dtype ) latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # TODO: we use pynative mode here since cache in vae.decode which not supported in graph mode + with pynative_context(): + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 85a0f5e5bc..0cde15e896 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -29,7 +29,7 @@ from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils import logging from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -772,7 +772,6 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand((latents.shape[0],)).to(latents.dtype) - # with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -787,7 +786,6 @@ def __call__( noise_pred = noise_pred[:, : latents.shape[1]] if do_true_cfg: - # with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 64b08da1f6..ee5e760089 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -30,7 +30,7 @@ from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils import logging from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -263,7 +263,6 @@ def _get_qwen_prompt_embeds( attn_mask_list = [mint.ones(e.shape[0], dtype=ms.int64) for e in split_hidden_states] max_seq_len = max([e.shape[0] for e in split_hidden_states]) prompt_embeds = mint.stack( - [mint.cat([u, u.new_zeros((max_seq_len - u.shape[0], u.shape[1]))]) for u in split_hidden_states] ) encoder_attention_mask = mint.stack( @@ -982,7 +981,6 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - # with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -997,7 +995,6 @@ def __call__( noise_pred = noise_pred[:, : latents.shape[1]] if do_true_cfg: - # with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 4b2c719e3c..be6e891b94 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -13,7 +13,7 @@ from ...loaders import QwenImageLoraLoaderMixin from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging #, scale_lora_layers, unscale_lora_layers +from ...utils import logging from ...utils.mindspore_utils import randn_tensor, pynative_context from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput @@ -718,7 +718,6 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand((latents.shape[0],)).to(latents.dtype) - # with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -732,7 +731,6 @@ def __call__( )[0] if do_true_cfg: - # with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index bae2aaf342..49a0f74f18 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -884,7 +884,6 @@ def __call__( self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand((latents.shape[0],)).to(latents.dtype) - # with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -898,7 +897,6 @@ def __call__( )[0] if do_true_cfg: - # with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py index bc60ba2cbe..95ce05fef6 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage.py @@ -145,9 +145,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( - # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" local_files_only=True, trust_remote_code=True, ), @@ -228,8 +226,7 @@ def test_inference(self, mode, dtype): ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) - # model_id = "Qwen/Qwen-Image" - model_id = "/data6/Qwen-Image" + model_id = "Qwen/Qwen-Image" pipe = QwenImagePipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) pipe.vae.enable_tiling() @@ -243,8 +240,7 @@ def test_inference(self, mode, dtype): # The text_coder causes deviations between ms and pt versions. However, the deviation\ # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( - # "mindone-testing-arrays", - "/data4/mindone-testing-arrays", + "mindone-testing-arrays", f"qwenimage_t2i_{dtype}.npy", subfolder="qwenimage", ) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py index e2a395503f..ad4e751a87 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py @@ -149,9 +149,7 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( - # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" local_files_only=True, trust_remote_code=True, ), @@ -161,9 +159,7 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", dict( - # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" local_files_only=True, trust_remote_code=True, ), @@ -252,8 +248,7 @@ def test_inference(self, mode, dtype): ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) - # model_id = "Qwen/Qwen-Image-Edit" - model_id = "/data6/Qwen-Image-Edit" + model_id = "Qwen/Qwen-Image-Edit" pipe = QwenImageEditPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) @@ -269,8 +264,7 @@ def test_inference(self, mode, dtype): # The text_coder causes deviations between ms and pt versions. However, the deviation\ # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( - # "mindone-testing-arrays", - "/data4/mindone-testing-arrays", + "mindone-testing-arrays", f"qwenimage_edit_{dtype}.npy", subfolder="qwenimage", ) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py index 5df9730dd1..8d7232d9f0 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_img2img.py @@ -138,9 +138,7 @@ class QwenImageImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( - # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" local_files_only=True, trust_remote_code=True, ), @@ -240,8 +238,7 @@ def test_inference(self, mode, dtype): ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) - # model_id = "Qwen/Qwen-Image" - model_id = "/data6/Qwen-Image" + model_id = "Qwen/Qwen-Image" image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image pipe = QwenImageImg2ImgPipeline.from_pretrained(model_id, mindspore_dtype=ms_dtype) @@ -258,8 +255,7 @@ def test_inference(self, mode, dtype): # The text_coder causes deviations between ms and pt versions. However, the deviation\ # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( - # "mindone-testing-arrays", - "/data4/mindone-testing-arrays", + "mindone-testing-arrays", f"qwenimage_i2i_{dtype}.npy", subfolder="qwenimage", ) diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py index 8dca8335bf..733d54ce9b 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_inpaint.py @@ -152,9 +152,7 @@ class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", "transformers.models.qwen2.tokenization_qwen2.Qwen2Tokenizer", dict( - # pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" - # pretrained_model_name_or_path="./hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", - pretrained_model_name_or_path="tests/diffusers_tests/pipelines/qwenimage/hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", + pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" local_files_only=True, trust_remote_code=True, ), @@ -259,8 +257,7 @@ def test_inference(self, mode, dtype): ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) - # model_id = "Qwen/Qwen-Image" - model_id = "/data6/Qwen-Image" + model_id = "Qwen/Qwen-Image" image = floats_tensor((1, 3, 32, 32), rng=random.Random(0)) # load given image mask_image = ms.mint.ones((1, 1, 32, 32)) @@ -279,8 +276,7 @@ def test_inference(self, mode, dtype): # The text_coder causes deviations between ms and pt versions. However, the deviation\ # is within THRESHOLD_PIXEL when using the same intermediate results of text_encoder. expected_image = load_numpy_from_local_file( - # "mindone-testing-arrays", - "/data4/mindone-testing-arrays", + "mindone-testing-arrays", f"qwenimage_inpaint_{dtype}.npy", subfolder="qwenimage", ) From 440e22619eb3cc8445cb5cec35c5d7b31cd0ad3c Mon Sep 17 00:00:00 2001 From: GUOGUO <55723162+Dong1017@users.noreply.github.com> Date: Mon, 15 Sep 2025 17:28:01 +0800 Subject: [PATCH 50/77] 2025/9/15 seamless_m4t submit add model seamless_m4t --- .../models/auto/configuration_auto.py | 2 + .../transformers/models/auto/modeling_auto.py | 9 +- .../models/seamless_m4t/__init__.py | 18 + .../seamless_m4t/modeling_seamless_m4t.py | 4056 +++++++++++++++++ .../models/seamless_m4t/__init__.py | 0 .../test_modeling_seamless_m4t.py | 398 ++ 6 files changed, 4482 insertions(+), 1 deletion(-) create mode 100644 mindone/transformers/models/seamless_m4t/__init__.py create mode 100644 mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py create mode 100644 tests/transformers_tests/models/seamless_m4t/__init__.py create mode 100644 tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index 6e3230205e..23252e2ba0 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -100,6 +100,7 @@ ("roberta", "RobertaConfig"), ("recurrent_gemma", "RecurrentGemmaConfig"), ("rembert", "RemBertConfig"), + ("seamless_m4t", "SeamlessM4TConfig"), ("swin", "SwinConfig"), ("siglip", "SiglipConfig"), ("siglip_vision_model", "SiglipVisionConfig"), @@ -195,6 +196,7 @@ ("recurrent_gemma", "RecurrentGemma"), ("rembert", "RemBERT"), ("swin", "Swin Transformer"), + ("seamless_m4t", "SeamlessM4T"), ("siglip", "SigLIP"), ("siglip_vision_model", "SiglipVisionModel"), ("smolvlm", "SmolVLM"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index 0eae2e3d30..56430abf51 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -96,6 +96,7 @@ ("qwen2_vl", "Qwen2VLModel"), ("roberta", "RobertaModel"), ("rembert", "RemBertModel"), + ("siglip", "SiglipModel"), ("siglip_vision_model", "SiglipVisionModel"), ("smolvlm", "SmolVLMModel"), @@ -372,6 +373,7 @@ ("mvp", "MvpForConditionalGeneration"), ("mt5", "MT5ForConditionalGeneration"), ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForTextToText"), ("t5", "T5ForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ] @@ -379,6 +381,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ + ("seamless_m4t", "SeamlessM4TForSpeechToText"), ("speecht5", "SpeechT5ForSpeechToText"), ("whisper", "WhisperForConditionalGeneration"), ] @@ -554,7 +557,11 @@ ] ) -MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict() +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( + [ + ("seamless_m4t", "SeamlessM4TForTextToSpeech") + ] +) MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ diff --git a/mindone/transformers/models/seamless_m4t/__init__.py b/mindone/transformers/models/seamless_m4t/__init__.py new file mode 100644 index 0000000000..0205df03ca --- /dev/null +++ b/mindone/transformers/models/seamless_m4t/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_seamless_m4t import * \ No newline at end of file diff --git a/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py new file mode 100644 index 0000000000..22a024c997 --- /dev/null +++ b/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -0,0 +1,4056 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MindSpore SeamlessM4T model.""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +from transformers import SeamlessM4TConfig +from transformers.utils import ModelOutput, logging + +import mindspore as ms +from mindspore import Tensor, mint, nn +from mindspore.mint.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/hf-seamless-m4t-medium" +_CONFIG_FOR_DOC = "SeamlessM4TConfig" + + +@dataclass +class SeamlessM4TGenerationOutput(ModelOutput): + """ + Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`], + [`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`]. + + Args: + waveform (`ms.Tensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + waveform_lengths (`ms.Tensor` of shape `(batch_size,)`, *optional*): + The length in samples of each element in the `waveform` batch. + sequences (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + The generated translated sequences. This is the output of the text-to-text or the speech-to-text models. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished + early due to the `eos_token_id`. + unit_sequences (`ms.Tensor` of shape `(batch_size, unit_sequence_length)`, *optional*): + The generated translated unit sequences. This is the output of the text-to-units model. The second + dimension (unit_sequence_length) is either equal to `t2u_max_length` or shorter if all batches finished + early due to the `t2u_eos_token_id`. + """ + + waveform: Optional[ms.Tensor] = None + waveform_lengths: Optional[ms.Tensor] = None + sequences: Optional[Tuple[ms.Tensor]] = None + unit_sequences: Optional[Tuple[ms.Tensor]] = None + + +############ UTILS ################ + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: ms.Tensor x: + + Returns: ms.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (mint.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: ms.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _compute_new_attention_mask(hidden_states: ms.Tensor, seq_lens: ms.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + + Args: + hidden_states (`ms.Tensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`ms.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + + Returns: + `ms.Tensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = mint.arange(mask_seq_len).expand((batch_size, -1)) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand((-1, mask_seq_len)) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +def format_speech_generation_kwargs(kwargs): + """ + Format kwargs for SeamlessM4T models that generate speech, attribute kwargs to either the text generation or the + speech generation models. + + Args: + kwargs (`dict`)`: + Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + """ + # attribute kwargs to models + kwargs_text = {} + kwargs_speech = {} + for key, value in kwargs.items(): + if key.startswith("text_"): + key = key[len("text_") :] + kwargs_text[key] = value + elif key.startswith("speech_"): + key = key[len("speech_") :] + kwargs_speech[key] = value + elif key == "generation_config": + kwargs_text[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_text: + kwargs_text[key] = value + if key not in kwargs_speech: + kwargs_speech[key] = value + return kwargs_text, kwargs_speech + + +############ SPEECH ENCODER related code ################ + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SeamlessM4TConformer, feat_extract_activation->speech_encoder_hidden_act +class SeamlessM4TConformerPositionalConvEmbedding(nn.Cell): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + self.conv = weight_norm(self.conv, name="weight", dim=2) + self.padding = SeamlessM4TConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + + def construct(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2->SeamlessM4T, num_attention_heads->speech_encoder_attention_heads +class SeamlessM4TConformerRotaryPositionalEmbedding(nn.Cell): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.speech_encoder_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (mint.arange(0, dim, 2, dtype=ms.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def construct(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = mint.arange(sequence_length).type_as(self.inv_freq) + freqs = mint.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = mint.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = mint.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2->SeamlessM4T +class SeamlessM4TConformerRelPositionalEmbedding(nn.Cell): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(ms.Tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype: + self.pe = self.pe.to(dtype=x.dtype) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (iSeamlessM4T +class SeamlessM4TConformerSamePadLayer(nn.Cell): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def construct(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SeamlessM4TConformerFeatureProjection(nn.Cell): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def construct(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerFeedForward(nn.Cell): + def __init__(self, config, act_fn=None, dropout=None): + super().__init__() + dropout = dropout if dropout is not None else config.speech_encoder_dropout + act_fn = act_fn if act_fn is not None else config.speech_encoder_hidden_act + + self.intermediate_dropout = nn.Dropout(dropout) + self.intermediate_dense = nn.Linear(config.hidden_size, config.speech_encoder_intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.speech_encoder_intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(dropout) + + def construct(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerConvolutionModule(nn.Cell): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding="same", + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def construct(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SeamlessM4TConformerSelfAttention(nn.Cell): + """Construct a SeamlessM4TConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config, use_position_embeddings=True): + super().__init__() + + self.head_size = config.hidden_size // config.speech_encoder_attention_heads + self.num_heads = config.speech_encoder_attention_heads + self.position_embeddings_type = config.position_embeddings_type if use_position_embeddings else None + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.speech_encoder_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(mint.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(mint.zeros(self.num_heads, self.head_size)) + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention.forward + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + relative_position_embeddings: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = mint.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = mint.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = mint.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = mint.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = mint.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = mint.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = mint.zeros((*scores_bd.size()[:3], 1), dtype=scores_bd.dtype) + scores_bd_padded = mint.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class SeamlessM4TConformerEncoderLayer(nn.Cell): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.speech_encoder_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = SeamlessM4TConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = SeamlessM4TConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = SeamlessM4TConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = SeamlessM4TConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def construct( + self, + hidden_states, + attention_mask: Optional[ms.Tensor] = None, + relative_position_embeddings: Optional[ms.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[ms.Tensor] = None, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class SeamlessM4TConformerEncoder(nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = SeamlessM4TConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = SeamlessM4TConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.dropout = nn.Dropout(config.speech_encoder_dropout) + self.layers = nn.CellList( + [SeamlessM4TConformerEncoderLayer(config) for _ in range(config.speech_encoder_layers)] + ) + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * mint.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = mint.rand([]) + + skip_the_layer = ( + True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False + ) + if not skip_the_layer: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SeamlessM4TConformerAdapterLayer(nn.Cell): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.adaptor_dropout + + self.kernel_size = config.adaptor_kernel_size + self.stride = config.adaptor_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = SeamlessM4TConformerSelfAttention(config, use_position_embeddings=False) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim) + self.ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=dropout) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + pad = self.kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + + return seq_lens.floor() + + def construct( + self, + hidden_states, + attention_mask: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _prepare_4d_attention_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weigths = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +class SeamlessM4TConformerAdapter(nn.Cell): + def __init__(self, config): + super().__init__() + + self.layers = nn.CellList(SeamlessM4TConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + + def construct(self, hidden_states, attention_mask): + # down project hidden_states if necessary + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + return hidden_states + + +############ TEXT / UNITS related code ################ + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T +class SeamlessM4TScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def construct(self, input_ids: ms.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class SeamlessM4TSinusoidalPositionalEmbedding(nn.Cell): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = mint.exp(mint.arange(half_dim, dtype=ms.int64).float() * -emb) + emb = mint.arange(num_embeddings, dtype=ms.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = mint.cat([mint.sin(emb), mint.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = mint.cat([emb, mint.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(mint.get_default_dtype()) + + def construct( + self, input_ids: ms.Tensor = None, inputs_embeds: ms.Tensor = None, past_key_values_length: int = 0 + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: ms.Tensor + + Returns: ms.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = mint.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=ms.int64 + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class SeamlessM4TAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.bart.modeling_bart.BartAttention.__init__ with Bart->SeamlessM4T + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SeamlessM4TConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: ms.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + past_key_value: Optional[Tuple[ms.Tensor]] = None, + attention_mask: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `encoder_hidden_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = mint.cat([past_key_value[0], key_states], dim=2) + value_states = mint.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(ms.Tensor, ms.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(ms.Tensor, ms.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = mint.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = mint.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size +class SeamlessM4TFeedForwardNetwork(nn.Cell): + def __init__(self, config: SeamlessM4TConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.hidden_size) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def construct(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, ms.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != mint.int8 and self.fc2.weight.dtype != mint.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SeamlessM4TEncoderLayer(nn.Cell): + def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): + super().__init__() + encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim + encoder_attention_heads = ( + config.encoder_attention_heads if encoder_attention_heads is None else encoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=encoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: ms.Tensor, + output_attentions: bool = False, + ) -> ms.Tensor: + """ + Args: + hidden_states (`ms.Tensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`ms.Tensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SeamlessM4TDecoderLayer(nn.Cell): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = SeamlessM4TAttention( + self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=decoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + past_key_value: Optional[Tuple[ms.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> ms.Tensor: + """ + Args: + hidden_states (`ms.Tensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`ms.Tensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`ms.Tensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`ms.Tensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + past_key_value (`Tuple(ms.Tensor)`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +############ SUB-MODELS related code ################ + + +class SeamlessM4TPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SeamlessM4TConfig + base_model_prefix = "seamless_m4t" + supports_gradient_checkpointing = True + _no_split_modules = ["SeamlessM4TEncoderLayer", "SeamlessM4TDecoderLayer", "SeamlessM4TConformerEncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SeamlessM4TConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, SeamlessM4TConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SeamlessM4TConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride + pad = kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1 + + return seq_lens.floor() + + def compute_last_hidden_states_per_sample( + self, + hidden_states: Tuple[Tuple[ms.Tensor]], + beam_indices: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + """ + Computes the last hidden states. + + Parameters: + hidden_states (`Tuple[Tuple[ms.Tensor]]`): + The generated hidden states. Tuple (one element for each generated token) of tuples (one element for + each layer of the decoder) of ms.Tensor of shape (batch_size*num_beams*num_return_sequences, + generated_length, hidden_size). + beam_indices (`ms.Tensor`, *optional*): + Beam indices of generated token id at each generation step. `ms.Tensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + + Return: + `ms.Tensor`: A `ms.Tensor` of shape `(batch_size*num_return_sequences, sequence_length, hidden_size)` + containing + the last hidden states. + ```""" + # 1. First, let's compute last_hidden_states from hidden_states. + # For each generation step, takes the hidden state from the last layer. + # shape: (batch_size*vocab_size*num_return_sequences, # generation_steps, hidden_dim) + last_hidden_states = mint.concat([hidden_states[-1] for hidden_states in hidden_states], dim=1) + + # 2. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + # in that case, return directly last_hidden_states + if beam_indices is None: + return last_hidden_states + + # 3. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards anyways + beam_indices[beam_indices_mask] = 0 + + # 5. expand beam_indices to last_hidden_states dim + beam_indices = beam_indices.unsqueeze(-1) + beam_indices = beam_indices.expand(-1, -1, last_hidden_states.shape[-1]) + + # 6. select the right candidate for each beam + # in other words, new_last_hidden_states[i,j,k] = last_hidden_states[beam_indices[i,j,k], j, k] for all i, j, k + last_hidden_states = mint.gather(last_hidden_states, 0, beam_indices) + + return last_hidden_states + +class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel): + main_input_name = "input_features" + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.feature_projection = SeamlessM4TConformerFeatureProjection(config) + self.encoder = SeamlessM4TConformerEncoder(config) + self.intermediate_ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=0.0) + self.adapter = SeamlessM4TConformerAdapter(config) if config.add_adapter else None + self.inner_layer_norm = nn.LayerNorm(config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_features: Optional[ms.Tensor], + attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError( + """Both `input_features` and `inputs_embeds` are `None` in `SeamlessM4TSpeechEncoder.forward`. + Make sure one of them is not `None`.""" + ) + + hidden_states = self.feature_projection(input_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + hidden_states = self.inner_layer_norm(hidden_states) + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# inspired from MBart and NllbMoe +class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + is_t2u_encoder: bool = False, + ): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + embed_dim = config.hidden_size + + self.is_t2u_encoder = is_t2u_encoder + self.max_source_positions = config.max_position_embeddings + + if not self.is_t2u_encoder: + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + + layers = [] + for _ in range(config.encoder_layers): + layers.append( + SeamlessM4TEncoderLayer( + config, + encoder_attention_heads=config.encoder_attention_heads, + encoder_ffn_dim=config.encoder_ffn_dim, + ) + ) + + self.layers = nn.CellList(layers) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and self.is_t2u_encoder: + raise ValueError( + "You cannot pass input_ids to the encoder of the text_to_units model. Pass inputs_embeds instead." + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.is_t2u_encoder: + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos + else: + hidden_states = inputs_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = mint.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.forward, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for _ in range(config.decoder_layers): + layers.append( + SeamlessM4TDecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + ) + ) + self.layers = nn.CellList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`ms.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`ms.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = mint.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + super().__init__(config) + + self.encoder = SeamlessM4TEncoder(config, is_t2u_encoder=True) + self.decoder = SeamlessM4TDecoder(config, embed_tokens_decoder) + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[ms.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = [ + "vocoder", + "speech_encoder", + "text_encoder", + "text_decoder", + ] + _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + # update config - used principaly for bos_token_id etc. + config = copy.deepcopy(config) + for param, val in config.to_dict().items(): + if param.startswith("t2u_"): + config.__setattr__(param[4:], val) + super().__init__(config) + + self.model = SeamlessM4TTextToUnitModel(config, embed_tokens_decoder) + + self.lm_head = nn.Linear(config.hidden_size, config.t2u_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: ms.Tensor): + return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + def _tie_weights(self) -> None: + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + +############ VOCODER related code ################ + + +HIFIGAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is a Mindspore [mindspore.nn.Cell](https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell) + sub-class. Use it as a regular Mindspore Cell and refer to the Mindspore documentation for all matter related + to general usage and behavior. + + Parameters: + config ([`SeamlessM4TConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Cell): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.CellList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.CellList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def construct(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class SeamlessM4TVariancePredictor(nn.Cell): + def __init__(self, config): + super().__init__() + + embed_dim = config.unit_embed_dim + kernel_size = config.variance_predictor_kernel_size + var_pred_dropout = config.var_pred_dropout + + self.conv1 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + self.activation_fuction = nn.ReLU() + self.ln1 = nn.LayerNorm(embed_dim) + self.dropout_module = nn.Dropout(p=var_pred_dropout) + self.conv2 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=1, + ) + self.ln2 = nn.LayerNorm(embed_dim) + self.proj = nn.Linear(embed_dim, 1) + + def construct(self, hidden_states: Tensor) -> Tensor: + # Input: B x T x C; Output: B x T + hidden_states = self.conv1(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln1(hidden_states)) + hidden_states = self.conv2(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln2(hidden_states)) + return self.proj(hidden_states).squeeze(dim=2) + + +class SeamlessM4THifiGan(nn.Cell): + def __init__(self, config: SeamlessM4TConfig): + super().__init__() + model_in_dim = config.unit_embed_dim + config.lang_embed_dim + config.spkr_embed_dim + self.leaky_relu_slope = config.leaky_relu_slope + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.CellList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.CellList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + def construct(self, input_embeds: ms.Tensor) -> ms.Tensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`ms.Tensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + model_in_dim)`, or un-batched and of shape `(sequence_length, model_in_dim)`. Note that `model_in_dim` + is the sum of `config.unit_embed_dim`, `config.lang_embed_dim` and `config.spkr_embed_dim`. + + Returns: + `ms.Tensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + + hidden_states = self.conv_pre(input_embeds) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = mint.tanh(hidden_states) + + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +class SeamlessM4TCodeHifiGan(PreTrainedModel): + config_class = SeamlessM4TConfig + main_input_name = "input_embeds" + _no_split_modules = [] + + def __init__(self, config): + super().__init__(config) + + self.pad_token_id = config.t2u_pad_token_id + self.dur_predictor = SeamlessM4TVariancePredictor(config) + + self.unit_embedding = nn.Embedding(config.unit_hifi_gan_vocab_size, config.unit_embed_dim) + self.speaker_embedding = nn.Embedding(config.vocoder_num_spkrs, config.spkr_embed_dim) + self.language_embedding = nn.Embedding(config.vocoder_num_langs, config.lang_embed_dim) + + self.hifi_gan = SeamlessM4THifiGan(config) + + # Initialize weights and apply final processing + self.post_init() + + def _get_dur_output_lengths(self, input_ids, dur_out): + """ + Computes the output length after the duration layer. + """ + unit_lengths = (input_ids != self.pad_token_id).sum(1) + + # take care of edge cases where no padding or too many padding + unit_lengths = mint.clamp(unit_lengths, 0, dur_out.shape[1] - 1) + + cumulative_dur_out = mint.cumsum(dur_out, dim=1) + unit_lengths = cumulative_dur_out.gather(dim=1, index=unit_lengths.unsqueeze(1)).squeeze() + + return unit_lengths + + def _get_output_hifigan_lengths(self, input_lengths: Union[ms.Tensor, int]): + """ + Computes the output length of the hifigan convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + mint.div(input_length + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + def construct( + self, input_ids: ms.Tensor, spkr_id: ms.Tensor, lang_id: ms.Tensor + ) -> Tuple[ms.Tensor]: + """ + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTextToUnitForConditionalGeneration`]. [What are input + IDs?](../glossary#input-ids) + spkr_id (`int`, *optional*): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + tgt_lang (`str`, *optional*): + The language id to use as target language for translation. + """ + hidden_states = self.unit_embedding(input_ids).transpose(1, 2) + spkr = self.speaker_embedding(spkr_id).transpose(1, 2) + lang = self.language_embedding(lang_id).transpose(1, 2) + + log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) + dur_out = mint.clamp(mint.round((mint.exp(log_dur_pred) - 1)).long(), min=1) + # B x C x T + if hidden_states.size(0) == 1: + hidden_states = mint.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + mint.repeat_interleave(hidden_state, duration, dim=-1).transpose(0, 1) + for (hidden_state, duration) in zip(hidden_states, dur_out) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True).transpose(1, 2) + + spkr = spkr.repeat(1, 1, hidden_states.shape[-1]) + lang = lang.repeat(1, 1, hidden_states.shape[-1]) + hidden_states = mint.cat([lang, hidden_states, spkr], dim=1) + + hidden_states = self.hifi_gan(hidden_states) + + unit_lengths = self._get_dur_output_lengths(input_ids, dur_out) + lengths = self._get_output_hifigan_lengths(unit_lengths) + + return hidden_states, lengths + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.apply_weight_norm() + weight_norm(self.hifi_gan.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.hifi_gan.conv_post) + + +############ WHOLE MODEL related code ################ + + +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`ms.Tensor` of varying shape depending on the modality, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `ms.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `ms.Tensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # prepare text_decoder_input_ids + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {', '.join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = ms.Tensor([[text_tgt_lang_id]] * batch_size) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + + return super().generate( + input_ids, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_decoder", "t2u_model", "vocoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def construct( + self, + input_features: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_features=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`ms.Tensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `ms.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `ms.Tensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") + if tgt_lang is not None: + inputs = kwargs.get("input_embeds") if input_features is None else input_features + inputs = ( + inputs + if inputs is not None + else kwargs.get("encoder_outputs", {"last_hidden_state": None})["last_hidden_state"] + ) + batch_size = len(inputs) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {', '.join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = ms.Tensor([[text_tgt_lang_id]] * batch_size) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + return super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids: Optional[ms.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[ms.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = ms.Tensor([[text_tgt_lang_id]] * batch_size) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + mint.arange(batch_size) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = ms.Tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = mint.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = ms.Tensor([[vocoder_tgt_lang_id]] * len(unit_ids)) + + spkr_id = ms.Tensor([[spkr_id]] * len(unit_ids)) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_encoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def construct( + self, + input_features: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForSpeechToText`. It doesn't use `self.t2u_model`." + "If you want to generate speech, use the `generate` method." + ) + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_features: Optional[ms.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[ms.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_features, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_features (`ms.Tensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_features) if input_features is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = ms.Tensor([[text_tgt_lang_id]] * batch_size) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_features, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get last_hidden_state from encoder + encoder_hidden_states = self.speech_encoder(input_features=input_features, attention_mask=attention_mask)[0] + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + mint.arange(batch_size) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = ms.Tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = mint.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = ms.Tensor([[vocoder_tgt_lang_id]] * len(unit_ids)) + + spkr_id = ms.Tensor([[spkr_id]] * len(unit_ids)) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config, current_modality="text"): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.current_modality = current_modality + if current_modality == "speech": + self.main_input_name = "input_features" + + # these models already call post_init in their initialization + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def set_modality(self, modality="text"): + if modality == "text": + self.main_input_name = "input_ids" + self.current_modality = "text" + elif modality == "speech": + self.main_input_name = "input_features" + self.current_modality = "speech" + else: + raise ValueError(f"`modality={modality}` is not a valid modality. It must be `text` or `speech`.") + + def get_encoder(self): + if self.current_modality == "text": + return self.text_encoder + else: + return self.speech_encoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + input_features: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + decoder_input_ids: Optional[ms.Tensor] = None, + decoder_attention_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[ms.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[ms.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if input_ids is None and input_features is None and inputs_embeds is None and encoder_outputs is None: + raise ValueError( + "`input_ids`,`input_features`, `inputs_embeds` and `encoder_outputs` are all empty. Make sure at least one of them is not." + ) + elif input_features is not None: + if input_ids is not None: + logger.warning( + "`input_ids` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through the `speech_encoder`. " + "Make sure that `input_features` and `input_ids` are mutually exclusive." + ) + + if inputs_embeds is not None: + logger.warning( + "`inputs_embeds` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through `speech_encoder`. " + "`inputs_embeds` will be ignored." + ) + + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + + self.set_modality("speech") + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + elif input_ids is not None or inputs_embeds is not None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + self.set_modality("text") + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + # input modality = speech so new attention mask + if self.current_modality == "speech" and attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids: Optional[ms.Tensor] = None, + input_features: Optional[ms.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + generate_speech: Optional[bool] = True, + **kwargs, + ) -> Union[ms.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated token ids and/or translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids=input_ids, num_beams=4, speech_do_sample=True)` will successively + perform beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generatiforon strategies and code examples, check out the [following + guide](./generation_strategies). + + + + + Args: + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`ms.Tensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. Note that if `generate_speech=True`, this parameter will be + ignored. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + generate_speech (`bool`, *optional*, defaults to `True`): + If `False`, will only returns the text tokens and won't generate speech. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor], ModelOutput]`: + - If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of + shape `(batch_size, sequence_length)`and and `waveform_lengths` which gives the length of each sample. + - If `generate_speech=False`, it will returns `ModelOutput`. + """ + if input_ids is None and input_features is None and kwargs.get("inputs_embeds", None) is None: + raise ValueError( + "`input_ids`,`input_features` and `inputs_embeds` are all empty. Make sure at least one of them is not." + ) + + if generate_speech and tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + + if tgt_lang is not None: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + batch_size = ( + len(input_features) + if input_features is not None + else (len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds"))) + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = ms.Tensor([[text_tgt_lang_id]] * batch_size) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + if input_features is not None: + self.set_modality("speech") + if input_ids is not None: + logger.warning( + "`input_features` and `input_ids` are both non empty. `input_features` will be used in priority " + "through the speech encoder. Make sure `input_features=None` if you want to use the text encoder." + ) + text_generation_output = super().generate(input_features=input_features, **kwargs_text) + else: + self.set_modality("text") + text_generation_output = super().generate(input_ids=input_ids, input_features=None, **kwargs_text) + sequences = text_generation_output.sequences + + if not generate_speech: + return text_generation_output + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get encoder last hidden states + if self.current_modality == "speech": + # get last_hidden_state from encoder - must do a pass through the speech encoder + encoder_hidden_states = self.speech_encoder( + input_features=input_features, attention_mask=attention_mask + ).last_hidden_state + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + else: + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + mint.arange(batch_size) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = ms.Tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = mint.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = ms.Tensor([[vocoder_tgt_lang_id]] * len(unit_ids)) + + spkr_id = ms.Tensor([[spkr_id]] * len(unit_ids)) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +__all__ = [ + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4TForSpeechToText", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TCodeHifiGan", + "SeamlessM4THifiGan", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", +] diff --git a/tests/transformers_tests/models/seamless_m4t/__init__.py b/tests/transformers_tests/models/seamless_m4t/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py new file mode 100644 index 0000000000..896370c898 --- /dev/null +++ b/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -0,0 +1,398 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests//models/albert/test_modeling_albert.py.""" + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the Mindspore SeamlessM4T model.""" + +import inspect + +import numpy as np +import pytest +import torch +from transformers import SeamlessM4TConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) + +from tests.transformers_tests.models.modeling_common import ( + floats_numpy, + ids_numpy, + random_attention_mask, +) + +# CrossEntropyLoss not support bf16 +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3} +MODES = [1] + +class SeamlessM4TModelTester: + def __init__( + self, + parent, + input_modality="speech", + batch_size=2, + seq_length=4, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + max_new_tokens=None, + num_labels=3, + num_choices=4, + scope=None, + vocab_size=20, + t2u_vocab_size=20, + hidden_size=6, + num_hidden_layers=2, + intermediate_size=6, + max_position_embeddings=256, + encoder_layers=2, + decoder_layers=2, + encoder_ffn_dim=6, + decoder_ffn_dim=6, + t2u_encoder_layers=2, + t2u_decoder_layers=2, + t2u_encoder_ffn_dim=6, + t2u_decoder_ffn_dim=6, + num_heads=2, + vocoder_num_spkrs=5, + vocoder_num_langs=5, + upsample_initial_channel=32, + unit_embed_dim=25, + spkr_embed_dim=6, + lang_embed_dim=6, + num_conv_pos_embeddings=8, + unit_hifi_gan_vocab_size=20, + t2u_num_langs=0, + t2u_max_new_tokens=25, + t2u_offset_tgt_lang=0, + vocoder_offset=0, + ): + self.parent = parent + self.input_modality = input_modality + + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.decoder_ffn_dim = decoder_ffn_dim + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.num_heads = num_heads + self.num_attention_heads = num_heads + + self.vocoder_num_spkrs = vocoder_num_spkrs + self.vocoder_num_langs = vocoder_num_langs + self.upsample_initial_channel = upsample_initial_channel + self.unit_embed_dim = unit_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.lang_embed_dim = lang_embed_dim + + self.max_new_tokens = max_new_tokens + + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.t2u_num_langs = t2u_num_langs + self.t2u_max_new_tokens = t2u_max_new_tokens + self.t2u_offset_tgt_lang = t2u_offset_tgt_lang + self.vocoder_offset = vocoder_offset + + def prepare_config_and_inputs(self): + if self.input_modality == "text": + inputs = ids_numpy([self.batch_size, self.seq_length], self.vocab_size - 1) + else: + inputs = ids_numpy([self.batch_size, self.seq_length, 160], self.vocab_size - 1).float() + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + decoder_input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size - 1) + + lm_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, inputs, decoder_input_ids, input_mask, lm_labels + + def get_config(self): + return SeamlessM4TConfig( + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + t2u_vocab_size=self.t2u_vocab_size, + hidden_size=self.hidden_size, + speech_encoder_layers=self.num_heads, + speech_encoder_intermediate_size=self.intermediate_size, + max_position_embeddings=self.max_position_embeddings, + encoder_layers=self.encoder_layers, + decoder_layers=self.decoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + decoder_ffn_dim=self.decoder_ffn_dim, + t2u_encoder_layers=self.t2u_encoder_layers, + t2u_decoder_layers=self.t2u_decoder_layers, + t2u_encoder_ffn_dim=self.t2u_encoder_ffn_dim, + t2u_decoder_ffn_dim=self.t2u_decoder_ffn_dim, + num_attention_heads=self.num_heads, + encoder_attention_heads=self.num_heads, + decoder_attention_heads=self.num_heads, + t2u_encoder_attention_heads=self.num_heads, + t2u_decoder_attention_heads=self.num_heads, + speech_encoder_attention_heads=self.num_heads, + unit_hifigan_vocab_vise=self.t2u_vocab_size, + vocoder_num_spkrs=self.vocoder_num_spkrs, + vocoder_num_langs=self.vocoder_num_langs, + upsample_initial_channel=self.upsample_initial_channel, + unit_embed_dim=self.unit_embed_dim, + spkr_embed_dim=self.spkr_embed_dim, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + lang_embed_dim=self.lang_embed_dim, + max_new_tokens=self.max_new_tokens, + unit_hifi_gan_vocab_size=self.unit_hifi_gan_vocab_size, + t2u_num_langs=self.t2u_num_langs, + t2u_max_new_tokens=self.t2u_max_new_tokens, + t2u_offset_tgt_lang=self.t2u_offset_tgt_lang, + vocoder_offset=self.vocoder_offset, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ) = config_and_inputs + + input_name = "input_ids" if self.input_modality == "text" else "input_features" + + inputs_dict = { + input_name: input_ids, + "attention_mask": input_mask, + "decoder_input_ids": decoder_input_ids, + "labels": lm_labels, + } + return config, inputs_dict + +model_tester = SeamlessM4TModelTester() +( + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, +) = model_tester.prepare_config_and_inputs_for_common() +# config_has_num_labels = copy.deepcopy(config) +# config_has_num_labels.num_labels = model_tester.num_labels + +Seamless_m4t_CASES = [ + [ + "SeamlessM4TForSpeechToSpeech", + "transformers.SeamlessM4TForSpeechToSpeech", + "mindone.transformers.SeamlessM4TForSpeechToSpeech", + (config,), + {}, + ( + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ), + {}, + { + "logits": 0, + "encoder_last_hidden_state": 2, + }, + ], + [ + "SeamlessM4TForSpeechToText", + "transformers.SeamlessM4TForSpeechToText", + "mindone.transformers.SeamlessM4TForSpeechToText", + (config,), + {}, + ( + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ), + {}, + { + "logits": 0, + "encoder_last_hidden_state": 2, + }, + ], + [ + "SeamlessM4TForTextToSpeech", + "transformers.SeamlessM4TForTextToSpeech", + "mindone.transformers.SeamlessM4TForTextToSpeech", + (config,), + {}, + ( + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ), + {}, + { + "logits": 0, + "encoder_last_hidden_state": 2, + }, + ], + [ + "SeamlessM4TForTextToText", + "transformers.SeamlessM4TForTextToText", + "mindone.transformers.SeamlessM4TForTextToText", + (config,), + {}, + ( + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ), + {}, + { + "logits": 0, + "encoder_last_hidden_state": 2, + }, + ], + [ + "SeamlessM4TModel", + "transformers.SeamlessM4TModel", + "mindone.transformers.SeamlessM4TModel", + (config,), + {}, + ( + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ), + {}, + { + "logits": 0, + "encoder_last_hidden_state": 2, + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in Seamless_m4t_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + ms_inputs_kwargs["return_dict"] = False + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) From 2e060cb118434e13183c14c3fe1c1aeb699657aa Mon Sep 17 00:00:00 2001 From: Dong1017 Date: Wed, 17 Sep 2025 08:26:36 +0800 Subject: [PATCH 51/77] 2025/9/17 seamless_m4t ut --- mindone/transformers/__init__.py | 7 + .../seamless_m4t/modeling_seamless_m4t.py | 283 +++++++++--------- .../test_modeling_seamless_m4t.py | 128 ++++---- 3 files changed, 225 insertions(+), 193 deletions(-) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index e4b7d051d0..13004f6ca3 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -520,6 +520,13 @@ RobertaModel, RobertaPreTrainedModel, ) +from .models.seamless_m4t import ( + SeamlessM4TForSpeechToSpeech, + SeamlessM4TForSpeechToText, + SeamlessM4TForTextToSpeech, + SeamlessM4TForTextToText, + SeamlessM4TModel, +) from .models.siglip import ( SiglipForImageClassification, SiglipImageProcessor, diff --git a/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 22a024c997..4b8e7c6951 100644 --- a/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/mindone/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -19,6 +19,7 @@ import copy import math +import numpy as np from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -26,11 +27,13 @@ from transformers.utils import ModelOutput, logging import mindspore as ms -from mindspore import Tensor, mint, nn +from mindspore import mint, nn, ops, Parameter, Tensor from mindspore.mint.nn import CrossEntropyLoss +from mindspore.common.initializer import initializer, Constant, HeNormal, Normal, Uniform, XavierUniform from ...activations import ACT2FN from ...generation import GenerationMixin +from ...mindspore_adapter._conv import Conv1d, ConvTranspose1d from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -128,9 +131,9 @@ def _compute_new_attention_mask(hidden_states: ms.Tensor, seq_lens: ms.Tensor): """ batch_size, mask_seq_len = hidden_states.shape[:2] - indices = mint.arange(mask_seq_len).expand((batch_size, -1)) + indices = mint.arange(mask_seq_len).expand((batch_size, -1),) - bool_mask = indices >= seq_lens.unsqueeze(1).expand((-1, mask_seq_len)) + bool_mask = indices >= seq_lens.unsqueeze(1).expand((-1, mask_seq_len),) mask = hidden_states.new_ones((batch_size, mask_seq_len)) @@ -185,7 +188,7 @@ def format_speech_generation_kwargs(kwargs): class SeamlessM4TConformerPositionalConvEmbedding(nn.Cell): def __init__(self, config): super().__init__() - self.conv = nn.Conv1d( + self.conv = Conv1d( config.hidden_size, config.hidden_size, kernel_size=config.num_conv_pos_embeddings, @@ -256,23 +259,23 @@ def __init__(self, config): self.max_len = config.max_source_positions self.d_model = config.hidden_size self.pe = None - self.extend_pe(ms.Tensor(0.0).expand(1, self.max_len)) + self.extend_pe(ms.Tensor(0.0).expand((1, self.max_len),)) def extend_pe(self, x): # Reset the positional encodings if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.shape[1] >= x.shape[1] * 2 - 1: if self.pe.dtype != x.dtype: self.pe = self.pe.to(dtype=x.dtype) return # Suppose `i` is the position of query vector and `j` is the # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i