1
1
2
- using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
2
+ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
3
3
const NoT = NoTangent ()
4
4
5
5
"""
@@ -11,11 +11,11 @@ Differentiable.
11
11
12
12
# Example
13
13
```jldoctest
14
- julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 ])))
15
- ([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
14
+ julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im ])))
15
+ (ComplexF64 [1.0 + 0.0im , 2.0 + 0.0im , 3.0 + 4.0im ], Restructure(NamedTuple, ..., 3))
16
16
17
- julia> re([10,20,30 ])
18
- (x = [10 .0, 20 .0], y = (sin, [30.0 ]))
17
+ julia> re([3, 5-im, 7+11im ])
18
+ (x = [3 .0, 5 .0], y = (sin, ComplexF64[7.0 + 11.0im ]))
19
19
```
20
20
"""
21
21
function destructure (x)
27
27
Restructure(Model, ..., length)
28
28
29
29
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30
- new parameters from vector `p`. If the model is callable, then `re(x, p)` .
30
+ new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)` .
31
31
32
32
# Example
33
33
```julia
@@ -107,22 +107,22 @@ end
107
107
108
108
function ChainRulesCore. rrule (:: typeof (_rebuild), x, off, flat; len)
109
109
dflat = map! (zero, similar (flat, float (eltype (flat))), flat)
110
- _rebuild_back (dx) = (NoT, NoT, NoT, _accumulate ! (x, dx , off, dflat))
110
+ _rebuild_back (dx) = (NoT, NoT, NoT, _grad ! (x, unthunk (dx) , off, dflat))
111
111
_rebuild (x, off, flat; len), _rebuild_back
112
112
end
113
113
114
114
# This is the gradient of model reconstruction, accumulating duplicates:
115
- function _accumulate ! (x, dx, off, flat:: AbstractVector )
115
+ function _grad ! (x, dx, off, flat:: AbstractVector )
116
116
x′, _ = functor (typeof (x), x)
117
117
dx′, _ = functor (typeof (x), dx)
118
118
off′, _ = functor (typeof (x), off)
119
- foreach ((xᵢ, dxᵢ, oᵢ) -> _accumulate ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119
+ foreach ((xᵢ, dxᵢ, oᵢ) -> _grad ! (xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
120
120
flat
121
121
end
122
- function _accumulate ! (x, dx, off:: Integer , flat:: AbstractVector )
122
+ function _grad ! (x, dx, off:: Integer , flat:: AbstractVector )
123
123
@views flat[off .+ (1 : length (x))] .+ = dx # must visit all tied nodes
124
124
flat
125
125
end
126
- _accumulate ! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
127
- _accumulate ! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
126
+ _grad ! (x, dx:: Zero , off, flat:: AbstractVector ) = nothing
127
+ _grad ! (x, dx:: Zero , off:: Integer , flat:: AbstractVector ) = nothing # ambiguity
128
128
0 commit comments