Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sorting of tuples of numbers #1196

Merged
merged 3 commits into from
Jan 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ Uses block y index to decide which values to operate on.
sync_threads()
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
val = idx0 <= hi ? values[idx0] : one(eltype(values))
comparison = flex_lt(pivot, val, parity, lt, by)
@inbounds if idx0 <= hi
val = values[idx0]
comparison = flex_lt(pivot, val, parity, lt, by)
end

@inbounds if idx0 <= hi
sums[threadIdx().x] = 1 & comparison
Expand All @@ -85,9 +87,11 @@ Uses block y index to decide which values to operate on.

cumsum!(sums)

dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
@inbounds if idx0 <= hi && dest_idx <= length(swap)
swap[dest_idx] = val
@inbounds if idx0 <= hi
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
if dest_idx <= length(swap)
swap[dest_idx] = val
end
end
sync_threads()

Expand Down Expand Up @@ -180,10 +184,8 @@ Must only run on 1 SM.
c = n_eff() - d
to_move = min(b, c)
sync_threads()
swap = if threadIdx().x <= to_move
vals[lo + a + threadIdx().x]
else
zero(eltype(vals)) # unused value
if threadIdx().x <= to_move
swap = vals[lo + a + threadIdx().x]
end
sync_threads()
if threadIdx().x <= to_move
Expand Down Expand Up @@ -215,7 +217,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b

@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
sync_threads()
old_val = zero(eltype(swap))

log_blockDim = begin
out = 0
Expand Down Expand Up @@ -269,10 +270,8 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
for level in 0:L
# get left/right neighbor depending on even/odd level
buddy = threadIdx().x - 1i32 + 2i32 * (1i32 & (threadIdx().x % 2i32 != level % 2i32))
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
swap[buddy]
else
zero(eltype(swap)) # unused value
if 1 <= buddy <= L && threadIdx().x <= L
buddy_val = swap[buddy]
end
sync_threads()
if 1 <= buddy <= L && threadIdx().x <= L
Expand Down Expand Up @@ -738,7 +737,7 @@ Each view is indexed along block x dim: one view per pseudo-block
@inbounds swap[threadIdx().x, threadIdx().y] = vals[index+one(I)]
end
sync_threads()
return @view swap[:, threadIdx().y]
return @inbounds @view swap[:, threadIdx().y]
end

"""
Expand Down
5 changes: 5 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ end
end
end

# XXX: some tests here make compute-sanitizer hang, but only on CI.
# maybe related to the container set-up? try again once we use Sandbox.jl.

@testset "interface" begin
@testset "quicksort" begin
# pre-sorted
Expand All @@ -302,6 +305,7 @@ end
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.QuickSort)
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.QuickSort)
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.QuickSort)
@not_if_sanitize @test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.QuickSort)

# non-uniform distributions
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 2); alg=CUDA.QuickSort)
Expand Down Expand Up @@ -345,6 +349,7 @@ end
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.BitonicSort)
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.BitonicSort)
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.BitonicSort)
@not_if_sanitize @test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.BitonicSort)

# test various sizes
@test check_sort!(Float32, 1, x -> rand(Float32); alg=CUDA.BitonicSort)
Expand Down