-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Adds ToElementsToTargetShape pattern. #166476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
fbbf0e4
1964d16
a0c6e4f
a6cbe0b
cd648da
5103187
71e53e7
200773d
aa4906a
228d0b1
8fe386a
521aec0
5c1a19d
4632671
5bb3f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> { | |||||||||
| vector::UnrollVectorOptions options; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| /// Takes a 1 dimensional `vector.to_element` op and attempts to change it to | ||||||||||
| /// the target shape. | ||||||||||
| /// | ||||||||||
| /// ``` | ||||||||||
| /// // In SPIR-V's default environment vector of size 8 | ||||||||||
| /// // are not allowed. | ||||||||||
|
Comment on lines
+850
to
+851
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: reflow this comment |
||||||||||
| /// %elements:8 = vector.to_elements %v : vector<8xf32> | ||||||||||
| /// | ||||||||||
| /// ===> | ||||||||||
| /// | ||||||||||
| /// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32> | ||||||||||
| /// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32> | ||||||||||
| /// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32> | ||||||||||
| /// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32> | ||||||||||
| /// ``` | ||||||||||
| /// | ||||||||||
| /// This pattern may fail if the rank is not divisible by to a native shape | ||||||||||
| /// or if the rank is already in the target shape and therefore it may be | ||||||||||
| /// skipped. | ||||||||||
| struct ToElementsToTargetShape final | ||||||||||
| : public OpRewritePattern<vector::ToElementsOp> { | ||||||||||
| ToElementsToTargetShape(MLIRContext *context, | ||||||||||
| const vector::UnrollVectorOptions &options, | ||||||||||
| PatternBenefit benefit = 1) | ||||||||||
| : OpRewritePattern<vector::ToElementsOp>(context, benefit), | ||||||||||
| options(options) {} | ||||||||||
|
|
||||||||||
| LogicalResult matchAndRewrite(vector::ToElementsOp op, | ||||||||||
| PatternRewriter &rewriter) const override { | ||||||||||
| auto targetShape = getTargetShape(options, op); | ||||||||||
amd-eochoalo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| if (!targetShape) | ||||||||||
| return failure(); | ||||||||||
|
|
||||||||||
| // We have | ||||||||||
| // source_rank = N * target_rank | ||||||||||
| int64_t source_rank = op.getSourceVectorType().getShape().front(); | ||||||||||
| int64_t target_rank = targetShape->front(); | ||||||||||
| int64_t N = source_rank / target_rank; | ||||||||||
|
|
||||||||||
| // Transformation where | ||||||||||
| // s = source_rank and | ||||||||||
| // t = target_rank | ||||||||||
| // ``` | ||||||||||
| // %e:s = vector.to_elements %v : vector<sxf32> | ||||||||||
| // | ||||||||||
| // ===> | ||||||||||
| // | ||||||||||
| // // N vector.extract_strided_slice of size t | ||||||||||
| // %v0 = vector.extract_strided_slice %v | ||||||||||
| // {offsets = [0*t], sizes = [t], strides = [1]} | ||||||||||
| // : vector<txf32> from vector<sxf32> | ||||||||||
| // %v1 = vector.extract_strided_slice %v | ||||||||||
| // {offsets = [1*t], sizes = [t], strides = [1]} | ||||||||||
| // : vector<txf32> from vector<sxf32> | ||||||||||
| // ... | ||||||||||
| // %vNminus1 = vector.extract_strided_slice $v | ||||||||||
| // {offsets = [(N-1)*t], sizes = [t], strides = [1]} | ||||||||||
| // : vector<txf32> from vector<sxf32> | ||||||||||
| // | ||||||||||
| // // N vector.to_elements of size t vectors. | ||||||||||
| // %e0:t = vector.to_elements %v0 : vector<txf32> | ||||||||||
| // %e1:t = vector.to_elements %v1 : vector<txf32> | ||||||||||
| // ... | ||||||||||
| // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32> | ||||||||||
| // ``` | ||||||||||
| SmallVector<Value> subVectors; | ||||||||||
| SmallVector<int64_t> strides(targetShape->size(), 1); | ||||||||||
| for (int64_t i = 0; i < N; i++) { | ||||||||||
| SmallVector<int64_t> elementOffsets = {i * target_rank}; | ||||||||||
| Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>( | ||||||||||
| op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides); | ||||||||||
| subVectors.push_back(subVector); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| SmallVector<Value> elements; | ||||||||||
| for (const Value subVector : subVectors) { | ||||||||||
| auto elementsOp = | ||||||||||
|
Comment on lines
+922
to
+923
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| vector::ToElementsOp::create(rewriter, op.getLoc(), subVector); | ||||||||||
| llvm::append_range(elements, elementsOp.getResults()); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| rewriter.replaceOp(op, elements); | ||||||||||
| return success(); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| private: | ||||||||||
| vector::UnrollVectorOptions options; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the | ||||||||||
| /// outermost dimension of the operand. For example: | ||||||||||
| /// | ||||||||||
| /// ``` | ||||||||||
| /// %0:4 = vector.to_elements %v : vector<2x2xf32> | ||||||||||
| /// %0:8 = vector.to_elements %v : vector<2x2x2xf32> | ||||||||||
| /// | ||||||||||
| /// ==> | ||||||||||
| /// | ||||||||||
|
|
@@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> { | |||||||||
| FailureOr<SmallVector<Value>> result = | ||||||||||
| vector::unrollVectorValue(source, rewriter); | ||||||||||
| if (failed(result)) { | ||||||||||
| // Only fails if operand is 1-dimensional. | ||||||||||
| return failure(); | ||||||||||
| } | ||||||||||
| SmallVector<Value> vectors = *result; | ||||||||||
|
|
@@ -1013,14 +1103,15 @@ void mlir::vector::populateVectorUnrollPatterns( | |||||||||
| UnrollReductionPattern, UnrollMultiReductionPattern, | ||||||||||
| UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, | ||||||||||
| UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, | ||||||||||
| UnrollToElements, UnrollStepPattern>(patterns.getContext(), | ||||||||||
| options, benefit); | ||||||||||
| UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>( | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you chose a different name from all the other patterns? (I'm not saying this is a bad name, just that it creates some asymmetry)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is already an UnrollToElements pattern. I prefer small simple patterns over larger ones. I'll turn this back to a draft for the time being. I'm thinking about maybe moving this pattern somewhere else. While working on adding the analogous to This only happens when the targetShape is {1} which can happen when the vector in the IR is not divisible by one of the native sizes. For example, the native vector size is 4 and the vector in the IR is This I believe comes from: int mlir::spirv::getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
}
return 1;
}When I go around this issue, I get I'm still thinking about whether there is another option. (For example, replace |
||||||||||
| patterns.getContext(), options, benefit); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| void mlir::vector::populateVectorToElementsUnrollPatterns( | ||||||||||
| RewritePatternSet &patterns, PatternBenefit benefit) { | ||||||||||
| patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(), | ||||||||||
| benefit); | ||||||||||
| auto options = UnrollVectorOptions().setNativeShape(SmallVector<int64_t>{4}); | ||||||||||
| patterns.add<UnrollToElements, ToElementsToTargetShape>(patterns.getContext(), | ||||||||||
| options, benefit); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| void mlir::vector::populateVectorFromElementsUnrollPatterns( | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| // RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s | ||
kuhar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // COM: This file tests the current behaviour of the SignatureConversion | ||
kuhar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V | ||
| // COM: sizes. | ||
|
|
||
| // COM: vector's of rank 1 and size 1 will be changed | ||
| // COM: to scalars. Since vector.to_elements will also produce | ||
| // COM: a scalar, we expect the vector.to_elements to be folded | ||
| // COM: away. Please note that even if run-signature-conversion=false | ||
| // COM: The pattern FuncOpConversion will still run and change parameters | ||
| // COM: which fit this constraint. | ||
|
|
||
| // CHECK-LABEL: spirv.func @vec_size_1 | ||
| // CHECK-SAME: (%[[ARG0:.+]]: f32) | ||
| func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) { | ||
| // CHECK-NEXT: spirv.ReturnValue %[[ARG0]] : f32 | ||
| %0:1 = vector.to_elements %arg0 : vector<1xf32> | ||
| return %0#0 : f32 | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // COM: vector's of rank 2, 3, 4 are allowed by SPIR-V. | ||
| // So they remain unchanged. FuncOpConversion will still | ||
| // run, but the signature converter will not convert these vectors. | ||
|
|
||
| // CHECK-LABEL: spirv.func @vec_size_2 | ||
| // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>) | ||
| func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) { | ||
| // COM: A single result type is enforced by the semantics | ||
|
|
||
| // CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> | ||
| %0:2 = vector.to_elements %arg0 : vector<2xf32> | ||
|
|
||
| // CHECK-NEXT: spirv.ReturnValue %[[VAL]] | ||
| return %0#0 : f32 | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // COM: vector of rank 5 is the first one that doesn't fit | ||
| // COM: into SPIR-V's vectors. | ||
|
|
||
| // COM: run-signature-conversion=false means that | ||
| // COM: this vector will not be unrolled. | ||
|
|
||
| // CHECK-LABEL: func.func @vec_size_5 | ||
| // CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>) | ||
| func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) { | ||
|
|
||
| // CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32> | ||
|
|
||
| // COM: We have the following comment in VectorConvertToElementOp | ||
| // COM: | ||
| // COM: // Input vectors of size 1 are converted to scalars by the type converter. | ||
| // COM: // We cannot use `spirv::CompositeExtractOp` directly in this case. | ||
| // COM: // For a scalar source, the result is just the scalar itself. | ||
| // COM: | ||
| // COM: Which in this case means an unrealized conversion cast. | ||
|
|
||
| // CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32 | ||
| %0:5 = vector.to_elements %arg0 : vector<5xf32> | ||
|
|
||
| // CHECK-NEXT: spirv.ReturnValue %[[RETVAL]] : f32 | ||
| return %0#0 : f32 | ||
|
Comment on lines
+63
to
+66
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe show what happens when some other result is returned -- this way we could check that we extract the collect element |
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.