diff --git a/src/pooling.jl b/src/pooling.jl index 538ff349f..a35c92f9e 100644 --- a/src/pooling.jl +++ b/src/pooling.jl @@ -177,3 +177,33 @@ for pool in [:maxpool, :meanpool] return Ω, $pullback end end + + +function topk(x::AbstractArray{T,N}, k; rev=false, dims=nothing) where {T,N} + if dims === nothing + y = vec(x) + perm = partialsortperm(y, 1:k; rev) + return y[perm], linear_to_cartesian(x, perm) + else + @assert dims isa Int + sz1 = size(x)[1:dims-1] + sz2 = size(x)[dims+1:end] + slice1 = CartesianIndices(sz1) + slice2 = CartesianIndices(sz2) + perm = similar(x, Int, (sz1..., k, sz2...)) + y = similar(x, (sz1..., k, sz2...)) + for I1 in slice1 + for I2 in slice2 + xI = x[I1,:,I2] + permI = partialsortperm(x[I1,:,I2], 1:k; rev) + perm[I1,:,I2] .= permI + y[I1,:,I2] .= xI[permI] + end + end + return y, perm + end +end + +function linear_to_cartesian(x, i) + CartesianIndices(x)[i] +end