Skip to content

Commit 347d001

Browse files
authored
Merge pull request #100 from mschauer/newversion
Make Bouncy Particle sampler FAST
2 parents 3b0f81d + 9ec6813 commit 347d001

11 files changed

+518
-310
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ZigZagBoomerang"
22
uuid = "36347407-b186-4a6a-8c98-4f4567861712"
33
authors = ["Sebastiano Grazzi and Moritz Schauer"]
4-
version = "0.11.0"
4+
version = "0.11.1"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/ZigZagBoomerang.jl

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Seed() = gen_seed(UInt64, 2)
1313
# ZigZag1d and Boomerang1d reference implementation
1414
include("types.jl")
1515
include("common.jl")
16+
include("oscn.jl")
1617
include("dynamics.jl")
1718
export ZigZag1d, Boomerang1d, ZigZag, FactBoomerang
1819
const LocalZigZag = ZigZag

src/dynamics.jl

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ function refresh!(rng, θ, F::BouncyParticle)
117117
θ .+= u
118118
θ
119119
end
120+
120121
function refresh!(rng, θ, F::Boomerang)
121122
ρ̄ = sqrt(1-F.ρ^2)
122123
θ .*= F.ρ

src/not_fact_samplers.jl

+146-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ function grad_correct!(y, x, F::Boomerang)
1111
y .-= (F.L'\(F.L\(x - F.μ)))
1212
y
1313
end
14-
λ(∇ϕx, θ, F::Union{BouncyParticle, Boomerang}) = pos(dot(∇ϕx, θ))
15-
14+
λ(∇ϕx::AbstractVector, θ, F::Union{BouncyParticle, Boomerang}) = pos(dot(∇ϕx, θ))
15+
λ(θdϕ::Number, F::Union{BouncyParticle, Boomerang}) = pos(θdϕ)
1616
#=
1717
function refresh!(rng, θ, F::BouncyParticle)
1818
ρ̄ = sqrt(1-F.ρ^2)
@@ -26,9 +26,13 @@ end
2626
function ab(x, θ, C::GlobalBound, ∇ϕx, v, B::BouncyParticle)
2727
(C.c + θ'*(B.Γ*(x-B.μ)), θ'*(B.Γ*θ), Inf)
2828
end
29-
function ab(x, θ, C::LocalBound, ∇ϕx, v, B::BouncyParticle)
29+
function ab(x, θ, C::LocalBound, ∇ϕx::AbstractVector, v, B::BouncyParticle)
3030
(C.c + dot(θ, ∇ϕx), v, 2.0/C.c/norm(θ, Inf))
3131
end
32+
function ab(x, θ, C::LocalBound, vdϕ::Number, v, B::BouncyParticle)
33+
@assert vdϕ isa Number
34+
(C.c + vdϕ, v, 2.0/C.c/norm(θ, Inf))
35+
end
3236

3337
function ab(x, θ, C::GlobalBound, ∇ϕx, v, B::Boomerang)
3438
(sqrt(normsq(θ) + normsq((x - B.μ)))*C.c, 0.0, Inf)
@@ -83,6 +87,7 @@ function pdmp_inner!(rng, ∇ϕ!, ∇ϕx, t, x, θ, c::Bound, abc, (t′, renew)
8387
end
8488
θ = reflect!(∇ϕx, x, θ, Flow)
8589
∇ϕx, v = ∇ϕ!(∇ϕx, t, x, θ, args...)
90+
∇ϕx = grad_correct!(∇ϕx, x, Flow)
8691
abc = ab(x, θ, c, ∇ϕx, v, Flow)
8792
t′, renew = next_time(t, abc, rand(rng))
8893
!subsample && return t, x, θ, (acc, num), c, abc, (t′, renew), τref, v
@@ -144,6 +149,144 @@ function pdmp(∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::Union{BouncyParticle, Boom
144149
return Ξ, (t, x, θ), (acc, num), c
145150
end
146151

152+
153+
function pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c::Bound, abc, (t′, renew), τref, v, (acc, num),
154+
Flow::BouncyParticle, args...; subsample=false, oscn=false, factor=1.5, adapt=false)
155+
while true
156+
if τref < t′
157+
t, x, θ = move_forward!(τref - t, t, x, θ, Flow)
158+
refresh!(rng, θ, Flow)
159+
θdϕ, v = (t, x, θ, args...)
160+
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
161+
l = λ(θdϕ, Flow)
162+
τref = t + waiting_time_ref(rng, Flow)
163+
abc = ab(x, θ, c, θdϕ, v, Flow)
164+
t′, renew = next_time(t, abc, rand(rng))
165+
return t, x, θ, (acc, num), c, abc, (t′, renew), τref, v
166+
elseif renew
167+
τ = t′ - t
168+
t, x, θ = move_forward!(τ, t, x, θ, Flow)
169+
θdϕ, v = (t, x, θ, args...)
170+
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
171+
abc = ab(x, θ, c, θdϕ, v, Flow)
172+
t′, renew = next_time(t, abc, rand(rng))
173+
else
174+
τ = t′ - t
175+
t, x, θ = move_forward!(τ, t, x, θ, Flow)
176+
θdϕ, v = (t, x, θ, args...)
177+
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
178+
l, lb = λ(θdϕ, Flow), pos(abc[1] + abc[2]*τ)
179+
num += 1
180+
if rand(rng)*lb <= l
181+
acc += 1
182+
if l > lb
183+
!adapt && error("Tuning parameter `c` too small.")
184+
c *= factor
185+
end
186+
∇ϕ!(∇ϕx, t, x, args...)
187+
@assert dot(θ, ∇ϕx) θdϕ
188+
if oscn
189+
@assert Flow.L == I
190+
oscn!(rng, θ, ∇ϕx, Flow.ρ; normalize=false)
191+
else
192+
θ = reflect!(∇ϕx, x, θ, Flow)
193+
end
194+
θdϕ, v = (t, x, θ, args...)
195+
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
196+
abc = ab(x, θ, c, θdϕ, v, Flow)
197+
t′, renew = next_time(t, abc, rand(rng))
198+
!subsample && return t, x, θ, (acc, num), c, abc, (t′, renew), τref, v
199+
else
200+
abc = ab(x, θ, c, θdϕ, v, Flow)
201+
t′, renew = next_time(t, abc, rand(rng))
202+
end
203+
end
204+
end
205+
end
206+
"""
207+
208+
pdmp(dϕ, ∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::BouncyParticle, args...; oscn=false, adapt=false, subsample=false, progress=false, progress_stops = 20, islocal = false, seed=Seed(), factor=2.0)
209+
210+
The first directional derivative `dϕ[1]` tells me if I move up or down the potential. The second directional derivative `dϕ[2]` tells me how fast the first changes. The gradient `∇ϕ!` tells me the surface I want to reflect on.
211+
212+
dϕ = function (t, x, v, args...) # two directional derivatives
213+
u = ForwardDiff.derivative(t -> -ℓ(x + t*v), Dual{:hSrkahPmmC}(0.0, 1.0))
214+
u.value, u.partials[]
215+
end
216+
∇ϕ! = function (y, t, x, args...)
217+
ForwardDiff.gradient!(y, ℓ, x)
218+
y .= -y
219+
y
220+
end
221+
222+
The remaining arguments:
223+
224+
d = 25 # number of parameters
225+
t0 = 0.0
226+
x0 = zeros(d) # starting point sampler
227+
θ0 = randn(d) # starting direction sampler
228+
T = 200. # end time (similar to number of samples in MCMC)
229+
c = 50.0 # initial guess for the bound
230+
231+
# define BouncyParticle sampler (has two relevant parameters)
232+
Z = BouncyParticle(∅, ∅, # ignored
233+
10.0, # momentum refreshment rate
234+
0.95, # momentum correlation / only gradually change momentum in refreshment/momentum update
235+
0.0, # ignored
236+
I # left cholesky factor of momentum precision
237+
)
238+
239+
trace, final, (acc, num), cs = @time pdmp(
240+
dneglogp, # return first two directional derivatives of negative target log-likelihood in direction v
241+
∇neglogp!, # return gradient of negative target log-likelihood
242+
t0, x0, θ0, T, # initial state and duration
243+
ZZB.LocalBound(c), # use Hessian information
244+
Z; # sampler
245+
oscn=false, # no orthogonal subspace pCR
246+
adapt=true, # adapt bound c
247+
progress=true, # show progress bar
248+
subsample=true # keep only samples at refreshment times
249+
)
250+
251+
# to obtain direction change times and points of piecewise linear trace
252+
t, x = ZigZagBoomerang.sep(trace)
253+
254+
"""
255+
function pdmp(dϕ, ∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::BouncyParticle, args...; oscn=false, adapt=false, subsample=false, progress=false, progress_stops = 20, islocal = false, seed=Seed(), factor=2.0)
256+
t, x, θ, ∇ϕx = t0, copy(x0), copy(θ0), copy(θ0)
257+
rng = Rng(seed)
258+
Ξ = Trace(t0, x0, θ0, Flow)
259+
τref = waiting_time_ref(rng, Flow)
260+
θdϕ, v = (t, x, θ, args...)
261+
#@assert v2 ≈ v
262+
#@assert θdϕ ≈ dot(∇ϕx, θ)
263+
264+
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
265+
num = acc = 0
266+
#l = 0.0
267+
abc = ab(x, θ, c, θdϕ, v, Flow)
268+
if progress
269+
prg = Progress(progress_stops, 1)
270+
else
271+
prg = missing
272+
end
273+
stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
274+
tstop = T/stops
275+
276+
t′, renew = next_time(t, abc, rand(rng))
277+
while t < T
278+
t, x, θ, (acc, num), c, abc, (t′, renew), τref, v = pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c, abc, (t′, renew), τref, v, (acc, num), Flow, args...; oscn=oscn, subsample=subsample, factor=factor, adapt=adapt)
279+
push!(Ξ, event(t, x, θ, Flow))
280+
281+
if t > tstop
282+
tstop += T/stops
283+
next!(prg)
284+
end
285+
end
286+
ismissing(prg) || ProgressMeter.finish!(prg)
287+
return Ξ, (t, x, θ), (acc, num), c
288+
end
289+
147290
wrap(f) = wrap_(f, methods(f)...)
148291
wrap_(f, args...) = f
149292
@inline wrap_(f, m) = m.nargs <= 4 ? Wrapper(f) : f

src/oscn.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
oscn!(rng, v, ∇ψx, ρ; normalize=false)
3+
4+
Orthogonal subspace Crank-Nicolson step with autocorrelation `ρ` for
5+
standard Gaussian or Uniform on the sphere (`normalize = true`).
6+
"""
7+
function oscn!(rng, v, ∇ψx, ρ; normalize=false)
8+
# Decompose v
9+
vₚ = (dot(v, ∇ψx)/normsq(∇ψx))*∇ψx
10+
v⊥ = ρ*(v - vₚ)
11+
if ρ == 1
12+
@. v = v - 2vₚ
13+
else
14+
# Sample and project
15+
z = randn!(rng, similar(v)) * (1.0f0 - ρ^2)
16+
z -= (dot(z, ∇ψx)/dot(∇ψx, ∇ψx))*∇ψx
17+
if normalize
18+
λ = sqrt(1 - norm(vₚ)^2)/norm(v⊥ + z)
19+
@. v = -vₚ + λ*(v⊥ + z)
20+
else
21+
@. v = -vₚ + v⊥ + z
22+
end
23+
end
24+
v
25+
end

0 commit comments

Comments
 (0)