Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 25 additions & 33 deletions src/invoke_training/_shared/flux/lora_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# ruff: noqa: N806
import os
import shutil
import tempfile
from pathlib import Path

import peft
import torch
from diffusers import FluxTransformer2DModel
from transformers import CLIPTextModel

from invoke_training._shared.checkpoints.lora_checkpoint_utils import (
_convert_peft_state_dict_to_kohya_state_dict,
load_multi_model_peft_checkpoint,
save_multi_model_peft_checkpoint,
)
from invoke_training._shared.checkpoints.serialization import save_state_dict

Expand Down Expand Up @@ -59,42 +57,36 @@
}


def save_flux_peft_checkpoint(
def save_flux_peft_checkpoint_single_file(
checkpoint_dir: Path | str,
transformer: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
models = {}
if transformer is not None:
models[FLUX_PEFT_TRANSFORMER_KEY] = transformer
if text_encoder_1 is not None:
models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1
if text_encoder_2 is not None:
models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2
assert isinstance(transformer, peft.PeftModel)
if (
hasattr(transformer, "config")
and isinstance(transformer.config, dict)
and "_name_or_path" not in transformer.config
):
transformer.config["_name_or_path"] = None

save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)
# Normalize output path and ensure parent exists when saving as file
out_path = Path(checkpoint_dir)

with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
# Save PEFT adapter into temporary directory
transformer.save_pretrained(str(tmp_path))

def load_flux_peft_checkpoint(
checkpoint_dir: Path | str,
transformer: FluxTransformer2DModel,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
is_trainable: bool = False,
):
models = load_multi_model_peft_checkpoint(
checkpoint_dir=checkpoint_dir,
models={
FLUX_PEFT_TRANSFORMER_KEY: transformer,
FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,
FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,
},
is_trainable=is_trainable,
raise_if_subdir_missing=False,
)
# Move adapter_model.safetensors out of the temp dir to the requested location
src_file = tmp_path / "adapter_model.safetensors"
if not src_file.exists():
raise FileNotFoundError(f"Expected adapter file not found in temporary directory: {src_file}")

# Always rename/move to exactly the path provided by checkpoint_dir
dest_file = out_path
dest_file.parent.mkdir(parents=True, exist_ok=True)

return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY]
shutil.move(str(src_file), str(dest_file))


def save_flux_kohya_checkpoint(
Expand Down
15 changes: 7 additions & 8 deletions src/invoke_training/pipelines/flux/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from invoke_training._shared.flux.encoding_utils import encode_prompt
from invoke_training._shared.flux.lora_checkpoint_utils import (
save_flux_kohya_checkpoint,
save_flux_peft_checkpoint,
save_flux_peft_checkpoint_single_file,
)
from invoke_training._shared.flux.model_loading_utils import load_models_flux
from invoke_training._shared.flux.validation import generate_validation_images_flux
Expand Down Expand Up @@ -63,9 +63,8 @@ def _save_flux_lora_checkpoint(

if lora_checkpoint_format == "invoke_peft":
model_type = ModelType.FLUX_LORA_PEFT
save_flux_peft_checkpoint(
Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
save_flux_peft_checkpoint_single_file(Path(save_path), transformer=transformer)

elif lora_checkpoint_format == "kohya":
model_type = ModelType.FLUX_LORA_KOHYA
save_flux_kohya_checkpoint(
Expand Down Expand Up @@ -424,16 +423,15 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
if config.train_transformer:
transformer_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
lora_alpha=config.lora_rank_dim,
target_modules=config.flux_lora_target_modules,
)
transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate)

if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
lora_alpha=config.lora_rank_dim,
# init_lora_weights="gaussian",
target_modules=config.text_encoder_lora_target_modules,
)
Expand Down Expand Up @@ -530,7 +528,8 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N
base_dir=ckpt_dir,
prefix="checkpoint",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
# we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file
extension=".safetensors",
)

# Train!
Expand Down