diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 24226ab4b1..e442ef665d 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -13,9 +13,9 @@ end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - function batchnorm_pullback(Δ) - grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) - (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) + function batchnorm_pullback(Δ, σ²Δ) + grad, σ²grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) + (NoTangent(), grad..., NoTangent(), NoTangent(), σ²grad..., NoTangent()) end y, batchnorm_pullback end