-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ck_tile fMHA forward example #1249
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright (c) 2018-2023 -> 2024
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n, | ||
const HostTensor<BDataType>& b_b_m_n, | ||
HostTensor<CDataType>& c_b_m_n, | ||
CK_TILE_HOST void reference_batched_elementwise(const ATensorView& a_b_m_n, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using Tensor instead of TensorView. Because we also have
struct tensor_view.
Using TensorView easily make reader confuse.
include/ck_tile/host/host_tensor.hpp
Outdated
std::size_t GetOffsetFromMultiIndex(Is... is) const | ||
std::enable_if_t<((std::is_integral_v<Is> && std::is_convertible_v<Is, std::size_t>)&&...), | ||
std::size_t> | ||
GetOffsetFromMultiIndex(Is... is) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to sync the naming style of funcion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I will rename the functions
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k, | ||
const HostTensor<BDataType>& b_b_n_k, | ||
HostTensor<CDataType>& c_b_m_n, | ||
CK_TILE_HOST void reference_batched_gemm(const ATensorView& a_b_m_k, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorView -> Tensor
@@ -9,11 +9,13 @@ | |||
|
|||
namespace ck_tile { | |||
|
|||
template <typename CDataType, typename MaskingType> | |||
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask) | |||
template <typename CTensorView, typename MaskingType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorView->HostTensor
typename CompDataType, | ||
typename BDataType, | ||
template <typename CompDataType, | ||
typename ATensorView, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorView->Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
we need to wait for @danyao12 merge his fmha bwd & dropout changes then refactor all the updated example codes together. |
I will continue developing the fmha fwd + KV cache reference function base on current design of |
No description provided.