Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
// the target shape.
int getComputeVectorSize(int64_t size);

// GetNativeVectorShape implementation for to_elements ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ToElementsOp op);

// GetNativeVectorShape implementation for reduction ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);

Expand Down
20 changes: 14 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :

def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Expand Down Expand Up @@ -732,7 +732,7 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
Expand Down Expand Up @@ -762,6 +762,7 @@ def Vector_FMAOp :

def Vector_ToElementsOp : Vector_Op<"to_elements", [
InferTypeOpAdaptor, Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
Expand Down Expand Up @@ -808,6 +809,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{

VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}

}];
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Expand Down Expand Up @@ -1245,7 +1253,7 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
Expand Down Expand Up @@ -1653,7 +1661,7 @@ def Vector_TransferWriteOp :
}

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
Expand Down Expand Up @@ -2057,7 +2065,7 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2758,7 +2766,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
Expand Down
5 changes: 2 additions & 3 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the shape ratio of unrolling to the target vector shape
`targetShape`. Return `std::nullopt` if the op cannot be unrolled to the
target vector shape.
Return the shape of the vector of this operation, which may be used to decide unrolling factors.
Return std::nullopt if the op is not applicable for unrolling.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<int64_t, 4>>",
/*methodName=*/"getShapeForUnroll",
Expand Down
11 changes: 10 additions & 1 deletion mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) {
return 1;
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) {
VectorType srcVectorType = op.getSourceVectorType();
assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
int64_t vectorSize =
mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
return {vectorSize};
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
VectorType srcVectorType = op.getSourceVectorType();
Expand Down Expand Up @@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
}

return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
.Case<vector::ReductionOp, vector::TransposeOp, vector::ToElementsOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
.Default(std::nullopt);
}
Expand Down
31 changes: 3 additions & 28 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2374,17 +2374,12 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
}

//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//

std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

//===----------------------------------------------------------------------===//
// ToElementsOp
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t, 4>> ToElementsOp::getShapeForUnroll() {
return llvm::to_vector<4>(getSourceVectorType().getShape());
}

/// Returns true if all the `operands` are defined by `defOp`.
/// Otherwise, returns false.
Expand Down Expand Up @@ -2782,10 +2777,6 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
Expand Down Expand Up @@ -5100,10 +5091,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Expand Down Expand Up @@ -5778,10 +5765,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
Expand Down Expand Up @@ -5986,10 +5969,6 @@ Type GatherOp::getExpectedMaskType() {
vecType.getScalableDims());
}

std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
auto vecType = dyn_cast<VectorType>(indexVec.getType());
Expand Down Expand Up @@ -6720,10 +6699,6 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}

std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
Expand Down
101 changes: 96 additions & 5 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};

/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
/// the target shape.
///
/// ```
/// // In SPIR-V's default environment vector of size 8
/// // are not allowed.
Comment on lines +850 to +851
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reflow this comment

/// %elements:8 = vector.to_elements %v : vector<8xf32>
///
/// ===>
///
/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
/// ```
///
/// This pattern may fail if the rank is not divisible by to a native shape
/// or if the rank is already in the target shape and therefore it may be
/// skipped.
struct ToElementsToTargetShape final
: public OpRewritePattern<vector::ToElementsOp> {
ToElementsToTargetShape(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ToElementsOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();

// We have
// source_rank = N * target_rank
int64_t source_rank = op.getSourceVectorType().getShape().front();
int64_t target_rank = targetShape->front();
int64_t N = source_rank / target_rank;

// Transformation where
// s = source_rank and
// t = target_rank
// ```
// %e:s = vector.to_elements %v : vector<sxf32>
//
// ===>
//
// // N vector.extract_strided_slice of size t
// %v0 = vector.extract_strided_slice %v
// {offsets = [0*t], sizes = [t], strides = [1]}
// : vector<txf32> from vector<sxf32>
// %v1 = vector.extract_strided_slice %v
// {offsets = [1*t], sizes = [t], strides = [1]}
// : vector<txf32> from vector<sxf32>
// ...
// %vNminus1 = vector.extract_strided_slice $v
// {offsets = [(N-1)*t], sizes = [t], strides = [1]}
// : vector<txf32> from vector<sxf32>
//
// // N vector.to_elements of size t vectors.
// %e0:t = vector.to_elements %v0 : vector<txf32>
// %e1:t = vector.to_elements %v1 : vector<txf32>
// ...
// %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
// ```
SmallVector<Value> subVectors;
SmallVector<int64_t> strides(targetShape->size(), 1);
for (int64_t i = 0; i < N; i++) {
SmallVector<int64_t> elementOffsets = {i * target_rank};
Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
subVectors.push_back(subVector);
}

SmallVector<Value> elements;
for (const Value subVector : subVectors) {
auto elementsOp =
Comment on lines +922 to +923
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (const Value subVector : subVectors) {
auto elementsOp =
for (Value subVector : subVectors) {
auto elementsOp =

vector::ToElementsOp::create(rewriter, op.getLoc(), subVector);
llvm::append_range(elements, elementsOp.getResults());
}

rewriter.replaceOp(op, elements);
return success();
}

private:
vector::UnrollVectorOptions options;
};

/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
/// outermost dimension of the operand. For example:
///
/// ```
/// %0:4 = vector.to_elements %v : vector<2x2xf32>
/// %0:8 = vector.to_elements %v : vector<2x2x2xf32>
///
/// ==>
///
Expand All @@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
FailureOr<SmallVector<Value>> result =
vector::unrollVectorValue(source, rewriter);
if (failed(result)) {
// Only fails if operand is 1-dimensional.
return failure();
}
SmallVector<Value> vectors = *result;
Expand Down Expand Up @@ -1013,14 +1103,15 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
options, benefit);
UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you chose a different name from all the other patterns? (I'm not saying this is a bad name, just that it creates some asymmetry)

Copy link
Contributor Author

@amd-eochoalo amd-eochoalo Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already an UnrollToElements pattern. I prefer small simple patterns over larger ones. I'll turn this back to a draft for the time being.

I'm thinking about maybe moving this pattern somewhere else. While working on adding the analogous to FromElementsOp I found that there's an issue. The issue is that I was attempting to break down FromElementsOp into multiple FromElementsOp with a suitable vector length and then reconstruct the target type with InsertOp, but the canonicalizer reverts this via the InsertChainFullyInitialized pattern.

This only happens when the targetShape is {1} which can happen when the vector in the IR is not divisible by one of the native sizes. For example, the native vector size is 4 and the vector in the IR is vector<5xf32> then the current getTargetShape will suggest rewriting it to 5 vector<1xf32>.

This I believe comes from:

int mlir::spirv::getComputeVectorSize(int64_t size) {
  for (int i : {4, 3, 2}) {
    if (size % i == 0)
      return i;
  }
  return 1;
}

When I go around this issue, I get builtin.unrealized_conversion_casts.

I'm still thinking about whether there is another option. (For example, replace vector<5xf32> with two vector<4xf32> and then extract what's needed from the other one.

patterns.getContext(), options, benefit);
}

void mlir::vector::populateVectorToElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
benefit);
auto options = UnrollVectorOptions().setNativeShape(SmallVector<int64_t>{4});
patterns.add<UnrollToElements, ToElementsToTargetShape>(patterns.getContext(),
options, benefit);
}

void mlir::vector::populateVectorFromElementsUnrollPatterns(
Expand Down
67 changes: 67 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s

// COM: This file tests the current behaviour of the SignatureConversion
// COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V
// COM: sizes.

// COM: vector's of rank 1 and size 1 will be changed
// COM: to scalars. Since vector.to_elements will also produce
// COM: a scalar, we expect the vector.to_elements to be folded
// COM: away. Please note that even if run-signature-conversion=false
// COM: The pattern FuncOpConversion will still run and change parameters
// COM: which fit this constraint.

// CHECK-LABEL: spirv.func @vec_size_1
// CHECK-SAME: (%[[ARG0:.+]]: f32)
func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) {
// CHECK-NEXT: spirv.ReturnValue %[[ARG0]] : f32
%0:1 = vector.to_elements %arg0 : vector<1xf32>
return %0#0 : f32
}

// -----

// COM: vector's of rank 2, 3, 4 are allowed by SPIR-V.
// So they remain unchanged. FuncOpConversion will still
// run, but the signature converter will not convert these vectors.

// CHECK-LABEL: spirv.func @vec_size_2
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) {
// COM: A single result type is enforced by the semantics

// CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
%0:2 = vector.to_elements %arg0 : vector<2xf32>

// CHECK-NEXT: spirv.ReturnValue %[[VAL]]
return %0#0 : f32
}

// -----

// COM: vector of rank 5 is the first one that doesn't fit
// COM: into SPIR-V's vectors.

// COM: run-signature-conversion=false means that
// COM: this vector will not be unrolled.

// CHECK-LABEL: func.func @vec_size_5
// CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>)
func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) {

// CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32>

// COM: We have the following comment in VectorConvertToElementOp
// COM:
// COM: // Input vectors of size 1 are converted to scalars by the type converter.
// COM: // We cannot use `spirv::CompositeExtractOp` directly in this case.
// COM: // For a scalar source, the result is just the scalar itself.
// COM:
// COM: Which in this case means an unrealized conversion cast.

// CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32
%0:5 = vector.to_elements %arg0 : vector<5xf32>

// CHECK-NEXT: spirv.ReturnValue %[[RETVAL]] : f32
return %0#0 : f32
Comment on lines +63 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe show what happens when some other result is returned -- this way we could check that we extract the collect element

}
Loading
Loading