Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
// the target shape.
int getComputeVectorSize(int64_t size);

// GetNativeVectorShape implementation for to_elements ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ToElementsOp op);

// GetNativeVectorShape implementation for reduction ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);

Expand Down
20 changes: 14 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :

def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Expand Down Expand Up @@ -732,7 +732,7 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
Expand Down Expand Up @@ -762,6 +762,7 @@ def Vector_FMAOp :

def Vector_ToElementsOp : Vector_Op<"to_elements", [
InferTypeOpAdaptor, Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
Expand Down Expand Up @@ -808,6 +809,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{

VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}

}];
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Expand Down Expand Up @@ -1245,7 +1253,7 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
Expand Down Expand Up @@ -1653,7 +1661,7 @@ def Vector_TransferWriteOp :
}

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
Expand Down Expand Up @@ -2057,7 +2065,7 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2758,7 +2766,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
Expand Down
5 changes: 2 additions & 3 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the shape ratio of unrolling to the target vector shape
`targetShape`. Return `std::nullopt` if the op cannot be unrolled to the
target vector shape.
Return the shape of the vector of this operation, which may be used to decide unrolling factors.
Return std::nullopt if the op is not applicable for unrolling.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<int64_t, 4>>",
/*methodName=*/"getShapeForUnroll",
Expand Down
11 changes: 10 additions & 1 deletion mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) {
return 1;
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) {
VectorType srcVectorType = op.getSourceVectorType();
assert(srcVectorType.getRank() == 1); // Guaranteed by UnrollToElements.
int64_t vectorSize =
mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
return {vectorSize};
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
VectorType srcVectorType = op.getSourceVectorType();
Expand Down Expand Up @@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
}

return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
.Case<vector::ReductionOp, vector::TransposeOp, vector::ToElementsOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
.Default(std::nullopt);
}
Expand Down
31 changes: 3 additions & 28 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2374,17 +2374,12 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
}

//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//

std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

//===----------------------------------------------------------------------===//
// ToElementsOp
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t, 4>> ToElementsOp::getShapeForUnroll() {
return llvm::to_vector<4>(getSourceVectorType().getShape());
}

/// Returns true if all the `operands` are defined by `defOp`.
/// Otherwise, returns false.
Expand Down Expand Up @@ -2782,10 +2777,6 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
Expand Down Expand Up @@ -5100,10 +5091,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Expand Down Expand Up @@ -5778,10 +5765,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
Expand Down Expand Up @@ -5986,10 +5969,6 @@ Type GatherOp::getExpectedMaskType() {
vecType.getScalableDims());
}

std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
auto vecType = dyn_cast<VectorType>(indexVec.getType());
Expand Down Expand Up @@ -6720,10 +6699,6 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}

std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
Expand Down
Loading