Skip to content

Commit

Permalink
SIMD/Loop framework upgrade (#2937)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger committed Sep 18, 2024
1 parent a6ebca0 commit 9dd7c4a
Show file tree
Hide file tree
Showing 13 changed files with 1,577 additions and 820 deletions.
6 changes: 3 additions & 3 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// Store f32 values back to the (normal layout) output.
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + l;
create.vec.storeIE(vecF32H, alloc, outputAF, {});
create.vec.storeIE(vecF32H, alloc, outputAF);
create.vec.storeIE(
vecF32L, alloc, outputAF, {litArchVLHalf.getValue()});
});
Expand All @@ -277,8 +277,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
Value vecF32L = convertOp.getResult(1);
// Save into archVL value buffer.
Value bufferF32 = create.mem.alignedAlloca(bufferType);
create.vec.storeIE(vecF32H, bufferF32, {litZero}, {});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {});
create.vec.storeIE(vecF32H, bufferF32, {litZero});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf});
// Save the remaining values as scalars.
create.scf.forLoop(litZero.getValue(),
remainingScalarValues.getValue(), 1,
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
// Nothing to write.
} else {
// Loop to copy the data.
createAffine.forLoopIE(zeroIE, writeUBs[i], 1,
createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/,
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
loopIndices.emplace_back(loopInd[0]);
genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref,
Expand Down
9 changes: 4 additions & 5 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,8 +1527,7 @@ static LogicalResult getPartiallyFlattenedSimdCode(

create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly,
useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF},
[&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
SmallVectorImpl<Value> &resVals, int64_t VL) {
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Type currElementType = outputElementType;
if (VL > 1)
Expand Down Expand Up @@ -1557,9 +1556,9 @@ static LogicalResult getPartiallyFlattenedSimdCode(
res = emitPostProcessingFor<OP_TYPE>(rewriter, create.getLoc(),
op, currElementType, accumulated);
}
resVals.emplace_back(res);
}); // SIMD kernel.
}); // Outer loops.
return res;
}}); // SIMD kernel.
}); // Outer loops.

rewriter.replaceOp(op, alloc);
return success();
Expand Down
265 changes: 128 additions & 137 deletions src/Conversion/ONNXToKrnl/Math/Reduction.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/NN/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct ONNXInstanceNormalizationOpLowering
for (int d = 0; d < rank - 2; ++d)
inputAccessFct.emplace_back(spatial_loopInd[d]);
// tmp += input[n,c, spatial dims]
Value oldSum = create.krnl.load(tmpMemRef, {});
Value oldSum = create.krnl.load(tmpMemRef);
Value val = create.krnl.load(inputMemRef, inputAccessFct);
val = create.math.sub(val, mean);
val = create.math.mul(val, val);
Expand Down
7 changes: 3 additions & 4 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
outputAF.emplace_back(zero);
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
[&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
SmallVectorImpl<Value> &resVals, int64_t VL) {
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
// Scale
Expand All @@ -87,8 +86,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
// Saturate
Value saturateX = create.math.clip(adjustX, qMin, qMax);
Value res = create.math.cast(quantizedElementType, saturateX);
resVals.emplace_back(res);
});
return res;
}});
if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down
45 changes: 32 additions & 13 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,37 +222,56 @@ KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops,
});
}

void KrnlBuilder::forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step,
bool useParallel, KrnlLoopBodyFn builderFn) const {
ValueRange originalLoopDef = defineLoops(1);
llvm::SmallVector<Value, 1> optLoopDef(1, originalLoopDef[0]);
if (step > 1) {
// Block loop by step.
ValueRange blockedLoopDef = block(originalLoopDef[0], step);
optLoopDef[0] = blockedLoopDef[0];
}
if (useParallel)
parallel(optLoopDef[0]);
iterateIE(originalLoopDef, optLoopDef, {lb}, {ub}, builderFn);
}

void KrnlBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL,
bool fullySimd, bool useParallel, ArrayRef<Value> inputs,
ArrayRef<DimsExpr> inputAFs, ArrayRef<Value> outputs,
ArrayRef<DimsExpr> outputAFs,
function_ref<void(KrnlBuilder &b, ArrayRef<Value> inputVals,
llvm::SmallVectorImpl<Value> &resultVals, int64_t VL)>
bodyBuilderFn) const {
ArrayRef<KrnlSimdIterateBodyFn> iterateBodyFnList) const {
onnx_mlir::impl::simdIterateIE<KrnlBuilder, KrnlBuilder>(*this, lb, ub, VL,
fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs,
bodyBuilderFn);
iterateBodyFnList);
}

void KrnlBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL,
bool fullySimd, ArrayRef<Value> inputs, ArrayRef<DimsExpr> inputAFs,
ArrayRef<Value> tmps, ArrayRef<DimsExpr> tmpAFs, ArrayRef<Value> outputs,
ArrayRef<DimsExpr> outputAFs, ArrayRef<Value> initVals,
/* reduction function (simd or scalar) */
function_ref<void(const KrnlBuilder &b, ArrayRef<Value> inputVals,
ArrayRef<Value> tmpVals, llvm::SmallVectorImpl<Value> &resultVals,
int64_t VL)>
reductionBuilderFn,
ArrayRef<KrnlSimdReductionBodyFn> reductionBodyFnList,
/* post reduction function (simd to scalar + post processing)*/
function_ref<void(const KrnlBuilder &b, ArrayRef<Value> tmpVals,
llvm::SmallVectorImpl<Value> &scalarOutputs, int64_t VL)>
postProcessingBuilderFn) const {
ArrayRef<KrnlSimdPostReductionBodyFn> postReductionBodyFnList) const {
onnx_mlir::impl::simdReduceIE<KrnlBuilder, KrnlBuilder>(*this, lb, ub, VL,
fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals,
reductionBuilderFn, postProcessingBuilderFn);
reductionBodyFnList, postReductionBodyFnList);
}

void KrnlBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL,
bool fullySimd, Value input, DimsExpr inputAF, Value tmp, DimsExpr tmpAF,
Value output, DimsExpr outputAF, Value initVal,
/* reduction functions (simd or scalar) */
KrnlSimdReductionBodyFn reductionBodyFn,
/* post reduction functions (post processing ONLY)*/
KrnlSimdPostReductionBodyFn postReductionBodyFn) const {
onnx_mlir::impl::simdReduce2DIE<KrnlBuilder, KrnlBuilder>(*this, lb, ub, VL,
fullySimd, input, inputAF, tmp, tmpAF, output, outputAF, initVal,
reductionBodyFn, postReductionBodyFn);
}

void KrnlBuilder::yield(mlir::ValueRange iterArgs) const {
void KrnlBuilder::yield(ValueRange iterArgs) const {
b().create<KrnlYieldOp>(loc(), iterArgs);
}

Expand Down
96 changes: 60 additions & 36 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ struct KrnlBuilder : public DialectBuilder {
KrnlBuilder(const DialectBuilder &db) : DialectBuilder(db) {}
virtual ~KrnlBuilder() {}

// Common load/store interface (krnl/affine/memref)
// Add offsets (if any) to the least significant dims.
mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {},
mlir::ValueRange offsets = {}) const;
mlir::Value loadIE(mlir::Value memref, mlir::ArrayRef<IndexExpr> indices = {},
mlir::ValueRange offsets = {}) const;
// Add offsets (if any) to the least significant dims.
void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {},
mlir::ValueRange offsets = {}) const;
void storeIE(mlir::Value val, mlir::Value memref,
Expand Down Expand Up @@ -70,11 +70,12 @@ struct KrnlBuilder : public DialectBuilder {
// Iterate over optimized loops given the original loops, lbs and ubs. Lambda
// function implement the body of the loop, and receive a KRNL builder and the
// loop indices.
using KrnlLoopBodyFn =
mlir::function_ref<void(KrnlBuilder &, mlir::ValueRange)>;

void iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops,
mlir::ValueRange lbs, mlir::ValueRange ubs,
mlir::function_ref<void(
KrnlBuilder &createKrnl, mlir::ValueRange indices)>
bodyBuilderFn) const;
KrnlLoopBodyFn bodyBuilderFn) const;
mlir::KrnlIterateOp iterate(mlir::ValueRange originalLoops,
mlir::ValueRange optimizedLoops, mlir::ValueRange lbs,
mlir::ValueRange ubs, mlir::ValueRange inits,
Expand All @@ -87,31 +88,38 @@ struct KrnlBuilder : public DialectBuilder {
// Same versions with Index Expressions for bounds.
void iterateIE(mlir::ValueRange originalLoops,
mlir::ValueRange optimizedLoops, mlir::ArrayRef<IndexExpr> lbs,
mlir::ArrayRef<IndexExpr> ubs,
mlir::function_ref<void(
KrnlBuilder &createKrnl, mlir::ValueRange indices)>
bodyBuilderFn) const;
mlir::ArrayRef<IndexExpr> ubs, KrnlLoopBodyFn bodyBuilderFn) const;
mlir::KrnlIterateOp iterateIE(mlir::ValueRange originalLoops,
mlir::ValueRange optimizedLoops, mlir::ArrayRef<IndexExpr> lbs,
mlir::ArrayRef<IndexExpr> ubs, mlir::ValueRange inits,
mlir::function_ref<void(KrnlBuilder &createKrnl, mlir::ValueRange indices,
mlir::ValueRange blockIters)>
bodyBuilderFn) const;

// Common loop interface (krnl/affine/scf).
void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel,
KrnlLoopBodyFn builderFn) const;

// Common simd loop interface (krnl/affine/scf).
/*
Iterate over a loop executing the loop body in SIMD mode (of vector length
VL) from lb to ub. A scalar loop may execute up to VL-1 loop
iterations when the trip count is not a multiple of VL. If fullySimd is
true, then the call assumes that the trip count is a multiple of VL.
This call needs be given each of the memref inputs to the loop body, given
as an ordered pair memref value and its corresponding access function. Same
hold for all the memref outputs of the loop body.
This simdIterateIE needs be given each of the memref inputs to the loop
body, given as an ordered pair memref value and its corresponding access
function. Same hold for all the memref outputs of the loop body.
The loop body is constructed by calling each of the KrnlSimdIterateBodyFn
given in the list. Each function is responsible for returning one output
value. The returned values are eventually stored in the output memrefs at a
location given by its respective output access function.
The loop body is given a KRNL builder, a list of loaded input (same order
as the input's memrefs and access functions). It will generate values that
must be placed in the result list in the same order as the output's memrefs
and access functions.
To generate their output, each KrnlSimdIterateBodyFn function is given
a KRNL builder, a list of loaded input (same order
as the input's memrefs and access functions), and the current VectorLength
(VL). VL is either the original VL or 1 (when executing in scalar mode).
It will be the responsibility of this call to load each of the inputs and
store each of the outputs. When operating in SIMD mode, every input and
Expand All @@ -129,45 +137,61 @@ struct KrnlBuilder : public DialectBuilder {
Dialect/Mlir/DialectBuilder.hpp.inc.
*/

using KrnlSimdIterateBodyFn = impl::SimdIterateBodyFn<KrnlBuilder>;
void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd,
bool useParallel, mlir::ArrayRef<mlir::Value> inputs,
mlir::ArrayRef<DimsExpr> inputAFs, mlir::ArrayRef<mlir::Value> outputs,
mlir::ArrayRef<DimsExpr> outputAFs,
mlir::function_ref<void(KrnlBuilder &b,
mlir::ArrayRef<mlir::Value> inputVals,
llvm::SmallVectorImpl<mlir::Value> &resultVals, int64_t VL)>
bodyBuilderFn) const;
mlir::ArrayRef<KrnlSimdIterateBodyFn> bodyBuilderFnList) const;

/*
Works similarly as simdIterateIE, but performs a reduction to a single
scalar per output value. Inputs must be strided in their innermost
dimensions. Temps are used to hold the temporary results (partial results
per SIMD lane), and the outputs have the scalar reduction outputs
Two functions are given: reductionBuilderFn to perform the partial
reductions into the temporary values tmps, finishing with up to VL partial
reductions
The second function: postProcessingBuilderFn performs the reductions of the
up to VL partial reductions into a final scalar reduction to be stored into
the outputs (a scalar value). For some reductions, post processing is also
needed, for example, mean reduction divide the accumulated sum by the
number of elements. That step is also performed here.
Two function lists are given: a list of reductionBodyFn to perform the
partial reductions into the temporary values tmps, finishing with up to VL
partial reductions The second list of postReductionBodyFn perform the
reductions of the up to VL partial reductions into a final scalar reduction
to be stored into the outputs (a scalar value). For some reductions, post
processing is also needed, for example, mean reduction divide the
accumulated sum by the number of elements. That step is also performed
here.
*/
using KrnlSimdReductionBodyFn = impl::SimdReductionBodyFn<KrnlBuilder>;
using KrnlSimdPostReductionBodyFn =
impl::SimdPostReductionBodyFn<KrnlBuilder>;

void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd,
mlir::ArrayRef<mlir::Value> inputs, mlir::ArrayRef<DimsExpr> inputAFs,
mlir::ArrayRef<mlir::Value> tmps, mlir::ArrayRef<DimsExpr> tmpAFs,
mlir::ArrayRef<mlir::Value> outputs, mlir::ArrayRef<DimsExpr> outputAFs,
mlir::ArrayRef<mlir::Value> initVals,
/* reduction function (simd or scalar) */
mlir::function_ref<void(const KrnlBuilder &b,
mlir::ArrayRef<mlir::Value> inputVals,
mlir::ArrayRef<mlir::Value> tmpVals,
llvm::SmallVectorImpl<mlir::Value> &resultVals, int64_t VL)>
reductionBuilderFn,
mlir::ArrayRef<KrnlSimdReductionBodyFn> reductionBodyFnList,
/* post reduction function (simd to scalar + post processing)*/
mlir::function_ref<void(const KrnlBuilder &b,
mlir::ArrayRef<mlir::Value> tmpVals,
llvm::SmallVectorImpl<mlir::Value> &scalarOutputs, int64_t VL)>
postProcessingBuilderFn) const;
mlir::ArrayRef<KrnlSimdPostReductionBodyFn> postReductionBodyFnList)
const;

/*
Same as simdReduceIE, but perform VL reductions at once. It expect at least
VL iterations in the second to last dimension of inputs/outputs.
Unlike simdReduceIE, the second function is for post processing only. In
simdReduceIE, that function was also used to reduce the SIMD temporary
reduction into a single scalar.
Also, at this time, simdReduce2DIE process only one reduction at a time,
whereas simdReduceIE could process an arbitrary number of reductions.
*/
void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd,
mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF,
mlir::Value output, DimsExpr outputAF, mlir::Value initVal,
/* reduction functions (simd or scalar) */
KrnlSimdReductionBodyFn reductionBodyFn,
/* post reduction functions (post processing ONLY)*/
KrnlSimdPostReductionBodyFn postReductionBodyFn) const;

void yield(mlir::ValueRange iterArgs) const;

Expand Down
Loading

0 comments on commit 9dd7c4a

Please sign in to comment.