Skip to content

Commit 2c44d10

Browse files
committed
change build-backend = setuptools -> uv_build
fix type errors
1 parent bd31e7e commit 2c44d10

File tree

8 files changed

+29
-29
lines changed

8 files changed

+29
-29
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.11.12
11+
rev: v0.12.12
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
1515
- id: ruff-format
1616

1717
- repo: https://github.com/pre-commit/pre-commit-hooks
18-
rev: v5.0.0
18+
rev: v6.0.0
1919
hooks:
2020
- id: check-case-conflict
2121
- id: check-symlinks
@@ -52,5 +52,6 @@ repos:
5252
hooks:
5353
- id: ty
5454
name: ty check
55-
entry: ty check .
55+
entry: ty check
5656
language: python
57+
additional_dependencies: [ty]

examples/functorch_mlp_ensemble.ipynb

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@
159159
"\n",
160160
"\n",
161161
"# visualize different noise levels and then train model on the noisiest one\n",
162-
"labels: list[Tensor] = []\n",
163-
"points: list[Tensor] = []\n",
164162
"for noise_std in (0, 0.05, 0.1, 0.2):\n",
165163
" points, labels = make_spirals(100, noise_std=noise_std)\n",
166164
"\n",
@@ -223,25 +221,25 @@
223221
"\n",
224222
"\n",
225223
"def train_step_fn(\n",
226-
" weights: Tensor, batch: Tensor, targets: Tensor, lr: float = 0.2\n",
227-
") -> tuple[Tensor, Tensor, Tensor]:\n",
224+
" weights: list[Tensor], batch: Tensor, targets: Tensor, lr: float = 0.2\n",
225+
") -> tuple[Tensor, Tensor, tuple[Tensor, ...]]:\n",
228226
" \"\"\"This function performs a single training step.\n",
229227
"\n",
230228
" Args:\n",
231-
" weights (Tensor): Model weights.\n",
229+
" weights (list[Tensor]): Model weights.\n",
232230
" batch (Tensor): Mini-batch of training samples.\n",
233231
" targets (Tensor): Ground truth labels for the mini-batch.\n",
234232
" lr (float, optional): Learning rate. Defaults to 0.2.\n",
235233
"\n",
236234
" Returns:\n",
237-
" tuple[Tensor, Tensor, Tensor]: Loss, accuracy, and updated weights.\n",
235+
" tuple[Tensor, Tensor, tuple[Tensor, ...]]: Loss, accuracy, and updated weights.\n",
238236
" \"\"\"\n",
239237
"\n",
240-
" def compute_loss(weights: Tensor, batch: Tensor, targets: Tensor) -> Tensor:\n",
238+
" def compute_loss(weights: list[Tensor], batch: Tensor, targets: Tensor) -> Tensor:\n",
241239
" output = func_model(weights, batch)\n",
242240
" return loss_fn(output, targets)\n",
243241
"\n",
244-
" def accuracy(weights: Tensor, batch: Tensor, targets: Tensor) -> Tensor:\n",
242+
" def accuracy(weights: list[Tensor], batch: Tensor, targets: Tensor) -> Tensor:\n",
245243
" output = func_model(weights, batch)\n",
246244
" return (output.argmax(dim=1) == targets).float().mean()\n",
247245
"\n",
@@ -256,7 +254,7 @@
256254
"\n",
257255
" acc = accuracy(new_weights, batch, targets)\n",
258256
"\n",
259-
" return loss, acc, new_weights"
257+
" return loss, acc, tuple(new_weights)"
260258
]
261259
},
262260
{
@@ -303,7 +301,7 @@
303301
"\n",
304302
"metrics = {}\n",
305303
"for step in range(n_train_steps):\n",
306-
" loss, acc, weights = train_step_fn(weights, points, labels)\n",
304+
" loss, acc, weights = train_step_fn(list(weights), points, labels)\n",
307305
" if step % 100 == 0:\n",
308306
" metrics[step] = {\"loss\": loss.item(), \"acc\": acc.item()}\n",
309307
"\n",
@@ -362,8 +360,9 @@
362360
"batched_weights = initialize_ensemble(n_models=5)\n",
363361
"for step in tqdm(range(n_train_steps), desc=\"training MLP ensemble\"):\n",
364362
" losses, accuracies, batched_weights = parallel_train_step_fn(\n",
365-
" batched_weights, points, labels\n",
363+
" list(batched_weights), points, labels\n",
366364
" )\n",
365+
" batched_weights = list(batched_weights)\n",
367366
"\n",
368367
" loss_dict = {f\"model {idx}\": loss for idx, loss in enumerate(losses, 1)}\n",
369368
" writer.add_scalars(\"training/loss\", loss_dict, step)\n",

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,20 @@ excel = ["openpyxl"]
3939
[project.scripts]
4040
tb-reducer = "tensorboard_reducer:main"
4141

42-
[tool.setuptools.packages]
43-
find = { include = ["tensorboard_reducer*"], exclude = ["tests*"] }
42+
[build-system]
43+
requires = ["uv_build>=0.7.19"]
44+
build-backend = "uv_build"
4445

45-
[tool.distutils.bdist_wheel]
46-
universal = true
46+
[tool.uv.build-backend]
47+
module-name = "tensorboard_reducer"
48+
module-root = ""
4749

4850
[tool.pytest.ini_options]
4951
testpaths = ["tests"]
5052
addopts = "-p no:warnings"
5153

5254
[tool.ruff]
5355
target-version = "py311"
54-
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
5556

5657
[tool.ruff.lint]
5758
select = ["ALL"]

tensorboard_reducer/event_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def reload(self) -> EventAccumulator:
7878
self._process_event(event)
7979
return self
8080

81-
def _process_event(self, event: Event) -> None:
81+
def _process_event(self, event: type[Event]) -> None:
8282
"""Called whenever an event is loaded."""
8383
if self._first_event_timestamp is None:
8484
self._first_event_timestamp = event.wall_time

tensorboard_reducer/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def load_tb_events(
121121
if handle_dup_steps == "mean":
122122
df_scalar = df_scalar.groupby(df_scalar.index).mean()
123123
elif handle_dup_steps in ("keep-first", "keep-last"):
124-
keep = handle_dup_steps.removeprefix( # ty: ignore[possibly-unbound-attribute] # noqa: E501
124+
keep = handle_dup_steps.removeprefix( # ty: ignore[possibly-unbound-attribute]
125125
"keep-"
126126
)
127127
df_scalar = df_scalar[~df_scalar.index.duplicated(keep=keep)]

tensorboard_reducer/write.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def write_tb_events(
7272
list[str]: List of paths to the new TensorBoard event files.
7373
"""
7474
try:
75-
from torch.utils.tensorboard import SummaryWriter
75+
from torch.utils.tensorboard import SummaryWriter # noqa: PLC0415
7676
except ImportError:
7777
try:
78-
from tensorflow.summary import SummaryWriter
78+
from tensorflow.summary import SummaryWriter # noqa: PLC0415
7979
except ImportError:
8080
raise ImportError(
8181
"Cannot import SummaryWriter from torch nor tensorflow. "
@@ -172,7 +172,8 @@ def write_data_file(
172172
# names and tag names as 2nd level
173173
dict_of_dfs = {op: pd.DataFrame(dic) for op, dic in data_to_write.items()}
174174
df_out = pd.concat(dict_of_dfs, axis=1)
175-
df_out.columns = df_out.columns.swaplevel(0, 1)
175+
if isinstance(df_out.columns, pd.MultiIndex):
176+
df_out.columns = df_out.columns.swaplevel(i=0, j=1)
176177
df_out.index.name = "step"
177178

178179
# let pandas handle compression inference from extensions (.csv.gz, .json.bz2, etc.)

tests/test_init.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import re
44

5+
import tensorboard_reducer as tbr
6+
57

68
def test_init_imports() -> None:
79
"""Test that all expected imports are available from tensorboard_reducer."""
8-
import tensorboard_reducer as tbr
9-
1010
assert callable(tbr.load_tb_events)
1111
assert callable(tbr.reduce_events)
1212
assert callable(tbr.write_tb_events)
@@ -16,7 +16,4 @@ def test_init_imports() -> None:
1616

1717
def test_init_version() -> None:
1818
"""Test that the version is available when the package is installed."""
19-
# Back up the current state
20-
import tensorboard_reducer as tbr
21-
2219
assert re.match(r"\d+\.\d+\.\d+", tbr.__version__) is not None

tests/test_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def test_main_report_version(capsys: pytest.CaptureFixture[str], arg: str) -> No
8585
with pytest.raises(SystemExit) as exc_info:
8686
main([arg])
8787

88+
assert isinstance(exc_info.value, SystemExit)
8889
assert exc_info.value.code == 0
8990

9091
stdout, stderr = capsys.readouterr()

0 commit comments

Comments
 (0)