@@ -1055,6 +1055,62 @@ def clean_jaxpr(
1055
1055
return to_jaxpr_or_closed_jaxpr (closed_jaxpr , jaxpr )
1056
1056
1057
1057
1058
+ def clean_layer_tags_jaxpr (
1059
+ jaxpr : J ,
1060
+ only_remove_auto_tags : bool = False ,
1061
+ ) -> J :
1062
+ """Removes layer tags from a Jaxpr."""
1063
+
1064
+ closed_jaxpr = to_closed_jaxpr (jaxpr )
1065
+ eqns = []
1066
+ var_map = {}
1067
+
1068
+ for eqn in closed_jaxpr .jaxpr .eqns :
1069
+ if isinstance (eqn .primitive , tags .LayerTag ) and (
1070
+ not only_remove_auto_tags or "Auto" in eqn .params ["meta" ].name
1071
+ ):
1072
+ for ind1 , ind2 in enumerate (eqn .params ["meta" ].outputs_index ):
1073
+ var_map [eqn .outvars [ind1 ]] = eqn .invars [ind2 ]
1074
+ else :
1075
+ eqns .append (eqn )
1076
+
1077
+ eqns_new = []
1078
+ for eqn in eqns :
1079
+ new_invars = []
1080
+ for var in eqn .invars :
1081
+ if not isinstance (var , jex .core .Literal ) and var in var_map .keys ():
1082
+ new_invars .append (var_map [var ])
1083
+ else :
1084
+ new_invars .append (var )
1085
+ eqns_new .append (eqn .replace (invars = new_invars ))
1086
+
1087
+ closed_jaxpr = ClosedJaxpr (
1088
+ jaxpr = closed_jaxpr .jaxpr .replace (eqns = eqns_new ),
1089
+ consts = closed_jaxpr .consts ,
1090
+ )
1091
+
1092
+ return to_jaxpr_or_closed_jaxpr (closed_jaxpr , jaxpr )
1093
+
1094
+
1095
+ def clean_layer_tags_func (
1096
+ func : utils .Func ,
1097
+ func_args : utils .FuncArgs ,
1098
+ only_remove_auto_tags : bool = False ,
1099
+ ) -> utils .Func :
1100
+ """Removes layer tags from a function."""
1101
+ typed_jaxpr = jax .make_jaxpr (func )(* func_args )
1102
+ jaxpr_clean = clean_layer_tags_jaxpr (typed_jaxpr , only_remove_auto_tags )
1103
+
1104
+ def func_clean (* args ):
1105
+ # eval_jaxpr takes the jaxpr, its constants (literals), and arguments
1106
+ flattened_args = jax .tree_util .tree_leaves (args )
1107
+ return jax .core .eval_jaxpr (
1108
+ jaxpr_clean .jaxpr , jaxpr_clean .literals , * flattened_args
1109
+ )
1110
+
1111
+ return func_clean
1112
+
1113
+
1058
1114
# Prototype for clean_jaxpr using JAX's dce_jaxpr. Doesn't work because
1059
1115
# dce_jaxpr will remove any equations with no used outputs, regardless of the
1060
1116
# dce_rule for that equation's primitive. Adding an "effect" to loss/layer
0 commit comments