Skip to content
13 changes: 11 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
@uniform
@groupsize
@ndrange
synchronize
allocate
```

### Reduction

```@docs
@groupreduce
@warp_groupreduce
KernelAbstractions.shfl_down
KernelAbstractions.supports_warp_reduction
```

## Host language

```@docs
synchronize
allocate
KernelAbstractions.zeros
```

Expand Down
2 changes: 2 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ argconvert(k::Kernel{T}, arg) where {T} =
supports_enzyme(::Backend) = false
function __fake_compiler_job end

include("groupreduction.jl")

###
# Extras
# - LoopInfo
Expand Down
120 changes: 120 additions & 0 deletions src/groupreduction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export @groupreduce, @warp_groupreduce

"""
@groupreduce op val neutral [groupsize]

Perform group reduction of `val` using `op`.

# Arguments

- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.

- `groupsize` specifies size of the workgroup.
If a kernel does not specifies `groupsize` statically, then it is required to
provide `groupsize`.
Also can be used to perform reduction accross first `groupsize` threads
(if `groupsize < @groupsize()`).

# Returns

Result of the reduction.
"""
macro groupreduce(op, val)
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__)))))))
end
macro groupreduce(op, val, groupsize)
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize)))))
end

function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize}
storage = @localmem T groupsize

local_idx = @index(Local)
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val)
@synchronize()

s::UInt64 = groupsize ÷ 0x02
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
s >>= 0x01
end

if local_idx == 0x01
@inbounds val = storage[local_idx]
end
return val
end

# Warp groupreduce.

"""
@warp_groupreduce op val neutral [groupsize]

Perform group reduction of `val` using `op`.
Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic,
followed by final reduction over results of individual warp reductions.

!!! note

Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction.
"""
macro warp_groupreduce(op, val, neutral)
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__)))))))
end
macro warp_groupreduce(op, val, neutral, groupsize)
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize)))))
end

"""
shfl_down(val::T, offset::Integer)::T where T

Read `val` from a lane with higher id given by `offset`.
"""
function shfl_down end
supports_warp_reduction() = false

"""
supports_warp_reduction(::Backend)

Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction.
"""
supports_warp_reduction(::Backend) = false

# Assume warp is 32 lanes.
const __warpsize = UInt32(32)
# Maximum number of warps (for a groupsize = 1024).
const __warp_bins = UInt32(32)

@inline function __warp_reduce(val, op)
offset::UInt32 = __warpsize ÷ 0x02
while offset > 0x00
val = op(val, shfl_down(val, offset))
offset >>= 0x01
end
return val
end

function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
storage = @localmem T __warp_bins

local_idx = @index(Local)
lane = (local_idx - 0x01) % __warpsize + 0x01
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01

# Each warp performs a reduction and writes results into its own bin in `storage`.
val = __warp_reduce(val, op)
@inbounds lane == 0x01 && (storage[warp_id] = val)
@synchronize()

# Final reduction of the `storage` on the first warp.
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize
@inbounds val = within_storage ? storage[lane] : neutral
warp_id == 0x01 && (val = __warp_reduce(val, op))
return val
end
70 changes: 70 additions & 0 deletions test/groupreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val)
i == 1 && (y[1] = res)
end

@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, groupsize)
i == 1 && (y[1] = res)
end

@kernel cpu=false function warp_groupreduce_1!(y, x, op, neutral)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @warp_groupreduce(op, val, neutral)
i == 1 && (y[1] = res)
end

@kernel cpu=false function warp_groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @warp_groupreduce(op, val, neutral, groupsize)
i == 1 && (y[1] = res)
end

function groupreduce_testsuite(backend, AT)
# TODO should be a better way of querying max groupsize
groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)

@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
x = AT(ones(T, n))
y = AT(zeros(T, 1))
neutral = zero(T)
op = +

groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
@test Array(y)[1] == n

for groupsize in (64, 128)
groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
@test Array(y)[1] == groupsize
end
end
end

if KernelAbstractions.supports_warp_reduction(backend())
@testset "@warp_groupreduce" begin
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
x = AT(ones(T, n))
y = AT(zeros(T, 1))
neutral = zero(T)
op = +

warp_groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
@test Array(y)[1] == n

for groupsize in (64, 128)
warp_groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
@test Array(y)[1] == groupsize
end
end
end
end
end
8 changes: 8 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ include("reflection.jl")
include("examples.jl")
include("convert.jl")
include("specialfunctions.jl")
include("groupreduce.jl")

function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
@conditional_testset "Unittests" skip_tests begin
Expand Down Expand Up @@ -92,6 +93,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
examples_testsuite(backend_str)
end

# TODO @index(Local) only works as a top-level expression on CPU.
if backend != CPU
@conditional_testset "@groupreduce" skip_tests begin
groupreduce_testsuite(backend, AT)
end
end

return
end

Expand Down
Loading