Skip to content

Commit 2cb12ef

Browse files
Weigghts:
1 parent 2d9563d commit 2cb12ef

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

src/algorithms/apf.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export AuxiliaryParticleFilter, APF
22

3-
mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
4-
N::Integer
3+
mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <:
4+
AbstractParticleFilter{N}
55
resampler::RS
66
aux::Vector # Auxiliary weights
77
end
@@ -10,20 +10,22 @@ function AuxiliaryParticleFilter(
1010
N::Integer; threshold::Real=0.0, resampler::AbstractResampler=Systematic()
1111
)
1212
conditional_resampler = ESSResampler(threshold, resampler)
13-
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
13+
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(
14+
conditional_resampler, zeros(N)
15+
)
1416
end
1517

1618
const APF = AuxiliaryParticleFilter
1719

1820
function initialise(
1921
rng::AbstractRNG,
2022
model::StateSpaceModel{T},
21-
filter::AuxiliaryParticleFilter;
23+
filter::AuxiliaryParticleFilter{N};
2224
ref_state::Union{Nothing,AbstractVector}=nothing,
2325
kwargs...,
24-
) where {T}
25-
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
26-
initial_weights = fill(-log(T(filter.N)), filter.N)
26+
) where {T,N}
27+
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
28+
initial_weights = fill(-log(T(N)), N)
2729

2830
return update_ref!(
2931
ParticleContainer(initial_states, initial_weights), ref_state, filter
@@ -75,12 +77,12 @@ end
7577

7678
function update(
7779
model::StateSpaceModel{T},
78-
filter::AuxiliaryParticleFilter,
80+
filter::AuxiliaryParticleFilter{N},
7981
step::Integer,
8082
states::ParticleContainer,
8183
observation;
8284
kwargs...,
83-
) where {T}
85+
) where {T,N}
8486
@debug "step $step"
8587
log_increments = map(
8688
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
@@ -90,7 +92,7 @@ function update(
9092
states.filtered.log_weights = states.proposed.log_weights + log_increments
9193
states.filtered.particles = states.proposed.particles
9294

93-
return (states, logsumexp(log_increments) - log(T(filter.N)))
95+
return (states, logsumexp(log_increments) - log(T(N)))
9496
end
9597

9698
function step(

src/algorithms/bootstrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function update(
7171
end
7272

7373
function reset_weights!(
74-
state::ParticleState{T,WT}, idxs, filter::BootstrapFilter{N}
74+
state::ParticleState{T,WT}, idxs, ::BootstrapFilter{N}
7575
) where {T,WT<:Real,N}
7676
fill!(state.log_weights, -log(WT(N)))
7777
return state

src/algorithms/ffbs.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function (callback::WeightedParticleRecorderCallback)(
2525
return nothing
2626
end
2727

28-
function smooth(
28+
function sample(
2929
rng::Random.AbstractRNG,
3030
model::StateSpaceModel{T,LDT},
3131
alg::FFBS{<:BootstrapFilter{N}},
@@ -40,21 +40,34 @@ function smooth(
4040
)
4141

4242
particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...)
43+
44+
# Backward sampling - exact
4345
idx_ref = rand(rng, Categorical(weights(particles.filtered)), M)
4446
trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M)
4547

4648
trajectories[end, :] = particles.filtered[idx_ref]
4749
for step in (n_timestep - 1):-1:1
4850
for j in 1:M
49-
transitions = map(
50-
x ->
51-
SSMProblems.logdensity(model.dyn, step, x, trajectories[step+1]; kwargs...),
51+
backward_weights = backward(
52+
model::StateSpaceModel,
53+
step,
54+
trajectories[step + 1],
5255
recorder.particles[step, :],
56+
recorder.log_weights[step, :];
57+
kwargs...,
5358
)
54-
backward_weights = recorder.log_weights[step, :] + transitions
5559
ancestor = rand(rng, Categorical(softmax(backward_weights)))
5660
trajectories[step, j] = recorder.particles[step, ancestor]
5761
end
5862
end
5963
return trajectories
6064
end
65+
66+
function backward(
67+
model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs...
68+
) where {T,WT}
69+
transitions = map(
70+
x -> SSMProblems.logdensity(model.dyn, step, x, state; kwargs...), particles
71+
)
72+
return log_weights + transitions
73+
end

0 commit comments

Comments
 (0)