From 6a307cd31ce62667773dc7958927255bdb4699c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jun 2023 12:42:42 +0100 Subject: [PATCH] Bump DynamicPPL to 0.23 (#2001) * bump dppl test versions * also bump bijectors * bump AdvancedVI versions * revert Bijectors bump * bumped vi and bijectors too * breaking change * removed refernce to Bijectors.setadbackend * make use of DynamicPPL.make_evaluate_args_and_kwargs * bump DPPL version * bump DPPL version for tests * fixed bug in TracedModel * forgot to remove some lines * just drop the kwargs completely :( * Update container.jl * Update container.jl * will now error if we're using a model with kwargs and SMC * added reference to issue * added test for keyword models failing * make this a breaking change * made error message more informative * makde it slightly less informative * fixed typo in checking for TRaceModel * finally fixed the if-statement.. * Fix test error * fixed tests maybe * now fixed maybe * Update test/inference/Inference.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 5 ++-- src/essential/container.jl | 45 +++++++++++++---------------------- test/Project.toml | 2 +- test/inference/AdvancedSMC.jl | 6 +++++ test/inference/Inference.jl | 24 +++++++++++++++---- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index 3c062e1b7..11ee56215 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,7 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.25.3" +version = "0.26.0" + [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -47,7 +48,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.21.5, 0.22" +DynamicPPL = "0.23" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" diff --git a/src/essential/container.jl b/src/essential/container.jl index 19abca423..b68cb5b09 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -9,36 +9,21 @@ function TracedModel( model::Model, sampler::AbstractSampler, varinfo::AbstractVarInfo, - rng::Random.AbstractRNG -) + rng::Random.AbstractRNG, +) context = SamplingContext(rng, sampler, DefaultContext()) - evaluator = _get_evaluator(model, varinfo, context) - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator) -end - -# TODO: maybe move to DynamicPPL -@generated function _get_evaluator( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [ - :($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames - ] - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` - return quote - context_new = DynamicPPL.setleafcontext( - context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context)) - ) - (model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...)) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + if kwargs !== nothing && !isempty(kwargs) + error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.") end + return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( + model, + sampler, + varinfo, + (model.f, args...) + ) end - function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel}) newtask = copy(model.ctask) newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator)) @@ -73,10 +58,12 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} +function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} + # Extract the `args`. args = trace.model.ctask.args - _, _, container, = args - rng = container.rng + # From `args`, extract the `SamplingContext`, which contains the RNG. + sampling_context = args[3] + rng = sampling_context.rng trace.rng = rng return trace end diff --git a/test/Project.toml b/test/Project.toml index 790eb4f33..adf489a6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.21.5, 0.22" +DynamicPPL = "0.23" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" diff --git a/test/inference/AdvancedSMC.jl b/test/inference/AdvancedSMC.jl index c0bf738df..de6e65e40 100644 --- a/test/inference/AdvancedSMC.jl +++ b/test/inference/AdvancedSMC.jl @@ -173,6 +173,12 @@ end @test length(unique(c[:m])) == 1 @test length(unique(c[:s])) == 1 end + + # https://github.com/TuringLang/Turing.jl/issues/2007 + @turing_testset "keyword arguments not supported" begin + @model kwarg_demo(; x = 2) = return x + @test_throws ErrorException sample(kwarg_demo(), PG(1), 10) + end end # @testset "pmmh.jl" begin diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index bafe119bd..13658c3a4 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -259,11 +259,27 @@ return priors end - chain = sample(gauss2(; x=x), PG(10), 10) - chain = sample(gauss2(; x=x), SMC(), 10) + @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) + @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) - chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) - chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) + @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) + @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) + + @model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV} + priors = TV(undef, 2) + priors[1] ~ InverseGamma(2, 3) # s + priors[2] ~ Normal(0, sqrt(priors[1])) # m + for i in 1:length(x) + x[i] ~ Normal(priors[2], sqrt(priors[1])) + end + return priors + end + + chain = sample(gauss3(x), PG(10), 10) + chain = sample(gauss3(x), SMC(), 10) + + chain = sample(gauss3(x, Vector{Real}), PG(10), 10) + chain = sample(gauss3(x, Vector{Real}), SMC(), 10) end @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]