diff --git a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py index 786ad2aa..bd52fe7b 100644 --- a/src/invoke_training/_shared/flux/lora_checkpoint_utils.py +++ b/src/invoke_training/_shared/flux/lora_checkpoint_utils.py @@ -1,16 +1,14 @@ # ruff: noqa: N806 import os +import shutil +import tempfile from pathlib import Path import peft import torch -from diffusers import FluxTransformer2DModel -from transformers import CLIPTextModel from invoke_training._shared.checkpoints.lora_checkpoint_utils import ( _convert_peft_state_dict_to_kohya_state_dict, - load_multi_model_peft_checkpoint, - save_multi_model_peft_checkpoint, ) from invoke_training._shared.checkpoints.serialization import save_state_dict @@ -59,42 +57,36 @@ } -def save_flux_peft_checkpoint( +def save_flux_peft_checkpoint_single_file( checkpoint_dir: Path | str, transformer: peft.PeftModel | None, - text_encoder_1: peft.PeftModel | None, - text_encoder_2: peft.PeftModel | None, ): - models = {} - if transformer is not None: - models[FLUX_PEFT_TRANSFORMER_KEY] = transformer - if text_encoder_1 is not None: - models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1 - if text_encoder_2 is not None: - models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2 + assert isinstance(transformer, peft.PeftModel) + if ( + hasattr(transformer, "config") + and isinstance(transformer.config, dict) + and "_name_or_path" not in transformer.config + ): + transformer.config["_name_or_path"] = None - save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) + # Normalize output path and ensure parent exists when saving as file + out_path = Path(checkpoint_dir) + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + # Save PEFT adapter into temporary directory + transformer.save_pretrained(str(tmp_path)) -def load_flux_peft_checkpoint( - checkpoint_dir: Path | str, - transformer: FluxTransformer2DModel, - text_encoder_1: CLIPTextModel, - text_encoder_2: CLIPTextModel, - is_trainable: bool = False, -): - models = load_multi_model_peft_checkpoint( - checkpoint_dir=checkpoint_dir, - models={ - FLUX_PEFT_TRANSFORMER_KEY: transformer, - FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1, - FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2, - }, - is_trainable=is_trainable, - raise_if_subdir_missing=False, - ) + # Move adapter_model.safetensors out of the temp dir to the requested location + src_file = tmp_path / "adapter_model.safetensors" + if not src_file.exists(): + raise FileNotFoundError(f"Expected adapter file not found in temporary directory: {src_file}") + + # Always rename/move to exactly the path provided by checkpoint_dir + dest_file = out_path + dest_file.parent.mkdir(parents=True, exist_ok=True) - return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY] + shutil.move(str(src_file), str(dest_file)) def save_flux_kohya_checkpoint( diff --git a/src/invoke_training/pipelines/flux/lora/train.py b/src/invoke_training/pipelines/flux/lora/train.py index 1ae48b71..17653fb5 100644 --- a/src/invoke_training/pipelines/flux/lora/train.py +++ b/src/invoke_training/pipelines/flux/lora/train.py @@ -33,7 +33,7 @@ from invoke_training._shared.flux.encoding_utils import encode_prompt from invoke_training._shared.flux.lora_checkpoint_utils import ( save_flux_kohya_checkpoint, - save_flux_peft_checkpoint, + save_flux_peft_checkpoint_single_file, ) from invoke_training._shared.flux.model_loading_utils import load_models_flux from invoke_training._shared.flux.validation import generate_validation_images_flux @@ -63,9 +63,8 @@ def _save_flux_lora_checkpoint( if lora_checkpoint_format == "invoke_peft": model_type = ModelType.FLUX_LORA_PEFT - save_flux_peft_checkpoint( - Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 - ) + save_flux_peft_checkpoint_single_file(Path(save_path), transformer=transformer) + elif lora_checkpoint_format == "kohya": model_type = ModelType.FLUX_LORA_KOHYA save_flux_kohya_checkpoint( @@ -424,8 +423,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_transformer: transformer_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, target_modules=config.flux_lora_target_modules, ) transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate) @@ -433,7 +431,7 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, - lora_alpha=1.0, + lora_alpha=config.lora_rank_dim, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, ) @@ -530,7 +528,8 @@ def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = N base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, - extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, + # we are going to massage the peft model for this in save_flux_peft_checkpoint_single_file + extension=".safetensors", ) # Train!