Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 27, 2024
1 parent 9835fe8 commit dfeafa5
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
namespace contrib {

/**
* Y = Rotary(X) is equivalent to if side == LEFT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = X[...,N/2:]
* Y[...,N/2:] = -X[...,:N/2]
*
* And the opposite if side == RIGHT.
*/
* Y = Rotary(X) is equivalent to if side == LEFT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = X[...,N/2:]
* Y[...,N/2:] = -X[...,:N/2]
*
* And the opposite if side == RIGHT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = -X[...,N/2:]
* Y[...,N/2:] = X[...,:N/2]
*/
template <typename T>
struct Rotary {
template <typename TDict>
Expand All @@ -26,20 +31,17 @@ struct Rotary {
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
if (side == "left") {
side_ = RotarySide::LEFT;
}
else if (side == "right") {
} else if (side == "right") {
side_ = RotarySide::RIGHT;
}
else {
} else {
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
}

return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split,
ortc::Tensor<T>& output) const {

OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split, ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
Expand All @@ -54,19 +56,15 @@ struct Rotary {
}
const int64_t* split_data = split.Data();
if (split_data[0] != split_data[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
return {kOrtxErrorInvalidArgument, "Only equal split is allowed."};
}
if (split_data[0] * 2 != input_shape[input_shape.size()-1]) {
if (split_data[0] != split_data[1] != input_shape[input_shape.size() - 1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
}

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
static_cast<int>(input_shape[input_shape.size()-1]),
input_data,
split_data,
output_data,
side_);
LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), input_length,
static_cast<int>(input_shape[input_shape.size() - 1]), input_data, split_data, output_data,
side_);
return {};
}

Expand All @@ -76,7 +74,7 @@ struct Rotary {
return OrtMemType::OrtMemTypeDefault;
}

private:
private:
RotarySide side_;
};

Expand Down

0 comments on commit dfeafa5

Please sign in to comment.