Skip to content

Commit

Permalink
add compatibility with Turing.Experimental.Gibbs
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Aug 12, 2024
1 parent 2e2efed commit c2ec380
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SliceSampling"
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
version = "0.5.0"
version = "0.6.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -31,7 +31,7 @@ LogDensityProblemsAD = "1"
Random = "1"
Requires = "1"
SimpleUnPack = "1"
Turing = "0.31, 0.32, 0.33"
Turing = "0.33"
julia = "1.7"

[extras]
Expand Down
33 changes: 33 additions & 0 deletions docs/src/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,39 @@ model = demo()
sample(model, externalsampler(sampler), n_samples)
```

### Conditional sampling in a `Turing.Experimental.Gibbs` sampler
`SliceSampling.jl` be used as a conditional sampler in `Turing.Experimental.Gibbs`.

```@example turinggibbs
using Distributions
using FillArrays
using Turing
using SliceSampling
@model function simple_choice(xs)
p ~ Beta(2, 2)
z ~ Bernoulli(p)
for i in 1:length(xs)
if z == 1
xs[i] ~ Normal(0, 1)
else
xs[i] ~ Normal(2, 1)
end
end
end
sampler = Turing.Experimental.Gibbs(
(
p = externalsampler(SliceSteppingOut(2.0)),
z = PG(20, :z)
)
)
n_samples = 1000
model = simple_choice([1.5, 2.0, 0.3])
sample(model, sampler, n_samples)
```

## Drawing Samples
For drawing samples using the algorithms provided by `SliceSampling`, the user only needs to call:
```julia
Expand Down
39 changes: 37 additions & 2 deletions ext/SliceSamplingTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,53 @@ if isdefined(Base, :get_extension)
using LogDensityProblemsAD
using Random
using SliceSampling
using Turing: Turing
using Turing
# using Turing: Turing, Experimental
else
using ..LogDensityProblemsAD
using ..Random
using ..SliceSampling
using ..Turing: Turing
using ..Turing
#using ..Turing: Turing, Experimental
end

# Required for using the slice samplers as `externalsampler`s in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
sample::SliceSampling.Transition
) = sample.params
# end

# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.UnivariateSliceState
) = state.transition.params

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.GibbsState
) = state.transition.params

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.HitAndRunState
) = state.transition.params

Turing.Experimental.gibbs_requires_recompute_logprob(
model_dst,
::Turing.DynamicPPL.Sampler{
<: Turing.Inference.ExternalSampler{
<: SliceSampling.AbstractSliceSampling, A, U
}
},
sampler_src,
state_dst,
state_src
) where {A,U} = false
# end

function SliceSampling.initial_sample(
rng::Random.AbstractRNG,
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ MCMCTesting = "0.3"
Random = "1"
StableRNGs = "1"
Test = "1"
Turing = "0.31"
Turing = "0.33"
julia = "1.6"
21 changes: 21 additions & 0 deletions test/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,25 @@
progress=false,
)
end

@testset "gibbs($sampler)" for sampler in [
RandPermGibbs(Slice(1)),
RandPermGibbs(SliceSteppingOut(1)),
RandPermGibbs(SliceDoublingOut(1)),
Slice(1),
SliceSteppingOut(1),
SliceDoublingOut(1),
]
sample(
model,
Turing.Experimental.Gibbs(
(
s = externalsampler(sampler),
m = externalsampler(sampler),
),
),
n_samples,
progress=false,
)
end
end

0 comments on commit c2ec380

Please sign in to comment.