|
159 | 159 | "\n",
|
160 | 160 | "\n",
|
161 | 161 | "# visualize different noise levels and then train model on the noisiest one\n",
|
162 |
| - "labels: list[Tensor] = []\n", |
163 |
| - "points: list[Tensor] = []\n", |
164 | 162 | "for noise_std in (0, 0.05, 0.1, 0.2):\n",
|
165 | 163 | " points, labels = make_spirals(100, noise_std=noise_std)\n",
|
166 | 164 | "\n",
|
|
223 | 221 | "\n",
|
224 | 222 | "\n",
|
225 | 223 | "def train_step_fn(\n",
|
226 |
| - " weights: Tensor, batch: Tensor, targets: Tensor, lr: float = 0.2\n", |
227 |
| - ") -> tuple[Tensor, Tensor, Tensor]:\n", |
| 224 | + " weights: list[Tensor], batch: Tensor, targets: Tensor, lr: float = 0.2\n", |
| 225 | + ") -> tuple[Tensor, Tensor, tuple[Tensor, ...]]:\n", |
228 | 226 | " \"\"\"This function performs a single training step.\n",
|
229 | 227 | "\n",
|
230 | 228 | " Args:\n",
|
231 |
| - " weights (Tensor): Model weights.\n", |
| 229 | + " weights (list[Tensor]): Model weights.\n", |
232 | 230 | " batch (Tensor): Mini-batch of training samples.\n",
|
233 | 231 | " targets (Tensor): Ground truth labels for the mini-batch.\n",
|
234 | 232 | " lr (float, optional): Learning rate. Defaults to 0.2.\n",
|
235 | 233 | "\n",
|
236 | 234 | " Returns:\n",
|
237 |
| - " tuple[Tensor, Tensor, Tensor]: Loss, accuracy, and updated weights.\n", |
| 235 | + " tuple[Tensor, Tensor, tuple[Tensor, ...]]: Loss, accuracy, and updated weights.\n", |
238 | 236 | " \"\"\"\n",
|
239 | 237 | "\n",
|
240 |
| - " def compute_loss(weights: Tensor, batch: Tensor, targets: Tensor) -> Tensor:\n", |
| 238 | + " def compute_loss(weights: list[Tensor], batch: Tensor, targets: Tensor) -> Tensor:\n", |
241 | 239 | " output = func_model(weights, batch)\n",
|
242 | 240 | " return loss_fn(output, targets)\n",
|
243 | 241 | "\n",
|
244 |
| - " def accuracy(weights: Tensor, batch: Tensor, targets: Tensor) -> Tensor:\n", |
| 242 | + " def accuracy(weights: list[Tensor], batch: Tensor, targets: Tensor) -> Tensor:\n", |
245 | 243 | " output = func_model(weights, batch)\n",
|
246 | 244 | " return (output.argmax(dim=1) == targets).float().mean()\n",
|
247 | 245 | "\n",
|
|
256 | 254 | "\n",
|
257 | 255 | " acc = accuracy(new_weights, batch, targets)\n",
|
258 | 256 | "\n",
|
259 |
| - " return loss, acc, new_weights" |
| 257 | + " return loss, acc, tuple(new_weights)" |
260 | 258 | ]
|
261 | 259 | },
|
262 | 260 | {
|
|
303 | 301 | "\n",
|
304 | 302 | "metrics = {}\n",
|
305 | 303 | "for step in range(n_train_steps):\n",
|
306 |
| - " loss, acc, weights = train_step_fn(weights, points, labels)\n", |
| 304 | + " loss, acc, weights = train_step_fn(list(weights), points, labels)\n", |
307 | 305 | " if step % 100 == 0:\n",
|
308 | 306 | " metrics[step] = {\"loss\": loss.item(), \"acc\": acc.item()}\n",
|
309 | 307 | "\n",
|
|
362 | 360 | "batched_weights = initialize_ensemble(n_models=5)\n",
|
363 | 361 | "for step in tqdm(range(n_train_steps), desc=\"training MLP ensemble\"):\n",
|
364 | 362 | " losses, accuracies, batched_weights = parallel_train_step_fn(\n",
|
365 |
| - " batched_weights, points, labels\n", |
| 363 | + " list(batched_weights), points, labels\n", |
366 | 364 | " )\n",
|
| 365 | + " batched_weights = list(batched_weights)\n", |
367 | 366 | "\n",
|
368 | 367 | " loss_dict = {f\"model {idx}\": loss for idx, loss in enumerate(losses, 1)}\n",
|
369 | 368 | " writer.add_scalars(\"training/loss\", loss_dict, step)\n",
|
|
0 commit comments