From fbbf0e4113818f7ace97e4804679d579f8144a27 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 16:22:58 -0500 Subject: [PATCH 01/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff2082df..ccea764cfc579 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [Pure, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba02100a..3e125e5c1f37b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2782,10 +2782,6 @@ void BroadcastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } -std::optional> BroadcastOp::getShapeForUnroll() { - return llvm::to_vector<4>(getResultVectorType().getShape()); -} - /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. static llvm::SetVector From 1964d161457e71208189065fc3cf82f2341e26e7 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 16:33:14 -0500 Subject: [PATCH 02/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index ccea764cfc579..1d3f70a9813f7 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2758,7 +2758,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ def Vector_TransposeOp : Vector_Op<"transpose", [Pure, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]> { let summary = "vector transpose operation"; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3e125e5c1f37b..2d5580ec0ff81 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6716,10 +6716,6 @@ LogicalResult vector::TransposeOp::verify() { return success(); } -std::optional> TransposeOp::getShapeForUnroll() { - return llvm::to_vector<4>(getResultVectorType().getShape()); -} - void TransposeOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), argRanges.front()); From a0c6e4f90d38ab2609ebfce99fc1b28c623aeb11 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 16:39:13 -0500 Subject: [PATCH 03/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 1d3f70a9813f7..fd6196a156d0f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2057,7 +2057,7 @@ def Vector_GatherOp : Vector_Op<"gather", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods DeclareOpInterfaceMethods ]>, Arguments<(ins Arg, "", [MemRead]>:$base, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2d5580ec0ff81..cac8defb4d078 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5982,10 +5982,6 @@ Type GatherOp::getExpectedMaskType() { vecType.getScalableDims()); } -std::optional> GatherOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...] static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { auto vecType = dyn_cast(indexVec.getType()); From a6cbe0b42db5de0609455d3b1b575c006f6d3e4d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 16:43:37 -0500 Subject: [PATCH 04/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index fd6196a156d0f..fa613a86ad793 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -732,7 +732,7 @@ def Vector_ExtractOp : def Vector_FMAOp : Op, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, VectorOfAnyRankOf<[AnyFloat]>:$rhs, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index cac8defb4d078..b56e98dd6b595 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2374,14 +2374,6 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, results.push_back(llvm::cast(attr).getInt()); } -//===----------------------------------------------------------------------===// -// FmaOp -//===----------------------------------------------------------------------===// - -std::optional> FMAOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - //===----------------------------------------------------------------------===// // ToElementsOp //===----------------------------------------------------------------------===// From cd648dac74e3d607e4bf13c3e8bc7c65b0d5c698 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 16:47:12 -0500 Subject: [PATCH 05/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index fa613a86ad793..a85ea2e128e1f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1245,7 +1245,7 @@ def Vector_ExtractStridedSliceOp : def Vector_TransferReadOp : Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b56e98dd6b595..f126f8dd6c4dd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5088,10 +5088,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) { return OpFoldResult(); } -std::optional> TransferReadOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - void TransferReadOp::getEffects( SmallVectorImpl> &effects) { From 5103187a4f7b4676bc2125297a632b1d8419f9be Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 17:12:57 -0500 Subject: [PATCH 06/15] [mlir][vector] Use getShapeForUnroll's default implementation. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index a85ea2e128e1f..acfa578a184b8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1653,7 +1653,7 @@ def Vector_TransferWriteOp : } def Vector_LoadOp : Vector_Op<"load", [ - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f126f8dd6c4dd..b030b060c6ba0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5762,10 +5762,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) { return OpFoldResult(); } -std::optional> LoadOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - FailureOr>> LoadOp::bubbleDownCasts(OpBuilder &builder) { return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), From 71e53e7f294286f280b012367515f53a81b2cdb9 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 17:21:35 -0500 Subject: [PATCH 07/15] Fix documentation --- mlir/include/mlir/Interfaces/VectorInterfaces.td | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index 6838c16fdf0fe..1223f5c0704ab 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { let methods = [ InterfaceMethod< /*desc=*/[{ - Return the shape ratio of unrolling to the target vector shape - `targetShape`. Return `std::nullopt` if the op cannot be unrolled to the - target vector shape. + Return the shape of the vector of this operation, which may be used to decide unrolling factors. + Return std::nullopt if the op is not applicable for unrolling. }], /*retTy=*/"::std::optional<::llvm::SmallVector>", /*methodName=*/"getShapeForUnroll", From 200773d78f4e57baf5d02b9531d97a289012399a Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 6 Nov 2025 15:23:07 -0500 Subject: [PATCH 08/15] Fix rebase --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index acfa578a184b8..a1c5298629e58 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2057,7 +2057,7 @@ def Vector_GatherOp : Vector_Op<"gather", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins Arg, "", [MemRead]>:$base, From aa4906a085fe94bc31d88ff9d0ac12131434ccae Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 4 Nov 2025 17:30:43 -0500 Subject: [PATCH 09/15] [mlir][vector] to_elements implements VectorUnrollOpInterface --- .../SPIRV/Transforms/SPIRVConversion.h | 3 + .../mlir/Dialect/Vector/IR/VectorOps.td | 8 ++ .../SPIRV/Transforms/SPIRVConversion.cpp | 11 ++- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 + .../Vector/Transforms/VectorUnroll.cpp | 96 ++++++++++++++++++- 5 files changed, 117 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 03ae54a8ae30a..f202c0ea88bd0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, // the target shape. int getComputeVectorSize(int64_t size); +// GetNativeVectorShape implementation for to_elements ops. +SmallVector getNativeVectorShapeImpl(vector::ToElementsOp op); + // GetNativeVectorShape implementation for reduction ops. SmallVector getNativeVectorShapeImpl(vector::ReductionOp op); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index a1c5298629e58..51e9a9b986315 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -762,6 +762,7 @@ def Vector_FMAOp : def Vector_ToElementsOp : Vector_Op<"to_elements", [ InferTypeOpAdaptor, Pure, + DeclareOpInterfaceMethods, ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> { let summary = "operation that decomposes a vector into all its scalar elements"; let description = [{ @@ -808,6 +809,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ let assemblyFormat = "$source attr-dict `:` type($source)"; let hasFolder = 1; let hasCanonicalizer = 1; + let extraClassDeclaration = [{ + + VectorType getSourceVectorType() { + return ::llvm::cast(getSource().getType()); + } + + }]; } def Vector_FromElementsOp : Vector_Op<"from_elements", [ diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index cb9b7f6ec2fd2..22097f5f2cdc6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) { return 1; } +SmallVector +mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) { + VectorType srcVectorType = op.getSourceVectorType(); + assert(srcVectorType.getRank() == 1); // Guaranteed by semantics + int64_t vectorSize = + mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); + return {vectorSize}; +} + SmallVector mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { VectorType srcVectorType = op.getSourceVectorType(); @@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) { } return TypeSwitch>>(op) - .Case( + .Case( [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) .Default(std::nullopt); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b030b060c6ba0..4fe3b99f7fd6a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2377,6 +2377,9 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, //===----------------------------------------------------------------------===// // ToElementsOp //===----------------------------------------------------------------------===// +std::optional> ToElementsOp::getShapeForUnroll() { + return llvm::to_vector<4>(getSourceVectorType().getShape()); +} /// Returns true if all the `operands` are defined by `defOp`. /// Otherwise, returns false. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..c49718e0902a5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to +/// the target shape. +/// +/// ``` +/// // In SPIR-V's default environment vector of size 8 +/// // are not allowed. +/// %elements:8 = vector.to_elements %v : vector<8xf32> +/// +/// ===> +/// +/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32> +/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32> +/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32> +/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32> +/// ``` +/// +/// This pattern may fail if the rank is not divisible by to a native shape +/// or if the rank is already in the target shape and therefore it may be +/// skipped. +struct ToElementsToTargetShape final + : public OpRewritePattern { + ToElementsToTargetShape(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, op); + if (!targetShape) + return failure(); + + // We have + // source_rank = N * target_rank + int64_t source_rank = op.getSourceVectorType().getShape().front(); + int64_t target_rank = targetShape->front(); + int64_t N = source_rank / target_rank; + + // Transformation where + // s = source_rank and + // t = target_rank + // ``` + // %e:s = vector.to_elements %v : vector + // + // ===> + // + // // N vector.extract_strided_slice of size t + // %v0 = vector.extract_strided_slice %v + // {offsets = [0*t], sizes = [t], strides = [1]} + // : vector from vector + // %v1 = vector.extract_strided_slice %v + // {offsets = [1*t], sizes = [t], strides = [1]} + // : vector from vector + // ... + // %vNminus1 = vector.extract_strided_slice $v + // {offsets = [(N-1)*t], sizes = [t], strides = [1]} + // : vector from vector + // + // // N vector.to_elements of size t vectors. + // %e0:t = vector.to_elements %v0 : vector + // %e1:t = vector.to_elements %v1 : vector + // ... + // %eNminus1:t = vector.to_elements %vNminus1 : vector + // ``` + SmallVector subVectors; + SmallVector strides(targetShape->size(), 1); + for (int64_t i = 0; i < N; i++) { + SmallVector elementOffsets = {i * target_rank}; + Value subVector = rewriter.createOrFold( + op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides); + subVectors.push_back(subVector); + } + + SmallVector elements; + for (const Value subVector : subVectors) { + auto elementsOp = + vector::ToElementsOp::create(rewriter, op.getLoc(), subVector); + llvm::append_range(elements, elementsOp.getResults()); + } + + rewriter.replaceOp(op, elements); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the /// outermost dimension of the operand. For example: /// /// ``` -/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// %0:8 = vector.to_elements %v : vector<2x2x2xf32> /// /// ==> /// @@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern { FailureOr> result = vector::unrollVectorValue(source, rewriter); if (failed(result)) { + // Only fails if operand is 1-dimensional. return failure(); } SmallVector vectors = *result; @@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( From 228d0b142b14f95ef2dae0030fd39f48f14584b4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 5 Nov 2025 15:14:09 -0500 Subject: [PATCH 10/15] [mlir] Test vector.to_elements to spirv conversion. --- .../ConvertToSPIRV/vector-sizes.mlir | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir new file mode 100644 index 0000000000000..402c539a77093 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s + +// COM: This file tests the current behaviour of the SignatureConversion +// COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V +// COM: sizes. + +// COM: vector's of rank 1 and size 1 will be changed +// COM: to scalars. Since vector.to_elements will also produce +// COM: a scalar, we expect the vector.to_elements to be folded +// COM: away. Please note that even if run-signature-conversion=false +// COM: The pattern FuncOpConversion will still run and change parameters +// COM: which fit this constraint. + +// CHECK-LABEL: spirv.func @vec_size_1 +// CHECK-SAME: (%[[ARG0:.+]]: f32) +func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) { + // CHECK-NEXT: spirv.ReturnValue %[[ARG0]] : f32 + %0:1 = vector.to_elements %arg0 : vector<1xf32> + return %0#0 : f32 +} + +// ----- + +// COM: vector's of rank 2, 3, 4 are allowed by SPIR-V. +// So they remain unchanged. FuncOpConversion will still +// run, but the signature converter will not convert these vectors. + +// CHECK-LABEL: spirv.func @vec_size_2 +// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>) +func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) { + // COM: A single result type is enforced by the semantics + + // CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> + %0:2 = vector.to_elements %arg0 : vector<2xf32> + + // CHECK-NEXT: spirv.ReturnValue %[[VAL]] + return %0#0 : f32 +} + +// ----- + +// COM: vector of rank 5 is the first one that doesn't fit +// COM: into SPIR-V's vectors. + +// COM: run-signature-conversion=false means that +// COM: this vector will not be unrolled. + +// CHECK-LABEL: func.func @vec_size_5 +// CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>) +func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) { + + // CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32> + + // COM: We have the following comment in VectorConvertToElementOp + // COM: + // COM: // Input vectors of size 1 are converted to scalars by the type converter. + // COM: // We cannot use `spirv::CompositeExtractOp` directly in this case. + // COM: // For a scalar source, the result is just the scalar itself. + // COM: + // COM: Which in this case means an unrealized conversion cast. + + // CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32 + %0:5 = vector.to_elements %arg0 : vector<5xf32> + + // CHECK-NEXT: spirv.ReturnValue %[[RETVAL]] : f32 + return %0#0 : f32 +} From 8fe386a4edfe8148e6bb57e7cbd84f2a82e02b78 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 5 Nov 2025 17:03:39 -0500 Subject: [PATCH 11/15] [mlir] Update unrollToElements tests --- .../Vector/Transforms/VectorUnroll.cpp | 5 ++-- .../ConvertToSPIRV/vector-unroll.mlir | 16 +++++++++++++ .../Vector/vector-to-elements-lowering.mlir | 23 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index c49718e0902a5..fd5a8f7c89d7d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1109,8 +1109,9 @@ void mlir::vector::populateVectorUnrollPatterns( void mlir::vector::populateVectorToElementsUnrollPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), UnrollVectorOptions(), - benefit); + auto options = UnrollVectorOptions().setNativeShape(SmallVector{4}); + patterns.add(patterns.getContext(), + options, benefit); } void mlir::vector::populateVectorFromElementsUnrollPatterns( diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir index 0957f67690b97..dcc55a7868978 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -120,6 +120,22 @@ func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) { // ----- +// CHECK-LABEL: @unroll_to_elements_8xf32 +func.func @unroll_to_elements_8xf32() -> (f32, f32) { + + // CHECK: %[[VEC:.+]] = "test.op" + // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [0] + // CHECK: %[[V1:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [4] + // CHECK: %[[ELEMS0:.+]]:4 = vector.to_elements %[[V0]] + // CHECK: %[[ELEMS1:.+]]:4 = vector.to_elements %[[V1]] + // CHECK: return %[[ELEMS0]]#3, %[[ELEMS1]]#0 + %0 = "test.op"() : () -> (vector<8xf32>) + %1:8 = vector.to_elements %0 : vector<8xf32> + return %1#3, %1#4 : f32, f32 +} + +// ----- + // In order to verify that the pattern is applied, // we need to make sure that the the 2d vector is used // by an operation and that extracts are not folded away. diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir index c521bf0138f98..d448377143249 100644 --- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -29,3 +29,26 @@ func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) %0:4 = vector.to_elements %arg0 : vector<2x2xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 } + +// ----- + +// COM: Here we are testing the pattern ToElementsToTargetShape +// COM: The pattern has a native shape of [4], which means +// COM: that vectors multiples of 4 will be split. In this +// COM: case, that will happen in the function's body, not the argument. + +// CHECK-LABEL: func.func @unroll_vector_8xf32 +// CHECK-SAME: (%[[ARG0:.+]]: vector<8xf32>) +func.func @unroll_vector_8xf32(%arg0: vector<8xf32>) -> (f32, f32) { + %0:8 = vector.to_elements %arg0 : vector<8xf32> + + // COM: We only return two elements, one from each of the + // COM: vectors. + return %0#3, %0#4: f32, f32 + + // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK-NEXT: %[[V1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK-NEXT: %[[ELEMS_0:.+]]:4 = vector.to_elements %[[V0]] + // CHECK-NEXT: %[[ELEMS_1:.+]]:4 = vector.to_elements %[[V1]] + // CHECK-NEXT: return %[[ELEMS_0]]#3, %[[ELEMS_1]]#0 +} From 521aec0c1644e68eca7f666f5970a04fc7b57029 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 6 Nov 2025 16:29:34 -0500 Subject: [PATCH 12/15] Update comment --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 22097f5f2cdc6..d5feafb1aa18b 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1438,7 +1438,7 @@ int mlir::spirv::getComputeVectorSize(int64_t size) { SmallVector mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) { VectorType srcVectorType = op.getSourceVectorType(); - assert(srcVectorType.getRank() == 1); // Guaranteed by semantics + assert(srcVectorType.getRank() == 1); // Guaranteed by UnrollToElements. int64_t vectorSize = mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); return {vectorSize}; From 5c1a19d59a4ef64e32d0ee8061dd94d196a5a87c Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 6 Nov 2025 16:36:52 -0500 Subject: [PATCH 13/15] Spell out the type for getTargetShape in the file --- .../Vector/Transforms/VectorUnroll.cpp | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fd5a8f7c89d7d..f511eb0b7b08d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -161,7 +161,8 @@ struct UnrollTransferReadPattern return failure(); if (readOp.getMask()) return failure(); - auto targetShape = getTargetShape(options, readOp); + std::optional> targetShape = + getTargetShape(options, readOp); if (!targetShape) return failure(); auto sourceVectorType = readOp.getVectorType(); @@ -216,7 +217,8 @@ struct UnrollTransferWritePattern if (writeOp.getMask()) return failure(); - auto targetShape = getTargetShape(options, writeOp); + std::optional> targetShape = + getTargetShape(options, writeOp); if (!targetShape) return failure(); auto sourceVectorType = writeOp.getVectorType(); @@ -287,7 +289,8 @@ struct UnrollContractionPattern LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, contractOp); + std::optional> targetShape = + getTargetShape(options, contractOp); if (!targetShape) return failure(); auto dstVecType = cast(contractOp.getResultType()); @@ -462,7 +465,8 @@ struct UnrollElementwisePattern : public RewritePattern { PatternRewriter &rewriter) const override { if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); - auto targetShape = getTargetShape(options, op); + std::optional> targetShape = + getTargetShape(options, op); if (!targetShape) return failure(); int64_t targetShapeRank = targetShape->size(); @@ -590,7 +594,8 @@ struct UnrollTransposePattern : public OpRewritePattern { PatternRewriter &rewriter) const override { if (transposeOp.getResultVectorType().getRank() == 0) return failure(); - auto targetShape = getTargetShape(options, transposeOp); + std::optional> targetShape = + getTargetShape(options, transposeOp); if (!targetShape) return failure(); auto originalVectorType = transposeOp.getResultVectorType(); @@ -643,7 +648,8 @@ struct UnrollGatherPattern : public OpRewritePattern { VectorType sourceVectorType = gatherOp.getVectorType(); if (sourceVectorType.getRank() == 0) return failure(); - auto targetShape = getTargetShape(options, gatherOp); + std::optional> targetShape = + getTargetShape(options, gatherOp); if (!targetShape) return failure(); SmallVector strides(targetShape->size(), 1); @@ -697,7 +703,8 @@ struct UnrollLoadPattern : public OpRewritePattern { PatternRewriter &rewriter) const override { VectorType vecType = loadOp.getVectorType(); - auto targetShape = getTargetShape(options, loadOp); + std::optional> targetShape = + getTargetShape(options, loadOp); if (!targetShape) return failure(); @@ -741,7 +748,8 @@ struct UnrollStorePattern : public OpRewritePattern { PatternRewriter &rewriter) const override { VectorType vecType = storeOp.getVectorType(); - auto targetShape = getTargetShape(options, storeOp); + std::optional> targetShape = + getTargetShape(options, storeOp); if (!targetShape) return failure(); @@ -780,7 +788,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, broadcastOp); + std::optional> targetShape = + getTargetShape(options, broadcastOp); if (!targetShape) return failure(); @@ -863,7 +872,8 @@ struct ToElementsToTargetShape final LogicalResult matchAndRewrite(vector::ToElementsOp op, PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, op); + std::optional> targetShape = + getTargetShape(options, op); if (!targetShape) return failure(); From 46326712c83513d22597aea031e91d3a73e96881 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 6 Nov 2025 16:37:53 -0500 Subject: [PATCH 14/15] Use -- instead of - for options --- mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir index 402c539a77093..c04f3c2a4d429 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" --split-input-file %s | FileCheck %s // COM: This file tests the current behaviour of the SignatureConversion // COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V From 5bb3f9302a514e439a064c6a0315b5f15fc02d6f Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 6 Nov 2025 16:40:26 -0500 Subject: [PATCH 15/15] remove COM: --- .../ConvertToSPIRV/vector-sizes.mlir | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir index c04f3c2a4d429..26e55cac4b507 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir @@ -1,15 +1,15 @@ // RUN: mlir-opt --test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" --split-input-file %s | FileCheck %s -// COM: This file tests the current behaviour of the SignatureConversion -// COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V -// COM: sizes. +// This file tests the current behaviour of the SignatureConversion +// and the unrolling of vector.to_elements to vectors of valid SPIR-V +// sizes. -// COM: vector's of rank 1 and size 1 will be changed -// COM: to scalars. Since vector.to_elements will also produce -// COM: a scalar, we expect the vector.to_elements to be folded -// COM: away. Please note that even if run-signature-conversion=false -// COM: The pattern FuncOpConversion will still run and change parameters -// COM: which fit this constraint. +// vector's of rank 1 and size 1 will be changed +// to scalars. Since vector.to_elements will also produce +// a scalar, we expect the vector.to_elements to be folded +// away. Please note that even if run-signature-conversion=false +// The pattern FuncOpConversion will still run and change parameters +// which fit this constraint. // CHECK-LABEL: spirv.func @vec_size_1 // CHECK-SAME: (%[[ARG0:.+]]: f32) @@ -21,14 +21,14 @@ func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) { // ----- -// COM: vector's of rank 2, 3, 4 are allowed by SPIR-V. +// vector's of rank 2, 3, 4 are allowed by SPIR-V. // So they remain unchanged. FuncOpConversion will still // run, but the signature converter will not convert these vectors. // CHECK-LABEL: spirv.func @vec_size_2 // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>) func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) { - // COM: A single result type is enforced by the semantics + // A single result type is enforced by the semantics // CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> %0:2 = vector.to_elements %arg0 : vector<2xf32> @@ -39,11 +39,11 @@ func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) { // ----- -// COM: vector of rank 5 is the first one that doesn't fit -// COM: into SPIR-V's vectors. +// vector of rank 5 is the first one that doesn't fit +// into SPIR-V's vectors. -// COM: run-signature-conversion=false means that -// COM: this vector will not be unrolled. +// run-signature-conversion=false means that +// this vector will not be unrolled. // CHECK-LABEL: func.func @vec_size_5 // CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>) @@ -51,13 +51,13 @@ func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) { // CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32> - // COM: We have the following comment in VectorConvertToElementOp - // COM: - // COM: // Input vectors of size 1 are converted to scalars by the type converter. - // COM: // We cannot use `spirv::CompositeExtractOp` directly in this case. - // COM: // For a scalar source, the result is just the scalar itself. - // COM: - // COM: Which in this case means an unrealized conversion cast. + // We have the following comment in VectorConvertToElementOp + // + // // Input vectors of size 1 are converted to scalars by the type converter. + // // We cannot use `spirv::CompositeExtractOp` directly in this case. + // // For a scalar source, the result is just the scalar itself. + // + // Which in this case means an unrealized conversion cast. // CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32 %0:5 = vector.to_elements %arg0 : vector<5xf32>