Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: strip most types from gradient output #1362

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mcabbott
Copy link
Contributor

This is a draft of a way to start addressing #1334, for comment.

It implements what I called level 2 here: #1334 (comment)

On arrays like these, no change. Natural and structural representations agree:

julia> Enzyme.gradient(Reverse, first, Diagonal([1,2.]))
2×2 Diagonal{Float64, Vector{Float64}}:
 1.0    
     0.0

julia> using SparseArrays, StaticArrays

julia> Enzyme.gradient(Reverse, sum, sparse([5 0 6.]))
1×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
 1.0      1.0

julia> Enzyme.gradient(Reverse, sum, PermutedDimsArray(sparse([1 2; 3 0.]), (2,1)))
2×2 PermutedDimsArray(::SparseMatrixCSC{Float64, Int64}, (2, 1)) with eltype Float64:
 1.0  1.0
 1.0  0.0

julia> Enzyme.gradient(Reverse, first, reshape(SA[1,2,3,4.]',2,2))
2×2 reshape(adjoint(::SVector{4, Float64}), 2, 2) with eltype Float64:
 1.0  0.0
 0.0  0.0

On arrays like these, it does not know how to construct the natural representation, so doesn't try:
(I know how, but the fields of the result will not line up with the existing ones.)

julia> Enzyme.gradient(Reverse, sum, Symmetric(rand(3,3)))
(data = [1.0 2.0 2.0; 0.0 1.0 2.0; 0.0 0.0 1.0], uplo = nothing)

julia> Enzyme.gradient(Reverse, first, reshape(LinRange(1,2,4)',2,2))
(parent = (parent = (start = 1.0, stop = 0.0, len = nothing, lendiv = nothing),), dims = (nothing, nothing), mi = ())

Arrays of non-diff objects cannot be wrapped up in array structs:

julia> Enzyme.gradient(Reverse, floatfirst, Diagonal([1,2,3]))
(diag = nothing,)

julia> Enzyme.gradient(Reverse, floatfirst, SA[1,2,3]')
(parent = (data = (nothing, nothing, nothing),),)

make_zeros uses an IdDict cache to preserve identity between different branches of the struct.
At present this does not...

julia> mutable struct TwoThings{A,B}; a::A; b::B; end

julia> nt = (x=TwoThings(3.0, 4.0), y=TwoThings(3.0, 4.0));

julia> nt.x === nt.y
false

julia> grad = Enzyme.gradient(Reverse, nt -> nt.x.a + nt.y.a + 20nt.x.b + 20nt.y.b, nt)
(x = (a = [11.0, 11.0], b = 22.0), y = (a = [11.0, 11.0], b = 22.0))

julia> grad.x === grad.y  # new identity created
true

# example 2

julia> arrs = [[1,2.], [3,4.]];

julia> grad = Enzyme.gradient(Reverse, nt -> sum(sum(sum, x)::Float64 for x in nt), (a = arrs, b = arrs))
(a = [[2.0, 2.0], [2.0, 2.0]], b = [[2.0, 2.0], [2.0, 2.0]])

julia> grad.a === grad.b  # container array identity is not preserved
false

julia> grad.a[1] === grad.b[1]  # array of numbers
true

A simple Flux model, no functional change, just looks different to the model:

julia> using Flux

julia> model = Chain(Embedding(reshape(1:6, 2,3) .+ 0.0), softmax);

julia> Enzyme.gradient(Reverse, m -> sum(abs2, m(1)), model)
(layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),)

Comments:

  • I'm not sure what I think about the failure to preserve === relations between some mutable objects in the original gradient. Some of this could be solved by adding an IdDict cache like make_zeros does.

  • The function called strip_types for now probably needs to be public, so that you can call it yourself after constructing dx = make_zero(x), and so that you can overload it for your array wrappers.

  • Projecting things like Symmetric to their covariant representation probably needs to be opt-in, by somehow telling gradient that you want this. (That's level 4 here: Supporting covariant derivatives #1334 (comment) .) Could be implemented as additional methods of this function, something like strip_types(x, dx, Val(true), cache)?

  • Surely all the code is in the wrong place, and needs tests.

strip_types(x::Array{<:Union{Symbol, Char, AbstractString, Nothing}}) = nothing

# Containers to recurse into
strip_types(x::Union{Tuple, NamedTuple, Array}) = map(strip_types, x) # need to worry about undef?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work, for example, for a recursive type (and will instead infinite loop presumably?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I guess the implementation is a sketch, and a real one would need some kind of IdDict cache for this purpose too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants