Skip to content

Commit

Permalink
Bump DynamicPPL to 0.23 (#2001)
Browse files Browse the repository at this point in the history
* bump dppl test versions

* also bump bijectors

* bump AdvancedVI versions

* revert Bijectors bump

* bumped vi and bijectors too

* breaking change

* removed refernce to Bijectors.setadbackend

* make use of DynamicPPL.make_evaluate_args_and_kwargs

* bump DPPL version

* bump DPPL version for tests

* fixed bug in TracedModel

* forgot to remove some lines

* just drop the kwargs completely :(

* Update container.jl

* Update container.jl

* will now error if we're using a model with kwargs and SMC

* added reference to issue

* added test for keyword models failing

* make this a breaking change

* made error message more informative

* makde it slightly less informative

* fixed typo in checking for TRaceModel

* finally fixed the if-statement..

* Fix test error

* fixed tests maybe

* now fixed maybe

* Update test/inference/Inference.jl

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jun 15, 2023
1 parent a38c709 commit 6a307cd
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.25.3"
version = "0.26.0"


[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -47,7 +48,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
45 changes: 16 additions & 29 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,21 @@ function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG
)
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
evaluator = _get_evaluator(model, varinfo, context)
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
end

# TODO: maybe move to DynamicPPL
@generated function _get_evaluator(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
]
# We want to give `context` precedence over `model.context` while also
# preserving the leaf context of `context`. We can do this by
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
# 2. Set leaf context of `context` to the context resulting from (1).
# The result is:
# `context` -> `childcontext(context)` -> ... -> `model.context`
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
return quote
context_new = DynamicPPL.setleafcontext(
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
)
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
if kwargs !== nothing && !isempty(kwargs)
error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.")
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(model.f, args...)
)
end


function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
Expand Down Expand Up @@ -73,10 +58,12 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
return trace
end

function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
# Extract the `args`.
args = trace.model.ctask.args
_, _, container, = args
rng = container.rng
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
6 changes: 6 additions & 0 deletions test/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ end
@test length(unique(c[:m])) == 1
@test length(unique(c[:s])) == 1
end

# https://github.com/TuringLang/Turing.jl/issues/2007
@turing_testset "keyword arguments not supported" begin
@model kwarg_demo(; x = 2) = return x
@test_throws ErrorException sample(kwarg_demo(), PG(1), 10)
end
end

# @testset "pmmh.jl" begin
Expand Down
24 changes: 20 additions & 4 deletions test/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,27 @@
return priors
end

chain = sample(gauss2(; x=x), PG(10), 10)
chain = sample(gauss2(; x=x), SMC(), 10)
@test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10)
@test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10)

chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)

@model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV}
priors = TV(undef, 2)
priors[1] ~ InverseGamma(2, 3) # s
priors[2] ~ Normal(0, sqrt(priors[1])) # m
for i in 1:length(x)
x[i] ~ Normal(priors[2], sqrt(priors[1]))
end
return priors
end

chain = sample(gauss3(x), PG(10), 10)
chain = sample(gauss3(x), SMC(), 10)

chain = sample(gauss3(x, Vector{Real}), PG(10), 10)
chain = sample(gauss3(x, Vector{Real}), SMC(), 10)
end
@testset "new interface" begin
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
Expand Down

2 comments on commit 6a307cd

@yebai
Copy link
Member

@yebai yebai commented on 6a307cd Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85652

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.26.0 -m "<description of version>" 6a307cd31ce62667773dc7958927255bdb4699c8
git push origin v0.26.0

Please sign in to comment.