Skip to content

Commit 8fe386a

Browse files
committed
[mlir] Update unrollToElements tests
1 parent 228d0b1 commit 8fe386a

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,8 +1109,9 @@ void mlir::vector::populateVectorUnrollPatterns(
11091109

11101110
void mlir::vector::populateVectorToElementsUnrollPatterns(
11111111
RewritePatternSet &patterns, PatternBenefit benefit) {
1112-
patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1113-
benefit);
1112+
auto options = UnrollVectorOptions().setNativeShape(SmallVector<int64_t>{4});
1113+
patterns.add<UnrollToElements, ToElementsToTargetShape>(patterns.getContext(),
1114+
options, benefit);
11141115
}
11151116

11161117
void mlir::vector::populateVectorFromElementsUnrollPatterns(

mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,22 @@ func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
120120

121121
// -----
122122

123+
// CHECK-LABEL: @unroll_to_elements_8xf32
124+
func.func @unroll_to_elements_8xf32() -> (f32, f32) {
125+
126+
// CHECK: %[[VEC:.+]] = "test.op"
127+
// CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [0]
128+
// CHECK: %[[V1:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [4]
129+
// CHECK: %[[ELEMS0:.+]]:4 = vector.to_elements %[[V0]]
130+
// CHECK: %[[ELEMS1:.+]]:4 = vector.to_elements %[[V1]]
131+
// CHECK: return %[[ELEMS0]]#3, %[[ELEMS1]]#0
132+
%0 = "test.op"() : () -> (vector<8xf32>)
133+
%1:8 = vector.to_elements %0 : vector<8xf32>
134+
return %1#3, %1#4 : f32, f32
135+
}
136+
137+
// -----
138+
123139
// In order to verify that the pattern is applied,
124140
// we need to make sure that the the 2d vector is used
125141
// by an operation and that extracts are not folded away.

mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,26 @@ func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32)
2929
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
3030
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
3131
}
32+
33+
// -----
34+
35+
// COM: Here we are testing the pattern ToElementsToTargetShape
36+
// COM: The pattern has a native shape of [4], which means
37+
// COM: that vectors multiples of 4 will be split. In this
38+
// COM: case, that will happen in the function's body, not the argument.
39+
40+
// CHECK-LABEL: func.func @unroll_vector_8xf32
41+
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xf32>)
42+
func.func @unroll_vector_8xf32(%arg0: vector<8xf32>) -> (f32, f32) {
43+
%0:8 = vector.to_elements %arg0 : vector<8xf32>
44+
45+
// COM: We only return two elements, one from each of the
46+
// COM: vectors.
47+
return %0#3, %0#4: f32, f32
48+
49+
// CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
50+
// CHECK-NEXT: %[[V1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
51+
// CHECK-NEXT: %[[ELEMS_0:.+]]:4 = vector.to_elements %[[V0]]
52+
// CHECK-NEXT: %[[ELEMS_1:.+]]:4 = vector.to_elements %[[V1]]
53+
// CHECK-NEXT: return %[[ELEMS_0]]#3, %[[ELEMS_1]]#0
54+
}

0 commit comments

Comments
 (0)