Skip to content

Commit 7e4a756

Browse files
authored
Merge pull request #119 from mschauer/mass
test old and new mass matrix
2 parents 7cc11af + 49241d9 commit 7e4a756

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/not_fact_samplers.jl

+14-1
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,27 @@ function ZigZagBoomerang.reflect!(∇ϕx, t, x, v, F::BouncyParticle{<:Any, <:An
158158
v .-= (2*dot(∇ϕx, v)/dot(∇ϕx, z)) * z
159159
v
160160
end
161+
function reflect!(∇ϕx, t, x, θ, F::BouncyParticle)
162+
θ .-= (2*dot(∇ϕx, θ)/normsq(F.L\∇ϕx))*(F.L'\(F.L\∇ϕx))
163+
θ
164+
end
161165
function ZigZagBoomerang.refresh!(rng, t, x, v, F::BouncyParticle{<:Any, <:Any, <:Any, <:AbstractPDMat})
162-
ρ̄ = sqrt(1-F.ρ^2)
166+
ρ̄ = sqrt(1 - F.ρ^2)
163167
v .*= F.ρ
164168
s = local_speed(t, x, v, F)
165169
u = (s*ρ̄)*PDMats.unwhiten(F.U, randn(rng, length(v)))
166170
v .+= u
167171
record_rate(v, F)
168172
end
173+
function ZigZagBoomerang.refresh!(rng, t, x, v, F::BouncyParticle)
174+
ρ̄ = sqrt(1 - F.ρ^2)
175+
v .*= F.ρ
176+
s = local_speed(t, x, v, F)
177+
u = (s*ρ̄)*(F.L'\randn(rng, length(v)))
178+
v .+= u
179+
record_rate(v, F)
180+
end
181+
169182
function mass_adapt_init(M::InvChol)
170183
Cholesky(M.R)
171184
end

test/maintest.jl

+36-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ end
176176
t0 = 0.0
177177
θ0 = randn(d)
178178
x0 = randn(d)
179-
M = UpperTriangular(I + 0.4randn(d,d))
179+
M = UpperTriangular(I + 0.4randn(d,d)')
180180
c = 20.0
181181
B = BouncyParticle(missing, missing, # ignored
182182
1.0, # momentum refreshment rate
@@ -206,6 +206,41 @@ end
206206
@test mean(abs.(cov(xs) - inv(Matrix(Γ)))) < 3/sqrt(length(ts))
207207
end
208208

209+
@testset "Bouncy Particle Sampler (arbitrary mass matrix 2)" begin
210+
Random.seed!(2)
211+
t0 = 0.0
212+
θ0 = randn(d)
213+
x0 = randn(d)
214+
M = LowerTriangular(I + 0.4randn(d,d))
215+
c = 20.0
216+
B = BouncyParticle(missing, missing, # ignored
217+
1.0, # momentum refreshment rate
218+
0.9, # momentum correlation / only gradually change momentum in refreshment/momentum update
219+
missing, # metric
220+
M
221+
)
222+
223+
∇ϕ!(y, t, x, args...) = mul!(y, Γ, x)
224+
(t, x, v, args...) = dot(v, Γ, x), dot(v, Γ, v)
225+
n = 800
226+
trace, _, acc, more = @time pdmp(
227+
dϕ, # return first two directional derivatives of negative target log-likelihood in direction v
228+
∇ϕ!, # return gradient of negative target log-likelihood
229+
t0, x0, θ0, # initial state and duration
230+
n, # number of samples
231+
ZigZagBoomerang.LocalBound(c), # use Hessian information
232+
B; # sampler
233+
adapt=false, # adapt bound c
234+
progress=true, # show progress bar
235+
)
236+
@show more
237+
@show acc[1]/acc[2]
238+
ts, xs = sep(trace)
239+
@show length(ts)
240+
@test mean(abs.(mean(xs))) < 3/sqrt(length(ts))
241+
@test mean(abs.(cov(xs) - inv(Matrix(Γ)))) < 3/sqrt(length(ts))
242+
end
243+
209244
@testset "Bouncy Particle Sampler (adapted mass matrix)" begin
210245
Random.seed!(2)
211246
t0 = 0.0

0 commit comments

Comments
 (0)