Skip to content

Commit

Permalink
Avoid using unused one/zero values.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Dec 9, 2021
1 parent 92e9d74 commit 92eb9fe
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ using ..CUDA: i32
(eq && a′ == b′) || lt(a′, b′)
end

# To allow sorting tuples of numbers:
@inline _zero(x) = Base.zero(x)
@inline _zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)

@inline _one(x) = Base.one(x)
@inline _one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)


# Batch partitioning
"""
Expand Down Expand Up @@ -80,7 +73,12 @@ Uses block y index to decide which values to operate on.
sync_threads()
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
val = idx0 <= hi ? values[idx0] : _one(eltype(values))
val = if idx0 <= hi
values[idx0]
else
Ref{eltype(values)}()[] # undef
# if idx0 > hi, val, comparison and dest_idx are unused
end
comparison = flex_lt(pivot, val, parity, lt, by)

@inbounds if idx0 <= hi
Expand Down Expand Up @@ -190,7 +188,7 @@ Must only run on 1 SM.
swap = if threadIdx().x <= to_move
vals[lo + a + threadIdx().x]
else
_zero(eltype(vals)) # unused value
Ref{eltype(vals)}()[] # undef
end
sync_threads()
if threadIdx().x <= to_move
Expand Down Expand Up @@ -222,7 +220,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b

@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
sync_threads()
old_val = _zero(eltype(swap))

log_blockDim = begin
out = 0
Expand All @@ -245,8 +242,10 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
to_swap = (i & k) == 0 && bitonic_lt(l, i) || (i & k) != 0 && bitonic_lt(i, l)
to_swap = to_swap == (i < l)

if to_swap
@inbounds old_val = swap[l + 1]
old_val = if to_swap
@inbounds swap[l + 1]
else
Ref{eltype(swap)}()[] # undef
end
sync_threads()
if to_swap
Expand Down Expand Up @@ -279,7 +278,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
swap[buddy]
else
_zero(eltype(swap)) # unused value
Ref{eltype(swap)}()[] # undef
end
sync_threads()
if 1 <= buddy <= L && threadIdx().x <= L
Expand Down

0 comments on commit 92eb9fe

Please sign in to comment.