@@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
834834 vector::UnrollVectorOptions options;
835835};
836836
837+ // / Takes a 1 dimensional `vector.to_element` op and attempts to change it to
838+ // / the target shape.
839+ // /
840+ // / ```
841+ // / // In SPIR-V's default environment vector of size 8
842+ // / // are not allowed.
843+ // / %elements:8 = vector.to_elements %v : vector<8xf32>
844+ // /
845+ // / ===>
846+ // /
847+ // / %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
848+ // / %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
849+ // / %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
850+ // / %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
851+ // / ```
852+ // /
853+ // / This pattern may fail if the rank is not divisible by to a native shape
854+ // / or if the rank is already in the target shape and therefore it may be
855+ // / skipped.
856+ struct ToElementsToTargetShape final
857+ : public OpRewritePattern<vector::ToElementsOp> {
858+ ToElementsToTargetShape (MLIRContext *context,
859+ const vector::UnrollVectorOptions &options,
860+ PatternBenefit benefit = 1 )
861+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
862+ options (options) {}
863+
864+ LogicalResult matchAndRewrite (vector::ToElementsOp op,
865+ PatternRewriter &rewriter) const override {
866+ auto targetShape = getTargetShape (options, op);
867+ if (!targetShape)
868+ return failure ();
869+
870+ // We have
871+ // source_rank = N * target_rank
872+ int64_t source_rank = op.getSourceVectorType ().getShape ().front ();
873+ int64_t target_rank = targetShape->front ();
874+ int64_t N = source_rank / target_rank;
875+
876+ // Transformation where
877+ // s = source_rank and
878+ // t = target_rank
879+ // ```
880+ // %e:s = vector.to_elements %v : vector<sxf32>
881+ //
882+ // ===>
883+ //
884+ // // N vector.extract_strided_slice of size t
885+ // %v0 = vector.extract_strided_slice %v
886+ // {offsets = [0*t], sizes = [t], strides = [1]}
887+ // : vector<txf32> from vector<sxf32>
888+ // %v1 = vector.extract_strided_slice %v
889+ // {offsets = [1*t], sizes = [t], strides = [1]}
890+ // : vector<txf32> from vector<sxf32>
891+ // ...
892+ // %vNminus1 = vector.extract_strided_slice $v
893+ // {offsets = [(N-1)*t], sizes = [t], strides = [1]}
894+ // : vector<txf32> from vector<sxf32>
895+ //
896+ // // N vector.to_elements of size t vectors.
897+ // %e0:t = vector.to_elements %v0 : vector<txf32>
898+ // %e1:t = vector.to_elements %v1 : vector<txf32>
899+ // ...
900+ // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
901+ // ```
902+ SmallVector<Value> subVectors;
903+ SmallVector<int64_t > strides (targetShape->size (), 1 );
904+ for (int64_t i = 0 ; i < N; i++) {
905+ SmallVector<int64_t > elementOffsets = {i * target_rank};
906+ Value subVector = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
907+ op.getLoc (), op.getSource (), elementOffsets, *targetShape, strides);
908+ subVectors.push_back (subVector);
909+ }
910+
911+ SmallVector<Value> elements;
912+ for (const Value subVector : subVectors) {
913+ auto elementsOp =
914+ vector::ToElementsOp::create (rewriter, op.getLoc (), subVector);
915+ llvm::append_range (elements, elementsOp.getResults ());
916+ }
917+
918+ rewriter.replaceOp (op, elements);
919+ return success ();
920+ }
921+
922+ private:
923+ vector::UnrollVectorOptions options;
924+ };
925+
837926// / Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838927// / outermost dimension of the operand. For example:
839928// /
840929// / ```
841- // / %0:4 = vector.to_elements %v : vector<2x2xf32 >
930+ // / %0:8 = vector.to_elements %v : vector<2x2x2xf32 >
842931// /
843932// / ==>
844933// /
@@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
865954 FailureOr<SmallVector<Value>> result =
866955 vector::unrollVectorValue (source, rewriter);
867956 if (failed (result)) {
957+ // Only fails if operand is 1-dimensional.
868958 return failure ();
869959 }
870960 SmallVector<Value> vectors = *result;
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131103 UnrollReductionPattern, UnrollMultiReductionPattern,
10141104 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151105 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016- UnrollToElements, UnrollStepPattern>(patterns. getContext (),
1017- options, benefit);
1106+ UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
1107+ patterns. getContext (), options, benefit);
10181108}
10191109
10201110void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments