Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ck_tile fMHA forward example #1249

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ed524f6
Refactor ck_tile fMHA forward example
poyenc Apr 17, 2024
e555d5f
Re-order include directives
poyenc Apr 17, 2024
adc0d20
Unify naming style
poyenc Apr 17, 2024
b279d95
Add comment for intermediate tensors
poyenc Apr 17, 2024
91556a1
Remove qualified name
poyenc Apr 17, 2024
9ff7714
Use better comment for tensor views
poyenc Apr 17, 2024
fde9b86
Use standard way to determine iterator category
poyenc Apr 18, 2024
1659b37
Support more operations in permutation_iterator
poyenc Apr 18, 2024
4f8aced
Add zip_iterator<>
poyenc Apr 18, 2024
6b196bc
Remove unused include directive
poyenc Apr 18, 2024
9bb9361
Add transform_iterator<>
poyenc Apr 18, 2024
9153db0
Support operator- for zip_iterator<>
poyenc Apr 18, 2024
4d2b0ef
Remove unnecessary data member
poyenc Apr 18, 2024
f7e3b3c
Rename variables
poyenc Apr 23, 2024
edf08d3
Leave some untested reference functions untouched
poyenc Apr 23, 2024
65b77a0
Rename TensorView to Tensor in templated functions
poyenc Apr 23, 2024
23c7a30
Rename type trait tensor_view_value_t<> to tensor_value_t<>
poyenc Apr 23, 2024
6bd317a
Add is_tensor<> traits
poyenc Apr 23, 2024
a73d026
Update license date
poyenc Apr 23, 2024
461a77d
Unify method naming style
poyenc Apr 23, 2024
26e13c9
Align stride naming style to the FA
poyenc Apr 23, 2024
5947759
Merge remote-tracking branch 'origin/develop' into ck_tile/refactor-f…
poyenc Apr 23, 2024
e27a7b1
Fix wrong LSE tensor param type
poyenc Apr 24, 2024
7369f54
Add reference_batched_fmha()
poyenc Apr 24, 2024
9b6d8a7
Remove unnecessary local tensor
poyenc Apr 24, 2024
b5acb9f
Re-order parameters
poyenc Apr 24, 2024
775b7f0
Pass batch mode smoke tests
poyenc Apr 24, 2024
a307024
Rename reference_batched_fmha() to reference_mha_fwd()
poyenc Apr 24, 2024
e0632d9
Fix group mode wrong result in reference_mha_fwd()
poyenc Apr 24, 2024
3711ec5
Reuse the existing tensor shapes
poyenc Apr 24, 2024
f9cfc81
Rename variable 'lse' to 'store_lse'
poyenc Apr 24, 2024
2b389ee
Re-order reference_mha_fwd() parameters
poyenc Apr 24, 2024
1b28376
Construct optional<> upfroont calling reference function
poyenc Apr 25, 2024
f5f4524
Add _ref to the variables
poyenc Apr 25, 2024
4cbc6ff
Remove _host_ref variable name in reference function
poyenc Apr 25, 2024
8600de2
Remove namespace qualifiers
poyenc Apr 25, 2024
40f4be8
Merge remote-tracking branch 'origin/develop' into ck_tile/refactor-f…
poyenc Apr 25, 2024
c79c058
Rename Repeat<> as RepeatView<>
poyenc Apr 25, 2024
392c94d
Change bias layout
poyenc Apr 29, 2024
051d3e2
Merge remote-tracking branch 'origin/develop' into ck_tile/refactor-f…
poyenc Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 99 additions & 214 deletions example/ck_tile/01_fmha/fmha_fwd.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/mask.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
6 changes: 6 additions & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/iterator.hpp"
#include "ck_tile/core/utility/iterator_range.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/permutation_iterator.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/ranges.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transform_iterator.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/unary_element_function.hpp"
#include "ck_tile/core/utility/zip_iterator.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/core/algorithm/cluster_descriptor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/algorithm/space_filling_curve.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/arch/utility.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/container/array.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/container/meta_data_buffer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
21 changes: 16 additions & 5 deletions include/ck_tile/core/container/span.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core/config.hpp"
#include <cstddef>

#include <array>
#include <cstddef>
#include <limits>
#include <type_traits>

namespace ck_tile {

inline constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max();

// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
Expand Down Expand Up @@ -47,9 +51,9 @@ class span
{
}

template <typename Container>
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
: span(container.data(), container.size())
template <typename ContiguousRange>
CK_TILE_HOST_DEVICE constexpr span(ContiguousRange&& range)
: span(std::data(range), std::size(range))
{
}

Expand All @@ -70,6 +74,13 @@ class span

CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }

CK_TILE_HOST_DEVICE constexpr span subspan(size_type offset,
size_type count = dynamic_extent) const
{
const size_type remain_size = (size() - offset);
return {data() + offset, std::min(count, remain_size)};
}

private:
pointer ptr_;
size_type size_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/container/thread_buffer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/numeric/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/numeric/type_convert.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/numeric/vector_type.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/buffer_view.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/null_tensor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/null_tile_window.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/shuffle_tile.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/slice_tile.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/store_tile.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/sweep_tile.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tensor_adaptor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tensor_coordinate.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tensor_view.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tile_distribution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tile_elementwise.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/tile_window.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
96 changes: 96 additions & 0 deletions include/ck_tile/core/utility/iterator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iterator>
#include <type_traits>
#include <utility>

namespace ck_tile {

template <typename T>
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;

template <typename T>
using iter_reference_t = decltype(*std::declval<T&>());

template <typename T>
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;

template <typename T, typename = void>
struct is_iterator : std::false_type
{
};

template <typename T>
struct is_iterator<T,
std::void_t<decltype(*std::declval<T>()),
decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
: std::true_type
{
};

template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value;

namespace detail {
struct Placeholder final
{
template <typename T>
constexpr inline operator T() const noexcept;
};
} // namespace detail

template <typename Iterator, typename = void>
struct is_output_iterator : std::false_type
{
};

template <typename Iterator>
struct is_output_iterator<
Iterator,
std::void_t<decltype(*std::declval<Iterator>() = std::declval<detail::Placeholder>())>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;

template <typename Iterator, typename = void>
struct is_bidirectional_iterator : std::false_type
{
};

template <typename Iterator>
struct is_bidirectional_iterator<
Iterator,
std::void_t<decltype(--std::declval<std::add_lvalue_reference_t<Iterator>>()),
decltype(std::declval<std::add_lvalue_reference_t<Iterator>>()--)>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator<Iterator>::value;

template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
};

template <typename Iterator>
struct is_random_access_iterator<Iterator,
std::void_t<decltype(std::declval<Iterator>() + 1),
decltype(std::declval<Iterator>() - 1),
decltype(std::declval<Iterator>()[1])>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};

template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;

} // namespace ck_tile
Loading