Skip to content

Commit

Permalink
clang
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 2, 2024
1 parent 3cdf72a commit 27d4b5c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 80 deletions.
135 changes: 71 additions & 64 deletions operators/contrib/cuda/scatter_nd_of_shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

namespace ortops {

#define _ENFORCE(cond, msg) if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);
#define _ENFORCE(cond, msg) \
if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);

#ifndef HIP_LONG
#define HIP_LONG int32_t
Expand All @@ -17,14 +18,16 @@ namespace ortops {

struct GridDim {
enum : CUDA_LONG {
maxThreadsPerBlock = 256, // max threads per block
maxElementsPerThread = 4, // max element processed per thread
maxThreadsPerBlock = 256, // max threads per block
maxElementsPerThread = 4, // max element processed per thread
};
};

template <typename T> __device__ __forceinline__ void _add_inplace(T &x, const T a) { x += a; }
template <typename T>
__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; }

template<> __device__ __forceinline__ void _add_inplace(half &x, const half a) {
template <>
__device__ __forceinline__ void _add_inplace(half& x, const half a) {
#if __CUDA_ARCH__ < 700
x = __float2half(__half2float(x) + __half2float(a));
#else
Expand All @@ -34,8 +37,8 @@ template<> __device__ __forceinline__ void _add_inplace(half &x, const half a) {

template <typename T>
__global__ void
addition_inplace_kernel(T *__restrict__ output_data, const int64_t *__restrict__ indices_data,
const T *__restrict__ updates_data, const CUDA_LONG indice_size,
addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data,
const T* __restrict__ updates_data, const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride) {
HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= stride)
Expand All @@ -55,104 +58,108 @@ addition_inplace_kernel(T *__restrict__ output_data, const int64_t *__restrict__
//////////////////

template <typename T>
void *ScatterNDOfShapeOp<T>::CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
void* ScatterNDOfShapeOp<T>::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return std::make_unique<ScatterNDOfShapeKernel<T>>(api, info).release();
}

template <typename T> const char *ScatterNDOfShapeOp<T>::GetName() const {
template <typename T>
const char* ScatterNDOfShapeOp<T>::GetName() const {
return "ScatterNDOfShape";
}

template <typename T> const char *ScatterNDOfShapeOp<T>::GetExecutionProviderType() const {
template <typename T>
const char* ScatterNDOfShapeOp<T>::GetExecutionProviderType() const {
return "CUDAExecutionProvider";
}

template <typename T> size_t ScatterNDOfShapeOp<T>::GetInputTypeCount() const { return 3; };
template <typename T>
size_t ScatterNDOfShapeOp<T>::GetInputTypeCount() const { return 3; };

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetInputType(std::size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<half>::GetInputType(std::size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtMemType ScatterNDOfShapeOp<T>::GetInputMemoryType(std::size_t index) const {
switch (index) {
case 0:
return OrtMemTypeCPUInput;
case 1:
case 2:
return OrtMemTypeDefault;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
case 0:
return OrtMemTypeCPUInput;
case 1:
case 2:
return OrtMemTypeDefault;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtCustomOpInputOutputCharacteristic
ScatterNDOfShapeOp<T>::GetInputCharacteristic(std::size_t index) const {
switch (index) {
case 0:
case 1:
case 2:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
case 0:
case 1:
case 2:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T> size_t ScatterNDOfShapeOp<T>::GetOutputTypeCount() const { return 1; }
template <typename T>
size_t ScatterNDOfShapeOp<T>::GetOutputTypeCount() const { return 1; }

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetOutputType(std::size_t index) const {
// D, scale D
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<half>::GetOutputType(std::size_t index) const {
// D, scale D
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtCustomOpInputOutputCharacteristic
ScatterNDOfShapeOp<T>::GetOutputCharacteristic(std::size_t index) const {
switch (index) {
case 0:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION);
case 0:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION);
}
}

Expand All @@ -161,8 +168,8 @@ ScatterNDOfShapeOp<T>::GetOutputCharacteristic(std::size_t index) const {
///////////////////

template <typename T>
ScatterNDOfShapeKernel<T>::ScatterNDOfShapeKernel(const OrtApi &api,
const OrtKernelInfo *info) {
ScatterNDOfShapeKernel<T>::ScatterNDOfShapeKernel(const OrtApi& api,
const OrtKernelInfo* info) {
char value_string[1000];
std::size_t size = 1000;
ThrowOnError(api, api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size));
Expand All @@ -178,7 +185,8 @@ ScatterNDOfShapeKernel<T>::ScatterNDOfShapeKernel(const OrtApi &api,
maxThreadPerBlock_ = prop.maxThreadsPerBlock;
}

template <typename T> void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext *context) {
template <typename T>
void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext* context) {
Ort::KernelContext ctx(context);

int n_inputs = ctx.GetInputCount();
Expand All @@ -197,13 +205,13 @@ template <typename T> void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext *

auto memi = updates.GetTensorMemoryInfo();
_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Tensor updates is not on GPU.");
"Tensor updates is not on GPU.");

auto mem = shape.GetTensorMemoryInfo();
_ENFORCE(
mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
"Input shape is not on CPU.");
const int64_t *X = shape.GetTensorData<int64_t>();
const int64_t* X = shape.GetTensorData<int64_t>();
std::vector<int64_t> dims(X, X + dimensions[0]);
output = ctx.GetOutput(0, dims);

Expand All @@ -212,12 +220,11 @@ template <typename T> void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext *
if (reduction_ == Reduction::Add &&
indices_shape[indices_shape.size() - 1] == 1 && input_shape.size() == 2 &&
input_shape[input_shape.size() - 1] >= maxThreadPerBlock_) {

size_t indice_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(indices_shape));
size_t update_size = static_cast<size_t>(onnx_c_ops::flattened_dimension(updates_shape));

_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1],
"Size mismatch.");
"Size mismatch.");

ComputeNoAtomic(stream, input_shape, indices_shape, output.GetTensorMutableData<T>(),
indices.GetTensorData<int64_t>(), updates.GetTensorData<T>());
Expand All @@ -227,11 +234,11 @@ template <typename T> void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext *
}

template <typename T>
void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t &stream,
const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &indices_shape,
T *output_data, const int64_t *indices_data,
const T *updates_data) const {
void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t& stream,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape,
T* output_data, const int64_t* indices_data,
const T* updates_data) const {
// The kernel is slow if there are a lot of duplicates.
// reduction_ == Reduction::add
// indices_shape[indices_shape.size() - 1] == 1
Expand All @@ -257,4 +264,4 @@ void ScatterNDOfShapeKernel<T>::ComputeNoAtomic(cudaStream_t &stream,
static ScatterNDOfShapeOp<float> _op32;
static ScatterNDOfShapeOp<half> _op16;

} // namespace ortops
} // namespace ortops
33 changes: 17 additions & 16 deletions operators/contrib/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ enum class Reduction : int {
};

/**
* This kernel implementation the fusion of ConstantOfShape and ScatterND.
* The implementation does not use OrtLiteCustom as the input shape (first input)
* is expected to be on CPU wheeras the other outputs are expected to be on CUDA.
*/
template <typename T> struct ScatterNDOfShapeKernel {
ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);

private:
void ComputeNoAtomic(cudaStream_t &stream, const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &indices_shape, T *output_data,
const int64_t *indices_data, const T *updates_data) const;
* This kernel implementation the fusion of ConstantOfShape and ScatterND.
* The implementation does not use OrtLiteCustom as the input shape (first input)
* is expected to be on CPU wheeras the other outputs are expected to be on CUDA.
*/
template <typename T>
struct ScatterNDOfShapeKernel {
ScatterNDOfShapeKernel(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);

private:
void ComputeNoAtomic(cudaStream_t& stream, const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape, T* output_data,
const int64_t* indices_data, const T* updates_data) const;

Reduction reduction_;
int maxThreadPerBlock_;
Expand All @@ -37,9 +38,9 @@ struct ScatterNDOfShapeOp
: Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> {
typedef Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> parent_type;
ScatterNDOfShapeOp() : parent_type() {}
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const;
const char *GetName() const;
const char *GetExecutionProviderType() const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
const char* GetExecutionProviderType() const;

std::size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(std::size_t index) const;
Expand All @@ -51,4 +52,4 @@ struct ScatterNDOfShapeOp
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const;
};

} // namespace ortops
} // namespace ortops

0 comments on commit 27d4b5c

Please sign in to comment.