diff --git a/Project.toml b/Project.toml index 6433bd3c8..875475577 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.24.3" +version = "0.24.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 5cce69de4..2b8cdff9a 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -159,7 +159,14 @@ function DynamicPPL.initialstep( metricT = getmetricT(spl.alg) metric = metricT(length(theta)) ℓ = LogDensityProblemsAD.ADgradient( - Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + Turing.LogDensityFunction( + vi, + model, + # Use the leaf-context from the `model` in case the user has + # contextualized the model with something like `PriorContext` + # to sample from the prior. + DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)) + ) ) logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) ∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x) @@ -265,7 +272,11 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ℓ = LogDensityProblemsAD.ADgradient( - Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + Turing.LogDensityFunction( + vi, + model, + DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)) + ) ) ℓπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) ∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ) @@ -538,7 +549,12 @@ function HMCState( # Get the initial log pdf and gradient functions. ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + logπ = Turing.LogDensityFunction( + vi, + model, + DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)) + ) + # Get the metric type. metricT = getmetricT(spl.alg) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 66e2f68f8..ddbeaa2c5 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -375,7 +375,14 @@ function propose!!( # Make a new transition. densitymodel = AMH.DensityModel( - Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl))) + Base.Fix1( + LogDensityProblems.logdensity, + Turing.LogDensityFunction( + vi, + model, + DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)) + ) + ) ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) @@ -403,7 +410,14 @@ function propose!!( # Make a new transition. densitymodel = AMH.DensityModel( - Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl))) + Base.Fix1( + LogDensityProblems.logdensity, + Turing.LogDensityFunction( + vi, + model, + DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)) + ) + ) ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index d3c427d39..535353c43 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -216,4 +216,11 @@ res3 = sample(StableRNG(123), gdemo_default, alg, 1000) @test Array(res1) == Array(res2) == Array(res3) end + + @turing_testset "prior" begin + alg = NUTS(1000, 0.8) + gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext()) + chain = sample(gdemo_default_prior, alg, 10_000) + check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.2) + end end diff --git a/test/inference/mh.jl b/test/inference/mh.jl index dc9628b6e..8e52aec9b 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -216,4 +216,16 @@ vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) @test !DynamicPPL.islinked(vi, spl) end + + @turing_testset "prior" begin + # HACK: MH can be so bad for this prior model for some reason that it's difficult to + # find a non-trivial `atol` where the tests will pass for all seeds. Hence we fix it :/ + rng = StableRNG(10) + alg = MH() + gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext()) + burnin = 10_000 + n = 10_000 + chain = sample(rng, gdemo_default_prior, alg, n; discard_initial = burnin, thinning=10) + check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3) + end end