diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff2082df..a1c5298629e58 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [Pure, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, @@ -732,7 +732,7 @@ def Vector_ExtractOp : def Vector_FMAOp : Op, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, VectorOfAnyRankOf<[AnyFloat]>:$rhs, @@ -1245,7 +1245,7 @@ def Vector_ExtractStridedSliceOp : def Vector_TransferReadOp : Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -1653,7 +1653,7 @@ def Vector_TransferWriteOp : } def Vector_LoadOp : Vector_Op<"load", [ - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { @@ -2057,7 +2057,7 @@ def Vector_GatherOp : Vector_Op<"gather", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins Arg, "", [MemRead]>:$base, @@ -2758,7 +2758,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ def Vector_TransposeOp : Vector_Op<"transpose", [Pure, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]> { let summary = "vector transpose operation"; diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index 6838c16fdf0fe..1223f5c0704ab 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> { let methods = [ InterfaceMethod< /*desc=*/[{ - Return the shape ratio of unrolling to the target vector shape - `targetShape`. Return `std::nullopt` if the op cannot be unrolled to the - target vector shape. + Return the shape of the vector of this operation, which may be used to decide unrolling factors. + Return std::nullopt if the op is not applicable for unrolling. }], /*retTy=*/"::std::optional<::llvm::SmallVector>", /*methodName=*/"getShapeForUnroll", diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba02100a..b030b060c6ba0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2374,14 +2374,6 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr, results.push_back(llvm::cast(attr).getInt()); } -//===----------------------------------------------------------------------===// -// FmaOp -//===----------------------------------------------------------------------===// - -std::optional> FMAOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - //===----------------------------------------------------------------------===// // ToElementsOp //===----------------------------------------------------------------------===// @@ -2782,10 +2774,6 @@ void BroadcastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } -std::optional> BroadcastOp::getShapeForUnroll() { - return llvm::to_vector<4>(getResultVectorType().getShape()); -} - /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. static llvm::SetVector @@ -5100,10 +5088,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) { return OpFoldResult(); } -std::optional> TransferReadOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - void TransferReadOp::getEffects( SmallVectorImpl> &effects) { @@ -5778,10 +5762,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) { return OpFoldResult(); } -std::optional> LoadOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - FailureOr>> LoadOp::bubbleDownCasts(OpBuilder &builder) { return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), @@ -5986,10 +5966,6 @@ Type GatherOp::getExpectedMaskType() { vecType.getScalableDims()); } -std::optional> GatherOp::getShapeForUnroll() { - return llvm::to_vector<4>(getVectorType().getShape()); -} - /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...] static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { auto vecType = dyn_cast(indexVec.getType()); @@ -6720,10 +6696,6 @@ LogicalResult vector::TransposeOp::verify() { return success(); } -std::optional> TransposeOp::getShapeForUnroll() { - return llvm::to_vector<4>(getResultVectorType().getShape()); -} - void TransposeOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), argRanges.front());