Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring over OneElement for scalar getindex #717

Merged
merged 6 commits into from
Jun 13, 2023
Merged

Bring over OneElement for scalar getindex #717

merged 6 commits into from
Jun 13, 2023

Conversation

oxinabox
Copy link
Member

I think this is maybe is a precondition for FluxML/Zygote.jl#1328
so that we know that deleting the Zygote rules will not cause performance regressions.

Still WIP as i chase down failing tests.

@oxinabox
Copy link
Member Author

We should think about improving inferability so it stops making unionings over wether or not it is going to reshape.
One part i think is making the axes at type-param here.
But i think we also need changes in CRC

"""
    OneElement(val, ind, axes) <: AbstractArray

Extremely simple `struct` used for the gradient of scalar `getindex`.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
  val::T
  ind::I
  OneElement(val::T, ind::I, axes) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I, axes}(val, ind)
end
Base.size(::OneElement{<:Any,<:Any,<:Any, A}) where A = map(length, A)
Base.axes(::OneElement{<:Any,<:Any,<:Any, A}) where A = A
Base.getindex(xs::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==xs.ind, xs.val, zero(T))

Still this is a small union and doesn't really matter
but that is why i have to drop a lot of the inference tests

@oxinabox
Copy link
Member Author

Yota failure is unrelated

@oxinabox
Copy link
Member Author

Without #718
this should be broken on 1.9 (CI / Julia 1)
but should work on Julia 1.6

@oxinabox oxinabox marked this pull request as ready for review May 19, 2023 15:12
@oxinabox oxinabox requested a review from mcabbott May 19, 2023 15:12
@oxinabox
Copy link
Member Author

oxinabox commented May 22, 2023

Added some performance enhancements

Before

julia> x = rand(300,300);

julia> dx = OneElement(3.14, (231, 21), axes(x));

julia> @btime $x + $(collect(dx));  # Baseline
  31.978 μs (2 allocations: 703.17 KiB)

julia> @btime $x + $dx;
  61.043 μs (2 allocations: 703.17 KiB)

After:

julia> @btime $x + $(dx);
  21.056 μs (2 allocations: 703.17 KiB)

The orig version of this PR (and thus more or less what is in Zygote` was 2x slower than baseline, for accumulation,
with my fix it is now 30% faster than baseline.
(for a 300x300 matrix. This gets better the bigger the matrix).

For allocation no change vs what in Zygote, it remains basically ∞x faster than baseline -- because it just constant folds out

julia> @btime collect(OneElement(3.14, (231, 21), axes($x)));  # approximate baseline
  54.806 μs (2 allocations: 703.17 KiB)

julia> @btime OneElement(3.14, (231, 21), axes($x));
  1.748 ns (0 allocations: 0 bytes)

I think this is worth a the type instability of going to a small union

@oxinabox
Copy link
Member Author

bumping @mcabbott or @ToucheSir for a review

@mcabbott
Copy link
Member

One of the things I don't like about OneElement is that it's tied to Array. Somewhere I had a prototype which stored BroadcastStyle too, so that it remembers that you indexed e.g. a CuArray.

Maybe worth mentioning that JuliaArrays/FillArrays.jl#235 added essentially the same struct that Zygote has to FillArrays. CR could consider getting it from there, although it's so simple that owning it is fine. I think FillArrays may overload some mul! methods etc, maybe worth thinking about whether any of those are desirable for CR's purposes.

Sorry not really a review, a bit swamped.

@ToucheSir
Copy link
Contributor

I have nothing intelligent to add on top of what Michael said, but it's cool that FillArrays has OneElement now. Would a port of the benchmark in FluxML/Zygote.jl#962 help with making a decision here? I can look into that if there's interest.

@oxinabox
Copy link
Member Author

oxinabox commented Jun 1, 2023

FillArrays is still pending in #46
Though we could start with it here, and add it else where when opertunity arrises.
Would ideally have JuliaArrays/FillArrays.jl#260 addressed first

I am happy merging this as is, with other stuff to come in follow-ups.
I want to unblock deleting the rules from Zygote.
but I do need one of you to to approve.

@ToucheSir
Copy link
Contributor

Happy to get this landed first. Before I approve, does anything need to be done to make CI happy?

@oxinabox
Copy link
Member Author

ok, now everything we expect to be passing is passing

@oxinabox oxinabox merged commit b36e66c into main Jun 13, 2023
@oxinabox oxinabox deleted the ox/oneelement branch June 13, 2023 04:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants