Skip to content
Open
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :

def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Expand Down Expand Up @@ -732,7 +732,7 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
Expand Down Expand Up @@ -1245,7 +1245,7 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
Expand Down Expand Up @@ -1653,7 +1653,7 @@ def Vector_TransferWriteOp :
}

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
Expand Down Expand Up @@ -2057,7 +2057,7 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2758,7 +2758,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
Expand Down
5 changes: 2 additions & 3 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the shape ratio of unrolling to the target vector shape
`targetShape`. Return `std::nullopt` if the op cannot be unrolled to the
target vector shape.
Return the shape of the vector of this operation, which may be used to decide unrolling factors.
Return std::nullopt if the op is not applicable for unrolling.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<int64_t, 4>>",
/*methodName=*/"getShapeForUnroll",
Expand Down
28 changes: 0 additions & 28 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2374,14 +2374,6 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
}

//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//

std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

//===----------------------------------------------------------------------===//
// ToElementsOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2782,10 +2774,6 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
Expand Down Expand Up @@ -5100,10 +5088,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Expand Down Expand Up @@ -5778,10 +5762,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}

std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
Expand Down Expand Up @@ -5986,10 +5966,6 @@ Type GatherOp::getExpectedMaskType() {
vecType.getScalableDims());
}

std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
auto vecType = dyn_cast<VectorType>(indexVec.getType());
Expand Down Expand Up @@ -6720,10 +6696,6 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}

std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
Expand Down