Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex number support #480

Closed
wants to merge 1 commit into from
Closed

Complex number support #480

wants to merge 1 commit into from

Conversation

MikeInnes
Copy link
Member

This gets the absolute basics going, and hopefully can serve as a platform for anyone who wants to add more complex derivatives to our library. (You can make a PR against this one if it's not merged in yet.)

julia> Tracker.gradient(x -> abs2(log(x)), 1+2im)
(1.2076065567220924 - 0.20091567785600395im (tracked),)

Adding this was a bit less simple than it should have been due to the lack of an AbstractComplex equivalent, which we have to hack around (and that's quite tied to the AD internals).

I'd probably like to get complex broadcast working before merging this, but that's a bit fiddly, so it'll depend on whether I get round to it.

@simonbyrne
Copy link

Adding this was a bit less simple than it should have been due to the lack of an AbstractComplex equivalent, which we have to hack around (and that's quite tied to the AD internals).

JuliaLang/julia#26666

@ssfrr
Copy link

ssfrr commented Jan 10, 2019

I think another necessary feature would be the ability to do arithmetic on the tracked complex numbers. With this PR I tried a scalar loss function:

target = rand() + im*rand()
loss(x) = abs2(target-x)
dloss(x) = Tracker.gradient(loss, x)[1]

# initial value
x = rand() + im*rand()
@show loss(x)

for _ in 1:10
  global x -= 0.1*dloss(x)
  @show loss(x)
end

This gives an MethodError trying to apply loss. For a more basic example you get errors with these:

x = 1.0+1.0im

param(x) + x
x + param(x)
param(x)+param(x)

@chriscoey
Copy link

any updates on this? I have the need for taking gradients and Hessians of logdet of a Hermitian PSD matrix

@MikeInnes
Copy link
Member Author

Closing in favour of FluxML/Tracker.jl#16 since Tracker has moved out of Flux now.

It's not that likely that this PR will develop any further, though; if you want complex support you should check out Zygote.

@MikeInnes MikeInnes closed this Apr 23, 2019
@CarloLucibello CarloLucibello deleted the mji/complex branch April 7, 2022 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants