Skip to content

Commit 9def4f7

Browse files
committed
[mlir][vector] to_elements implements VectorUnrollOpInterface
1 parent 4e92eee commit 9def4f7

File tree

5 files changed

+117
-4
lines changed

5 files changed

+117
-4
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
198198
// the target shape.
199199
int getComputeVectorSize(int64_t size);
200200

201+
// GetNativeVectorShape implementation for to_elements ops.
202+
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ToElementsOp op);
203+
201204
// GetNativeVectorShape implementation for reduction ops.
202205
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
203206

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def Vector_FMAOp :
761761

762762
def Vector_ToElementsOp : Vector_Op<"to_elements", [
763763
InferTypeOpAdaptor, Pure,
764+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
764765
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
765766
let summary = "operation that decomposes a vector into all its scalar elements";
766767
let description = [{
@@ -807,6 +808,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
807808
let assemblyFormat = "$source attr-dict `:` type($source)";
808809
let hasFolder = 1;
809810
let hasCanonicalizer = 1;
811+
let extraClassDeclaration = [{
812+
813+
VectorType getSourceVectorType() {
814+
return ::llvm::cast<VectorType>(getSource().getType());
815+
}
816+
817+
}];
810818
}
811819

812820
def Vector_FromElementsOp : Vector_Op<"from_elements", [

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) {
14351435
return 1;
14361436
}
14371437

1438+
SmallVector<int64_t>
1439+
mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) {
1440+
VectorType srcVectorType = op.getSourceVectorType();
1441+
assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1442+
int64_t vectorSize =
1443+
mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1444+
return {vectorSize};
1445+
}
1446+
14381447
SmallVector<int64_t>
14391448
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
14401449
VectorType srcVectorType = op.getSourceVectorType();
@@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
14651474
}
14661475

14671476
return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
1468-
.Case<vector::ReductionOp, vector::TransposeOp>(
1477+
.Case<vector::ReductionOp, vector::TransposeOp, vector::ToElementsOp>(
14691478
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
14701479
.Default(std::nullopt);
14711480
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,9 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
23772377
//===----------------------------------------------------------------------===//
23782378
// ToElementsOp
23792379
//===----------------------------------------------------------------------===//
2380+
std::optional<SmallVector<int64_t, 4>> ToElementsOp::getShapeForUnroll() {
2381+
return llvm::to_vector<4>(getSourceVectorType().getShape());
2382+
}
23802383

23812384
/// Returns true if all the `operands` are defined by `defOp`.
23822385
/// Otherwise, returns false.

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
834834
vector::UnrollVectorOptions options;
835835
};
836836

837+
/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
838+
/// the target shape.
839+
///
840+
/// ```
841+
/// // In SPIR-V's default environment vector of size 8
842+
/// // are not allowed.
843+
/// %elements:8 = vector.to_elements %v : vector<8xf32>
844+
///
845+
/// ===>
846+
///
847+
/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
848+
/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
849+
/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
850+
/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
851+
/// ```
852+
///
853+
/// This pattern may fail if the rank is not divisible by to a native shape
854+
/// or if the rank is already in the target shape and therefore it may be
855+
/// skipped.
856+
struct ToElementsToTargetShape final
857+
: public OpRewritePattern<vector::ToElementsOp> {
858+
ToElementsToTargetShape(MLIRContext *context,
859+
const vector::UnrollVectorOptions &options,
860+
PatternBenefit benefit = 1)
861+
: OpRewritePattern<vector::ToElementsOp>(context, benefit),
862+
options(options) {}
863+
864+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
865+
PatternRewriter &rewriter) const override {
866+
auto targetShape = getTargetShape(options, op);
867+
if (!targetShape)
868+
return failure();
869+
870+
// We have
871+
// source_rank = N * target_rank
872+
int64_t source_rank = op.getSourceVectorType().getShape().front();
873+
int64_t target_rank = targetShape->front();
874+
int64_t N = source_rank / target_rank;
875+
876+
// Transformation where
877+
// s = source_rank and
878+
// t = target_rank
879+
// ```
880+
// %e:s = vector.to_elements %v : vector<sxf32>
881+
//
882+
// ===>
883+
//
884+
// // N vector.extract_strided_slice of size t
885+
// %v0 = vector.extract_strided_slice %v
886+
// {offsets = [0*t], sizes = [t], strides = [1]}
887+
// : vector<txf32> from vector<sxf32>
888+
// %v1 = vector.extract_strided_slice %v
889+
// {offsets = [1*t], sizes = [t], strides = [1]}
890+
// : vector<txf32> from vector<sxf32>
891+
// ...
892+
// %vNminus1 = vector.extract_strided_slice $v
893+
// {offsets = [(N-1)*t], sizes = [t], strides = [1]}
894+
// : vector<txf32> from vector<sxf32>
895+
//
896+
// // N vector.to_elements of size t vectors.
897+
// %e0:t = vector.to_elements %v0 : vector<txf32>
898+
// %e1:t = vector.to_elements %v1 : vector<txf32>
899+
// ...
900+
// %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
901+
// ```
902+
SmallVector<Value> subVectors;
903+
SmallVector<int64_t> strides(targetShape->size(), 1);
904+
for (int64_t i = 0; i < N; i++) {
905+
SmallVector<int64_t> elementOffsets = {i * target_rank};
906+
Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
907+
op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
908+
subVectors.push_back(subVector);
909+
}
910+
911+
SmallVector<Value> elements;
912+
for (const Value subVector : subVectors) {
913+
auto elementsOp =
914+
vector::ToElementsOp::create(rewriter, op.getLoc(), subVector);
915+
llvm::append_range(elements, elementsOp.getResults());
916+
}
917+
918+
rewriter.replaceOp(op, elements);
919+
return success();
920+
}
921+
922+
private:
923+
vector::UnrollVectorOptions options;
924+
};
925+
837926
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838927
/// outermost dimension of the operand. For example:
839928
///
840929
/// ```
841-
/// %0:4 = vector.to_elements %v : vector<2x2xf32>
930+
/// %0:8 = vector.to_elements %v : vector<2x2x2xf32>
842931
///
843932
/// ==>
844933
///
@@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
865954
FailureOr<SmallVector<Value>> result =
866955
vector::unrollVectorValue(source, rewriter);
867956
if (failed(result)) {
957+
// Only fails if operand is 1-dimensional.
868958
return failure();
869959
}
870960
SmallVector<Value> vectors = *result;
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131103
UnrollReductionPattern, UnrollMultiReductionPattern,
10141104
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151105
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016-
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
1017-
options, benefit);
1106+
UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
1107+
patterns.getContext(), options, benefit);
10181108
}
10191109

10201110
void mlir::vector::populateVectorToElementsUnrollPatterns(

0 commit comments

Comments
 (0)