Skip to content

Commit af14f84

Browse files
committed
tweak
1 parent 520efbe commit af14f84

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

src/destructure.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
2+
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
33
const NoT = NoTangent()
44

55
"""
@@ -11,11 +11,11 @@ Differentiable.
1111
1212
# Example
1313
```jldoctest
14-
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0])))
15-
([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
14+
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im])))
15+
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
1616
17-
julia> re([10,20,30])
18-
(x = [10.0, 20.0], y = (sin, [30.0]))
17+
julia> re([3, 5-im, 7+11im])
18+
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
1919
```
2020
"""
2121
function destructure(x)
@@ -27,7 +27,7 @@ end
2727
Restructure(Model, ..., length)
2828
2929
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30-
new parameters from vector `p`. If the model is callable, then `re(x, p)` .
30+
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
3131
3232
# Example
3333
```julia
@@ -107,22 +107,22 @@ end
107107

108108
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
109109
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
110-
_rebuild_back(dx) = (NoT, NoT, NoT, _accumulate!(x, dx, off, dflat))
110+
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat))
111111
_rebuild(x, off, flat; len), _rebuild_back
112112
end
113113

114114
# This is the gradient of model reconstruction, accumulating duplicates:
115-
function _accumulate!(x, dx, off, flat::AbstractVector)
115+
function _grad!(x, dx, off, flat::AbstractVector)
116116
x′, _ = functor(typeof(x), x)
117117
dx′, _ = functor(typeof(x), dx)
118118
off′, _ = functor(typeof(x), off)
119-
foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119+
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
120120
flat
121121
end
122-
function _accumulate!(x, dx, off::Integer, flat::AbstractVector)
122+
function _grad!(x, dx, off::Integer, flat::AbstractVector)
123123
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
124124
flat
125125
end
126-
_accumulate!(x, dx::Zero, off, flat::AbstractVector) = nothing
127-
_accumulate!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
126+
_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing
127+
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
128128

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
168168
@testset verbose=true "Destructure" begin
169169
include("destructure.jl")
170170
end
171-
@info "finished feature testing"
172171
@testset verbose=true "Optimisation Rules" begin
173172
include("rules.jl")
174173
end

0 commit comments

Comments
 (0)