diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index f4606ea111dd..3127aa11dd08 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -6454,6 +6454,11 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { for (auto s : shape) size *= s; + if (size == 0) { + return rewriter.notifyMatchFailure(op, + "Shape must not have zero dimensions"); + } + SmallVector values(size, fillVal); auto constOp = tosa::getConstTensor(rewriter, op, values, shape).value(); diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 3c1ae66aac3c..daacadc7727e 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -146,3 +146,17 @@ func.func @torch.aten.size.int(%arg0: !torch.vtensor<[4,2],f32>) -> !torch.int { %0 = torch.aten.size.int %arg0, %c2 : !torch.vtensor<[4,2],f32>, !torch.int -> !torch.int return %0 : !torch.int } + +// ----- +func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{ + %c1 = torch.constant.int 1 + %c0 = torch.constant.int 0 + %c256 = torch.constant.int 256 + %2452 = torch.prim.ListConstruct %c1, %c0, %c256 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %cpu = torch.constant.device "cpu" + %false = torch.constant.bool false + // expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}} + %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> + return %out : !torch.vtensor<[1,0,256],f32> +} diff --git a/test/Conversion/TorchToTosaLinalg/torch-backend-to-tosa-linalg-backend-pipeline.mlir b/test/Conversion/TorchToTosaLinalg/torch-backend-to-tosa-linalg-backend-pipeline.mlir index 93cc9241c27b..24f2a81af738 100644 --- a/test/Conversion/TorchToTosaLinalg/torch-backend-to-tosa-linalg-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosaLinalg/torch-backend-to-tosa-linalg-backend-pipeline.mlir @@ -53,3 +53,20 @@ func.func @tm_scan(%arg0: tensor<1x512xi64>) -> (tensor<1x512xi64>, tensor<1xi64 } -> tensor<1x512xi64>, tensor<1xi64> return %2#0, %2#1 : tensor<1x512xi64>, tensor<1xi64> } + +//----- +// CHECK-LABEL: func.func @torch.aten.empty.memory_format() -> tensor<1x0x256xf32> { +// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1x0x256xf32> +// CHECK: return %[[EMPTY_TENSOR]] : tensor<1x0x256xf32> +// CHECK: } +func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{ + %c1 = torch.constant.int 1 + %c0 = torch.constant.int 0 + %c256 = torch.constant.int 256 + %2452 = torch.prim.ListConstruct %c1, %c0, %c256 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %cpu = torch.constant.device "cpu" + %false = torch.constant.bool false + %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> + return %out : !torch.vtensor<[1,0,256],f32> +}