@@ -25,6 +25,43 @@ function (callback::WeightedParticleRecorderCallback)(
2525 return nothing
2626end
2727
28+ function gen_trajectory (
29+ rng:: Random.AbstractRNG ,
30+ model:: StateSpaceModel ,
31+ particles:: AbstractMatrix{T} , # Need better container
32+ log_weights:: AbstractMatrix{WT} ,
33+ forward_state,
34+ n_timestep:: Int ;
35+ kwargs...
36+ ) where {T,WT}
37+ trajectory = Vector {T} (undef, n_timestep)
38+ trajectory[end ] = forward_state
39+ for step in (n_timestep - 1 ): - 1 : 1
40+ backward_weights = backward (
41+ model,
42+ step,
43+ trajectory[step + 1 ],
44+ particles[step, :],
45+ log_weights[step, :];
46+ kwargs... ,
47+ )
48+ ancestor = rand (rng, Categorical (softmax (backward_weights)))
49+ trajectory[step] = particles[step, ancestor]
50+ end
51+ return trajectory
52+ end
53+
54+
55+ function backward (
56+ model:: StateSpaceModel , step:: Integer , state, particles:: T , log_weights:: WT ; kwargs...
57+ ) where {T,WT}
58+ transitions = map (particles) do prev_state
59+ SSMProblems. logdensity (model. dyn, step, prev_state, state; kwargs... )
60+ end
61+ return log_weights + transitions
62+ end
63+
64+
2865function sample (
2966 rng:: Random.AbstractRNG ,
3067 model:: StateSpaceModel{T,LDT} ,
@@ -34,6 +71,7 @@ function sample(
3471 callback= nothing ,
3572 kwargs... ,
3673) where {T,LDT,N}
74+
3775 n_timestep = length (obs)
3876 recorder = WeightedParticleRecorderCallback (
3977 Array {eltype(model.dyn)} (undef, n_timestep, N), Array {T} (undef, n_timestep, N)
@@ -46,28 +84,8 @@ function sample(
4684 trajectories = Array {eltype(model.dyn)} (undef, n_timestep, M)
4785
4886 trajectories[end , :] = particles. filtered[idx_ref]
49- for step in (n_timestep - 1 ): - 1 : 1
50- for j in 1 : M
51- backward_weights = backward (
52- model:: StateSpaceModel ,
53- step,
54- trajectories[step + 1 ],
55- recorder. particles[step, :],
56- recorder. log_weights[step, :];
57- kwargs... ,
58- )
59- ancestor = rand (rng, Categorical (softmax (backward_weights)))
60- trajectories[step, j] = recorder. particles[step, ancestor]
61- end
87+ for j in 1 : M
88+ trajectories[:, j] = gen_trajectory (rng, model, recorder. particles, recorder. log_weights, trajectories[end , j], n_timestep)
6289 end
6390 return trajectories
6491end
65-
66- function backward (
67- model:: StateSpaceModel , step:: Integer , state, particles:: T , log_weights:: WT ; kwargs...
68- ) where {T,WT}
69- transitions = map (
70- x -> SSMProblems. logdensity (model. dyn, step, x, state; kwargs... ), particles
71- )
72- return log_weights + transitions
73- end
0 commit comments