Skip to content

Commit 9cba757

Browse files
authored
Merge pull request #108 from mschauer/tilde
Simplify for type stability
2 parents 1cb36f0 + 9f8f59d commit 9cba757

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ZigZagBoomerang"
22
uuid = "36347407-b186-4a6a-8c98-4f4567861712"
33
authors = ["Sebastiano Grazzi and Moritz Schauer"]
4-
version = "0.11.2"
4+
version = "0.11.3"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/not_fact_samplers.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -150,29 +150,31 @@ function pdmp(∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::Union{BouncyParticle, Boom
150150
end
151151

152152

153-
function pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c::Bound, abc, (t′, renew), τref, v, (acc, num),
154-
Flow::BouncyParticle, args...; subsample=false, oscn=false, factor=1.5, adapt=false)
153+
##################################
154+
155+
function pdmp_inner!(rng, dϕ::F1, ∇ϕ!::F2, ∇ϕx, t, x, θ, c::Bound, abc, (t′, renew), τref, (acc, num),
156+
Flow::BouncyParticle, args...; subsample=false, oscn=false, factor=1.5, adapt=false) where {F1, F2}
155157
while true
156158
if τref < t′
157-
t, x, θ = move_forward!(τref - t, t, x, θ, Flow)
159+
t, _ = move_forward!(τref - t, t, x, θ, Flow)
158160
refresh!(rng, θ, Flow)
159161
θdϕ, v = (t, x, θ, args...)
160162
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
161163
l = λ(θdϕ, Flow)
162164
τref = t + waiting_time_ref(rng, Flow)
163165
abc = ab(x, θ, c, θdϕ, v, Flow)
164166
t′, renew = next_time(t, abc, rand(rng))
165-
return t, x, θ, (acc, num), c, abc, (t′, renew), τref, v
167+
return t, (acc, num), c, abc, (t′, renew), τref
166168
elseif renew
167169
τ = t′ - t
168-
t, x, θ = move_forward!(τ, t, x, θ, Flow)
170+
t, _ = move_forward!(τ, t, x, θ, Flow)
169171
θdϕ, v = (t, x, θ, args...)
170172
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
171173
abc = ab(x, θ, c, θdϕ, v, Flow)
172174
t′, renew = next_time(t, abc, rand(rng))
173175
else
174176
τ = t′ - t
175-
t, x, θ = move_forward!(τ, t, x, θ, Flow)
177+
t, _ = move_forward!(τ, t, x, θ, Flow)
176178
θdϕ, v = (t, x, θ, args...)
177179
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
178180
l, lb = λ(θdϕ, Flow), pos(abc[1] + abc[2]*τ)
@@ -189,13 +191,13 @@ function pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c::Bound, abc, (t′, r
189191
@assert Flow.L == I
190192
oscn!(rng, θ, ∇ϕx, Flow.ρ; normalize=false)
191193
else
192-
θ = reflect!(∇ϕx, x, θ, Flow)
194+
reflect!(∇ϕx, x, θ, Flow)
193195
end
194196
θdϕ, v = (t, x, θ, args...)
195197
#∇ϕx = grad_correct!(∇ϕx, x, Flow)
196198
abc = ab(x, θ, c, θdϕ, v, Flow)
197199
t′, renew = next_time(t, abc, rand(rng))
198-
!subsample && return t, x, θ, (acc, num), c, abc, (t′, renew), τref, v
200+
!subsample && return t, (acc, num), c, abc, (t′, renew), τref
199201
else
200202
abc = ab(x, θ, c, θdϕ, v, Flow)
201203
t′, renew = next_time(t, abc, rand(rng))
@@ -204,7 +206,6 @@ function pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c::Bound, abc, (t′, r
204206
end
205207
end
206208
"""
207-
208209
pdmp(dϕ, ∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::BouncyParticle, args...; oscn=false, adapt=false, subsample=false, progress=false, progress_stops = 20, islocal = false, seed=Seed(), factor=2.0)
209210
210211
The first directional derivative `dϕ[1]` tells me if I move up or down the potential. The second directional derivative `dϕ[2]` tells me how fast the first changes. The gradient `∇ϕ!` tells me the surface I want to reflect on.
@@ -275,7 +276,7 @@ function pdmp(dϕ, ∇ϕ!, t0, x0, θ0, T, c::Bound, Flow::BouncyParticle, args.
275276

276277
t′, renew = next_time(t, abc, rand(rng))
277278
while t < T
278-
t, x, θ, (acc, num), c, abc, (t′, renew), τref, v = pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c, abc, (t′, renew), τref, v, (acc, num), Flow, args...; oscn=oscn, subsample=subsample, factor=factor, adapt=adapt)
279+
t, (acc, num), c, abc, (t′, renew), τref = pdmp_inner!(rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c, abc, (t′, renew), τref, (acc, num), Flow, args...; oscn=oscn, subsample=subsample, factor=factor, adapt=adapt)
279280
push!(Ξ, event(t, x, θ, Flow))
280281

281282
if t > tstop

src/notfactiter.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function rawevent(t, x, θ, Z::Union{BouncyParticle,Boomerang})
5454
t, x, θ, nothing
5555
end
5656

57+
######
5758

5859
function iterate(FS::NotFactSampler{<:Any, <:Tuple})
5960
t0, (x0, θ0) = FS.u0
@@ -71,16 +72,17 @@ function iterate(FS::NotFactSampler{<:Any, <:Tuple})
7172
abc = ab(x, θ, c, θdϕ, v, Flow)
7273

7374
t′, renew = next_time(t, abc, rand(rng))
74-
iterate(FS, ((t => (x, θ)), ∇ϕx, (acc, num), c, abc, (t′, renew), τref, v))
75+
iterate(FS, ((t => (x, θ)), ∇ϕx, (acc, num), c, abc, (t′, renew), τref))
7576
end
77+
using Test
7678

7779

78-
function iterate(FS::NotFactSampler{<:Any, <:Tuple}, (u, ∇ϕx, (acc, num), c, abc, (t′, renew), τref, v))
80+
function iterate(FS::NotFactSampler{<:Any, <:Tuple}, (u, ∇ϕx, (acc, num), c, abc, (t′, renew), τref))
7981
t, (x, θ) = u
8082
dϕ, ∇ϕ! = FS.∇ϕ![1], FS.∇ϕ![2]
81-
t, x, θ, (acc, num), c, abc, (t′, renew), τref, v = pdmp_inner!(FS.rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c, abc, (t′, renew), τref, v, (acc, num), FS.F, FS.args...; FS.kargs...)
83+
t, (acc, num), c, abc, (t′, renew), τref = pdmp_inner!(FS.rng, dϕ, ∇ϕ!, ∇ϕx, t, x, θ, c, abc, (t′, renew), τref, (acc, num), FS.F, FS.args...; FS.kargs...)
8284

8385
ev = rawevent(t, x, θ, FS.F)
8486
u = t => (x, θ)
85-
return ev, (u, ∇ϕx, (acc, num), c, abc, (t′, renew), τref, v)
87+
return ev, (u, ∇ϕx, (acc, num), c, abc, (t′, renew), τref)
8688
end

0 commit comments

Comments
 (0)