From c2ec380f8ea89f407cd68a01e2fa530e58fac1a4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Aug 2024 21:37:02 +0100 Subject: [PATCH] add compatibility with `Turing.Experimental.Gibbs` --- Project.toml | 4 ++-- docs/src/general.md | 33 +++++++++++++++++++++++++++++ ext/SliceSamplingTuringExt.jl | 39 +++++++++++++++++++++++++++++++++-- test/Project.toml | 2 +- test/turing.jl | 21 +++++++++++++++++++ 5 files changed, 94 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index dfecbd6..fc2aac2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SliceSampling" uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf" -version = "0.5.0" +version = "0.6.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -31,7 +31,7 @@ LogDensityProblemsAD = "1" Random = "1" Requires = "1" SimpleUnPack = "1" -Turing = "0.31, 0.32, 0.33" +Turing = "0.33" julia = "1.7" [extras] diff --git a/docs/src/general.md b/docs/src/general.md index 39aa4a5..d238cef 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -60,6 +60,39 @@ model = demo() sample(model, externalsampler(sampler), n_samples) ``` +### Conditional sampling in a `Turing.Experimental.Gibbs` sampler +`SliceSampling.jl` be used as a conditional sampler in `Turing.Experimental.Gibbs`. + +```@example turinggibbs +using Distributions +using FillArrays +using Turing +using SliceSampling + +@model function simple_choice(xs) + p ~ Beta(2, 2) + z ~ Bernoulli(p) + for i in 1:length(xs) + if z == 1 + xs[i] ~ Normal(0, 1) + else + xs[i] ~ Normal(2, 1) + end + end +end + +sampler = Turing.Experimental.Gibbs( + ( + p = externalsampler(SliceSteppingOut(2.0)), + z = PG(20, :z) + ) +) + +n_samples = 1000 +model = simple_choice([1.5, 2.0, 0.3]) +sample(model, sampler, n_samples) +``` + ## Drawing Samples For drawing samples using the algorithms provided by `SliceSampling`, the user only needs to call: ```julia diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index 08a4243..213c281 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -5,18 +5,53 @@ if isdefined(Base, :get_extension) using LogDensityProblemsAD using Random using SliceSampling - using Turing: Turing + using Turing + # using Turing: Turing, Experimental else using ..LogDensityProblemsAD using ..Random using ..SliceSampling - using ..Turing: Turing + using ..Turing + #using ..Turing: Turing, Experimental end +# Required for using the slice samplers as `externalsampler`s in Turing +# begin Turing.Inference.getparams( ::Turing.DynamicPPL.Model, sample::SliceSampling.Transition ) = sample.params +# end + +# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing +# begin +Turing.Inference.getparams( + ::Turing.DynamicPPL.Model, + state::SliceSampling.UnivariateSliceState +) = state.transition.params + +Turing.Inference.getparams( + ::Turing.DynamicPPL.Model, + state::SliceSampling.GibbsState +) = state.transition.params + +Turing.Inference.getparams( + ::Turing.DynamicPPL.Model, + state::SliceSampling.HitAndRunState +) = state.transition.params + +Turing.Experimental.gibbs_requires_recompute_logprob( + model_dst, + ::Turing.DynamicPPL.Sampler{ + <: Turing.Inference.ExternalSampler{ + <: SliceSampling.AbstractSliceSampling, A, U + } + }, + sampler_src, + state_dst, + state_src +) where {A,U} = false +# end function SliceSampling.initial_sample( rng::Random.AbstractRNG, diff --git a/test/Project.toml b/test/Project.toml index 1d423cd..85d8c2a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,5 +18,5 @@ MCMCTesting = "0.3" Random = "1" StableRNGs = "1" Test = "1" -Turing = "0.31" +Turing = "0.33" julia = "1.6" diff --git a/test/turing.jl b/test/turing.jl index 7705bdc..6df11e5 100644 --- a/test/turing.jl +++ b/test/turing.jl @@ -35,4 +35,25 @@ progress=false, ) end + + @testset "gibbs($sampler)" for sampler in [ + RandPermGibbs(Slice(1)), + RandPermGibbs(SliceSteppingOut(1)), + RandPermGibbs(SliceDoublingOut(1)), + Slice(1), + SliceSteppingOut(1), + SliceDoublingOut(1), + ] + sample( + model, + Turing.Experimental.Gibbs( + ( + s = externalsampler(sampler), + m = externalsampler(sampler), + ), + ), + n_samples, + progress=false, + ) + end end