From 0f64af1f594bdf3d35e51814d3f519924d396f3f Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Jul 2025 17:22:05 +0800 Subject: [PATCH 1/5] add Muon Optimizer --- mindone/trainers/muon.py | 235 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 mindone/trainers/muon.py diff --git a/mindone/trainers/muon.py b/mindone/trainers/muon.py new file mode 100644 index 0000000000..afcae63ab8 --- /dev/null +++ b/mindone/trainers/muon.py @@ -0,0 +1,235 @@ +import math +from typing import List, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint as mint +import mindspore.ops as ops +from mindspore import Parameter, ParameterTuple, Tensor +from mindspore.experimental.optim.optimizer import Optimizer + +_muon_opt = ops.MultitypeFuncGraph("muon_opt") + + +@_muon_opt.register( + "Float", + "Float", + "Float", + "Float", + "Bool", + "Int", + "Float", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Float", + "Bool", +) +def _update_run_op( + mu: float, + beta1: float, + beta2: float, + eps: float, + nesterov: bool, + ns_steps: int, + weight_decay: float, + lr: Parameter, + step: Parameter, + param: Parameter, + m: Parameter, + v: Parameter, + g: Tensor, + ratio: float, + use_muon: bool, +) -> bool: + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + v_next = None + if use_muon: + # Muon branch + if g.ndim > 2: + g = g.view(g.shape[0], -1) + m_next = mu * m + g + if nesterov: + g = g.add(m_next, alpha=mu) + else: + g = m_next + g = zeropower_via_newtonschulz5(g, steps=ns_steps) + param.add_(-(lr * ratio) * g) + else: + # AdamW branch + m_next = mint.lerp(g, m, beta1) + v_next = mint.lerp(mint.square(g), v, beta2) + g = m_next / (eps + mint.sqrt(v_next)) + bias_correction1 = 1 - mint.pow(beta1, step) + bias_correction2 = 1 - mint.pow(beta2, step) + scale = bias_correction1 / bias_correction2**0.5 + param.add_(-(lr / scale) * g) + + ops.assign(m, m_next) + if not use_muon: + ops.assign(v, v_next) + return True + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.shape[0] > G.shape[1]: + X = X.T + # Ensure spectral norm is at most 1 + X = X / (mint.norm(X) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.shape[0] > G.shape[1]: + X = X.T + return X + + +class Muon(Optimizer): + """Following https://github.com/MoonshotAI/Moonlight""" + + def __init__( + self, + lr: Union[float, Tensor] = 1e-3, + wd: float = 0.1, + muon_params: Optional[List[Parameter]] = None, + momentum: float = 0.95, + nesterov: bool = True, + ns_steps: int = 5, + adamw_params: Optional[List[Parameter]] = None, + adamw_betas: Tuple[float, float] = (0.9, 0.95), + adamw_eps: float = 1e-8, + ) -> None: + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + use_muon = list() + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + use_muon.append(True) + + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + use_muon.append(False) + self.use_muon = tuple(use_muon) + + self.exp_avg = self.parameters.clone("exp_avg", init="zeros") + self.exp_avg_sq = ParameterTuple( + [ + ( + Parameter(mint.zeros(x.shape, dtype=x.dtype), name="exp_avg_sq." + x.name) + if not use_muon + else Parameter([], name="exp_avg_sq." + x.name) + ) + for x, use_muon in zip(self.parameters, self.use_muon) + ] + ) + + self.lr_ratio = tuple([self._cal_lr_ratio(x, use_muon) for x, use_muon in zip(self.parameters, self.use_muon)]) + + self.state_step = Parameter(Tensor(0, dtype=ms.int32)) + self.increase_tensor = Tensor(1, dtype=ms.int32) + + def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2) -> float: + if not use_muon: + return 1.0 + + A, B = param.shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = rms_scale * math.sqrt(max(A, B)) + return adjusted_ratio + + @ms.jit + def muon( + self, + momentum: float, + beta1: float, + beta2: float, + eps: float, + nesterov: bool, + ns_steps: int, + weight_decay: float, + lr: Parameter, + gradients: Tuple[Tensor, ...], + ratio: Tuple[float, ...], + use_muon: Tuple[bool, ...], + start_id: int, + end_id: int, + ) -> bool: + optim_result = self.hyper_map( + ops.partial( + _muon_opt, + momentum, + beta1, + beta2, + eps, + nesterov, + ns_steps, + weight_decay, + lr, + self.state_step, + ), + self.parameters[start_id:end_id], + self.exp_avg[start_id:end_id], + self.exp_avg_sq[start_id:end_id], + gradients[start_id:end_id], + ratio[start_id:end_id], + use_muon[start_id:end_id], + ) + return optim_result + + def construct(self, gradients: Tuple[Tensor, ...]) -> bool: + self.state_step += self.increase_tensor + for group_id, group in enumerate(self.param_groups): + beta1, beta2 = group["adamw_betas"] + start_id = self.group_start_id[group_id] + end_id = self.group_start_id[group_id + 1] + + self.muon( + group["momentum"], + beta1, + beta2, + group["adamw_eps"], + group["nesterov"], + group["ns_steps"], + group["weight_decay"], + group["lr"], + gradients, + self.lr_ratio, + self.use_muon, + start_id, + end_id, + ) + + return True From e18706e38feb17c61b6fbd94e61c3cb035ea7dfb Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Jul 2025 17:36:28 +0800 Subject: [PATCH 2/5] add test script --- .../benchmark/mindspore/toy_train_ms.py | 177 +++++++++ mindone/trainers/benchmark/torch/toy_train.py | 349 ++++++++++++++++++ mindone/trainers/muon.py | 1 + 3 files changed, 527 insertions(+) create mode 100644 mindone/trainers/benchmark/mindspore/toy_train_ms.py create mode 100644 mindone/trainers/benchmark/torch/toy_train.py diff --git a/mindone/trainers/benchmark/mindspore/toy_train_ms.py b/mindone/trainers/benchmark/mindspore/toy_train_ms.py new file mode 100644 index 0000000000..d8633d10f4 --- /dev/null +++ b/mindone/trainers/benchmark/mindspore/toy_train_ms.py @@ -0,0 +1,177 @@ +"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" + +import os +import time +from functools import partial + +import numpy as np +from datasets import load_dataset +from loguru import logger +from optim import AdamW, Muon +from tqdm import tqdm +from transformers import Qwen2Config, Qwen2Tokenizer +from transformers.optimization import _get_cosine_schedule_with_warmup_lr_lambda + +import mindspore as ms +from mindspore.dataset import GeneratorDataset +from mindspore.experimental.optim import Optimizer +from mindspore.experimental.optim.lr_scheduler import LambdaLR + +from mindone.transformers import Qwen2ForCausalLM + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class MoonDataset: + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.npy"): + self.tokens = np.load(f"{self.dataset_name}.npy") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + np.save(f"{self.dataset_name}.npy", self.tokens) + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = np.asarray(token_slice, dtype=np.int32) + return data + + +def get_model_and_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name]) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + # mike: default shuffle = True, for comparison set it to be False + train_loader = GeneratorDataset(train_dataset, column_names="input_ids", shuffle=False).batch(8) + + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model, train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return AdamW(model.get_parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.parameters_and_names() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.parameters_and_names() + if not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + + model, train_loader = get_model_and_dataloader(args.model, args.dataset, args.hidden_size) + optimizer = get_optimizer(args.optimizer, model, lr=args.lr) + + model.set_train(True) + epoch = 1 + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + + grad_fn = ms.value_and_grad(model, grad_position=None, weights=optimizer.parameters, has_aux=True) + for epoch in range(epoch): + for step, batch in enumerate(train_loader.create_tuple_iterator()): + (input_ids,) = batch + (loss, _), grads = grad_fn(input_ids=input_ids, labels=input_ids, return_dict=False) + ms.synchronize() + start = time.time() + optimizer(grads) + ms.synchronize() + duration = time.time() - start + lr_scheduler.step() + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr'].item():.5f} " + f"Optimizer update time: {duration:.3f} Training loss: {loss.item()}" + ) diff --git a/mindone/trainers/benchmark/torch/toy_train.py b/mindone/trainers/benchmark/torch/toy_train.py new file mode 100644 index 0000000000..905ec30218 --- /dev/null +++ b/mindone/trainers/benchmark/torch/toy_train.py @@ -0,0 +1,349 @@ +"""Copied from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" + +import math +import os +import time + +import torch +from datasets import load_dataset +from loguru import logger +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Tokenizer, get_cosine_schedule_with_warmup + + +class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model_and_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name]) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + # mike: default shuffle = True, for comparison set it to be False + train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False) + + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model, train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + + model, train_loader = get_model_and_dataloader(args.model, args.dataset, args.hidden_size) + optimizer = get_optimizer(args.optimizer, model, lr=args.lr) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + model.train() + epoch = 1 + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(device) + input_ids = batch + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + torch.cuda.synchronize() + start = time.time() + optimizer.step() + torch.cuda.synchronize() + duration = time.time() - start + lr_scheduler.step() + optimizer.zero_grad() + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']:.5f} Optimizer update time: {duration:.3f} Training loss: {loss.item()}" + ) diff --git a/mindone/trainers/muon.py b/mindone/trainers/muon.py index afcae63ab8..b2e3a65a10 100644 --- a/mindone/trainers/muon.py +++ b/mindone/trainers/muon.py @@ -1,3 +1,4 @@ +"""Modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py""" import math from typing import List, Optional, Tuple, Union From 34008ae73219ad24905ee3d44147a58f9e967530 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Jul 2025 17:47:03 +0800 Subject: [PATCH 3/5] move test scripts --- .../trainer_tests/muon}/mindspore/toy_train_ms.py | 0 .../benchmark => tests/trainer_tests/muon}/torch/toy_train.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {mindone/trainers/benchmark => tests/trainer_tests/muon}/mindspore/toy_train_ms.py (100%) rename {mindone/trainers/benchmark => tests/trainer_tests/muon}/torch/toy_train.py (100%) diff --git a/mindone/trainers/benchmark/mindspore/toy_train_ms.py b/tests/trainer_tests/muon/mindspore/toy_train_ms.py similarity index 100% rename from mindone/trainers/benchmark/mindspore/toy_train_ms.py rename to tests/trainer_tests/muon/mindspore/toy_train_ms.py diff --git a/mindone/trainers/benchmark/torch/toy_train.py b/tests/trainer_tests/muon/torch/toy_train.py similarity index 100% rename from mindone/trainers/benchmark/torch/toy_train.py rename to tests/trainer_tests/muon/torch/toy_train.py From ad4302223899c44e842f8d7b51d3a9a781e6a6cf Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Jul 2025 18:49:01 +0800 Subject: [PATCH 4/5] improve speed --- mindone/trainers/muon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindone/trainers/muon.py b/mindone/trainers/muon.py index b2e3a65a10..2250bba609 100644 --- a/mindone/trainers/muon.py +++ b/mindone/trainers/muon.py @@ -95,9 +95,9 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: X = X / (mint.norm(X) + 1e-7) # Perform the NS iterations for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X + A = mint.matmul(X, X.T) + B = b * A + c * mint.matmul(A, A) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + mint.matmul(B, X) if G.shape[0] > G.shape[1]: X = X.T From f062a1479742fd57eba07c8513078b328b8422fc Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 17 Jul 2025 15:04:03 +0800 Subject: [PATCH 5/5] improve speed 2 --- mindone/trainers/muon.py | 58 ++++++++++--------- .../muon/mindspore/toy_train_ms.py | 4 +- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/mindone/trainers/muon.py b/mindone/trainers/muon.py index 2250bba609..1d9ff7277e 100644 --- a/mindone/trainers/muon.py +++ b/mindone/trainers/muon.py @@ -37,7 +37,7 @@ def _update_run_op( ns_steps: int, weight_decay: float, lr: Parameter, - step: Parameter, + denom: Parameter, param: Parameter, m: Parameter, v: Parameter, @@ -48,30 +48,20 @@ def _update_run_op( if weight_decay != 0: param.mul_(1 - lr * weight_decay) - v_next = None if use_muon: - # Muon branch - if g.ndim > 2: - g = g.view(g.shape[0], -1) - m_next = mu * m + g + m.mul_(mu).add_(g) if nesterov: - g = g.add(m_next, alpha=mu) + g = g.add(m, alpha=mu) else: - g = m_next + g = m g = zeropower_via_newtonschulz5(g, steps=ns_steps) - param.add_(-(lr * ratio) * g) + param.add_(lr * g, alpha=-ratio) else: - # AdamW branch m_next = mint.lerp(g, m, beta1) v_next = mint.lerp(mint.square(g), v, beta2) g = m_next / (eps + mint.sqrt(v_next)) - bias_correction1 = 1 - mint.pow(beta1, step) - bias_correction2 = 1 - mint.pow(beta2, step) - scale = bias_correction1 / bias_correction2**0.5 - param.add_(-(lr / scale) * g) - - ops.assign(m, m_next) - if not use_muon: + param.add_(-(lr / denom) * g) + ops.assign(m, m_next) ops.assign(v, v_next) return True @@ -86,21 +76,30 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) + shape = G.shape + + if len(shape) > 2: + G = G.view(G.shape[0], -1) + assert len(shape) == 2 + + a, b, c = 3.4445, -4.7750, 2.0315 X = G.bfloat16() if G.shape[0] > G.shape[1]: - X = X.T + X = mint.t(X) + # Ensure spectral norm is at most 1 X = X / (mint.norm(X) + 1e-7) # Perform the NS iterations for _ in range(steps): A = mint.matmul(X, X.T) - B = b * A + c * mint.matmul(A, A) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + mint.matmul(B, X) + B = mint.addmm(A, A, A, beta=b, alpha=c) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = mint.addmm(X, B, X, beta=a) if G.shape[0] > G.shape[1]: - X = X.T + X = mint.t(X) + + if len(shape) > 2: + X = X.view(*shape) return X @@ -136,7 +135,7 @@ def __init__( use_muon = list() for p in muon_params: # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim + assert p.ndim >= 2, p.ndim use_muon.append(True) for p in adamw_params: @@ -160,6 +159,7 @@ def __init__( self.state_step = Parameter(Tensor(0, dtype=ms.int32)) self.increase_tensor = Tensor(1, dtype=ms.int32) + self.denom = Parameter(Tensor(1.0, dtype=ms.float32)) def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2) -> float: if not use_muon: @@ -171,7 +171,7 @@ def _cal_lr_ratio(self, param: Parameter, use_muon: bool, rms_scale: float = 0.2 adjusted_ratio = rms_scale * math.sqrt(max(A, B)) return adjusted_ratio - @ms.jit + @ms.jit(jit_level="O1") def muon( self, momentum: float, @@ -188,6 +188,10 @@ def muon( start_id: int, end_id: int, ) -> bool: + bias_correction1 = 1 - beta1**self.state_step + bias_correction2 = 1 - beta2**self.state_step + ops.assign(self.denom, bias_correction1 / bias_correction2**0.5) + optim_result = self.hyper_map( ops.partial( _muon_opt, @@ -199,7 +203,7 @@ def muon( ns_steps, weight_decay, lr, - self.state_step, + self.denom, ), self.parameters[start_id:end_id], self.exp_avg[start_id:end_id], @@ -211,7 +215,7 @@ def muon( return optim_result def construct(self, gradients: Tuple[Tensor, ...]) -> bool: - self.state_step += self.increase_tensor + self.state_step.add_(self.increase_tensor) for group_id, group in enumerate(self.param_groups): beta1, beta2 = group["adamw_betas"] start_id = self.group_start_id[group_id] diff --git a/tests/trainer_tests/muon/mindspore/toy_train_ms.py b/tests/trainer_tests/muon/mindspore/toy_train_ms.py index d8633d10f4..f4471fd14a 100644 --- a/tests/trainer_tests/muon/mindspore/toy_train_ms.py +++ b/tests/trainer_tests/muon/mindspore/toy_train_ms.py @@ -7,16 +7,16 @@ import numpy as np from datasets import load_dataset from loguru import logger -from optim import AdamW, Muon from tqdm import tqdm from transformers import Qwen2Config, Qwen2Tokenizer from transformers.optimization import _get_cosine_schedule_with_warmup_lr_lambda import mindspore as ms from mindspore.dataset import GeneratorDataset -from mindspore.experimental.optim import Optimizer +from mindspore.experimental.optim import AdamW, Optimizer from mindspore.experimental.optim.lr_scheduler import LambdaLR +from mindone.trainers.muon import Muon from mindone.transformers import Qwen2ForCausalLM