Skip to content

Commit 72a29fc

Browse files
KfacJaxDevKfacJaxDev
authored andcommitted
Add function to remove layer tags from jaxpr
PiperOrigin-RevId: 804950216
1 parent f7bbd0f commit 72a29fc

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

kfac_jax/_src/tag_graph_matcher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,62 @@ def clean_jaxpr(
10551055
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)
10561056

10571057

1058+
def clean_layer_tags_jaxpr(
1059+
jaxpr: J,
1060+
only_remove_auto_tags: bool = False,
1061+
) -> J:
1062+
"""Removes layer tags from a Jaxpr."""
1063+
1064+
closed_jaxpr = to_closed_jaxpr(jaxpr)
1065+
eqns = []
1066+
var_map = {}
1067+
1068+
for eqn in closed_jaxpr.jaxpr.eqns:
1069+
if isinstance(eqn.primitive, tags.LayerTag) and (
1070+
not only_remove_auto_tags or "Auto" in eqn.params["meta"].name
1071+
):
1072+
for ind1, ind2 in enumerate(eqn.params["meta"].outputs_index):
1073+
var_map[eqn.outvars[ind1]] = eqn.invars[ind2]
1074+
else:
1075+
eqns.append(eqn)
1076+
1077+
eqns_new = []
1078+
for eqn in eqns:
1079+
new_invars = []
1080+
for var in eqn.invars:
1081+
if not isinstance(var, jex.core.Literal) and var in var_map.keys():
1082+
new_invars.append(var_map[var])
1083+
else:
1084+
new_invars.append(var)
1085+
eqns_new.append(eqn.replace(invars=new_invars))
1086+
1087+
closed_jaxpr = ClosedJaxpr(
1088+
jaxpr=closed_jaxpr.jaxpr.replace(eqns=eqns_new),
1089+
consts=closed_jaxpr.consts,
1090+
)
1091+
1092+
return to_jaxpr_or_closed_jaxpr(closed_jaxpr, jaxpr)
1093+
1094+
1095+
def clean_layer_tags_func(
1096+
func: utils.Func,
1097+
func_args: utils.FuncArgs,
1098+
only_remove_auto_tags: bool = False,
1099+
) -> utils.Func:
1100+
"""Removes layer tags from a function."""
1101+
typed_jaxpr = jax.make_jaxpr(func)(*func_args)
1102+
jaxpr_clean = clean_layer_tags_jaxpr(typed_jaxpr, only_remove_auto_tags)
1103+
1104+
def func_clean(*args):
1105+
# eval_jaxpr takes the jaxpr, its constants (literals), and arguments
1106+
flattened_args = jax.tree_util.tree_leaves(args)
1107+
return jax.core.eval_jaxpr(
1108+
jaxpr_clean.jaxpr, jaxpr_clean.literals, *flattened_args
1109+
)
1110+
1111+
return func_clean
1112+
1113+
10581114
# Prototype for clean_jaxpr using JAX's dce_jaxpr. Doesn't work because
10591115
# dce_jaxpr will remove any equations with no used outputs, regardless of the
10601116
# dce_rule for that equation's primitive. Adding an "effect" to loss/layer

kfac_jax/_src/tracer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,11 @@ def make_from_func(
321321
func_args = tuple(func_args)
322322

323323
if auto_register_tags:
324+
# Always remove existing auto layer tags before re-running the automatic
325+
# registration.
326+
func = tgm.clean_layer_tags_func(
327+
func, func_args, only_remove_auto_tags=True
328+
)
324329
func = tgm.auto_register_tags(
325330
func=func,
326331
func_args=func_args,
@@ -824,14 +829,12 @@ def forward() -> tuple[Array, ...]:
824829
# Loop through equations and evaluate them
825830
num_losses_passed = 0
826831
for eqn in processed_jaxpr.jaxpr.eqns:
827-
828-
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
829-
830832
if isinstance(eqn.primitive, tags.LossTag):
831833
num_losses_passed += 1
832834
if num_losses_passed == len(processed_jaxpr.loss_tags):
833835
break
834-
836+
else:
837+
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
835838
assert num_losses_passed == len(processed_jaxpr.loss_tags)
836839

837840
return tuple(read(layer_input_vars))
@@ -884,7 +887,6 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
884887
for eqn in processed_jaxpr.jaxpr.eqns:
885888

886889
input_values = read(eqn.invars)
887-
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
888890

889891
if isinstance(eqn.primitive, tags.LossTag):
890892
loss: LossFunction = tags.loss_eqn_construct_loss(eqn, *input_values)
@@ -896,6 +898,8 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
896898

897899
if num_losses_passed == len(processed_jaxpr.loss_tags):
898900
break
901+
else:
902+
write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
899903

900904
assert num_losses_passed == len(processed_jaxpr.loss_tags)
901905

0 commit comments

Comments
 (0)