Add a pass to adjust integer width from i1 and i64 types to i8 and i32 types.
Due to the limit, vulkan-spirv backend does not support load/store for i1 and
i64. Adjust the integer width for bunch of operations: AccessChainOp,
AddressOfOp, GlobalVariableOp, LoadOp, and StoreOp. During the translation, some
invalid SConvertOp operations will be generated (like i32 -> i32), add another
pattern RemoveNopSConvertOp to legalize it.
PiperOrigin-RevId: 284756560
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 4051195..02a5f89 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -143,6 +143,7 @@
conversionPassManager.addPass(xla_hlo::createLegalizeToStdPass());
conversionPassManager.addPass(createIndexComputationPass());
conversionPassManager.addPass(createIREEToSPIRVPass());
+ conversionPassManager.addPass(createAdjustIntegerWidthPass());
if (failed(conversionPassManager.run(moduleOp))) {
return moduleOp.emitError() << "failed to run conversion passes";
}
diff --git a/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
new file mode 100644
index 0000000..88392e5
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
@@ -0,0 +1,250 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- AdjustIntegerWidthPass.cpp ------------------------------*- C++//-*-===//
+//
+// Pass to adjust integer widths of operations.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h"
+#include "iree/compiler/Utils/TypeConversionUtils.h"
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Pass to
+/// 1) Legalize 64-bit integer values to 32-bit integers values.
+/// 2) Legalize !spv.array containing i1 type to !spv.array of i8 types.
+struct AdjustIntegerWidthPass : public OperationPass<AdjustIntegerWidthPass> {
+ void runOnOperation() override;
+};
+
+// Returns true if the type contains any IntegerType of the width specified by
+// `widths`
+bool hasIntTypeOfWidth(Type type, ArrayRef<int64_t> widths) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ return llvm::is_contained(widths, intType.getWidth());
+ } else if (auto structType = type.dyn_cast<spirv::StructType>()) {
+ for (int64_t i = 0, e = structType.getNumElements(); i != e; ++i) {
+ if (hasIntTypeOfWidth(structType.getElementType(i), widths)) return true;
+ }
+ return false;
+ } else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+ return hasIntTypeOfWidth(arrayType.getElementType(), widths);
+ } else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+ return hasIntTypeOfWidth(ptrType.getPointeeType(), widths);
+ }
+ return false;
+}
+
+// Legalizes the integer types in struct.
+// 1) i1 -> i8,
+// 2) i64 -> i32.
+Type legalizeIntegerType(Type type) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (intType.getWidth() == 1) {
+ return IntegerType::get(8, intType.getContext());
+ } else if (intType.getWidth() == 64) {
+ return IntegerType::get(32, intType.getContext());
+ }
+ } else if (auto structType = type.dyn_cast<spirv::StructType>()) {
+ SmallVector<Type, 1> elementTypes;
+ for (auto i : llvm::seq<unsigned>(0, structType.getNumElements())) {
+ elementTypes.push_back(legalizeIntegerType(structType.getElementType(i)));
+ }
+ // TODO(ravishankarm): Use ABI attributes to legalize the struct type.
+ spirv::StructType::LayoutInfo structSize = 0;
+ VulkanLayoutUtils::Size structAlignment = 1;
+ auto t = spirv::StructType::get(elementTypes);
+ return VulkanLayoutUtils::decorateType(t, structSize, structAlignment);
+ } else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+ return spirv::ArrayType::get(
+ legalizeIntegerType(arrayType.getElementType()),
+ arrayType.getNumElements());
+ } else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+ return spirv::PointerType::get(
+ legalizeIntegerType(ptrType.getPointeeType()),
+ ptrType.getStorageClass());
+ }
+ return type;
+}
+
+/// Rewrite access chain operations where the pointee type contains i1 or i64
+/// types.
+struct AdjustAccessChainOp : public OpRewritePattern<spirv::AccessChainOp> {
+ using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::AccessChainOp op,
+ PatternRewriter &rewriter) const override {
+ if (!hasIntTypeOfWidth(op.component_ptr()->getType(), {1, 64})) {
+ return matchFailure();
+ }
+ ValueRange indices(op.indices());
+ Type newType = legalizeIntegerType(op.component_ptr()->getType());
+ rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(op, newType,
+ op.base_ptr(), indices);
+ return matchSuccess();
+ }
+};
+
+/// Rewrite address of operations which refers to global variables that contain
+/// i1 or i64 types.
+struct AdjustAddressOfOp : public OpRewritePattern<spirv::AddressOfOp> {
+ using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
+ PatternRewriter &rewriter) const override {
+ if (!hasIntTypeOfWidth(op.pointer()->getType(), {1, 64})) {
+ return matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
+ op, legalizeIntegerType(op.pointer()->getType()),
+ SymbolRefAttr::get(op.variable(), rewriter.getContext()));
+ return matchSuccess();
+ }
+};
+
+/// Rewrite global variable ops that contain i1 and i64 types to i8 and i32
+/// types respectively.
+struct AdjustGlobalVariableWidth
+ : public OpRewritePattern<spirv::GlobalVariableOp> {
+ using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op,
+ PatternRewriter &rewriter) const override {
+ if (!hasIntTypeOfWidth(op.type(), {1, 64})) {
+ return matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
+ op, legalizeIntegerType(op.type()), op.sym_name(),
+ op.getAttr("descriptor_set").cast<IntegerAttr>().getInt(),
+ op.getAttr("binding").cast<IntegerAttr>().getInt());
+ return matchSuccess();
+ }
+};
+
+/// Rewrite loads from !spv.ptr<i64,..> to load from !spv.ptr<i32,...>
+/// Rewrite loads from !spv.ptr<i1,...> to load from !spv.ptr<i8,...> followed
+/// by a truncate to i1 type.
+struct AdjustLoadOp : public OpRewritePattern<spirv::LoadOp> {
+ using OpRewritePattern<spirv::LoadOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::LoadOp op,
+ PatternRewriter &rewriter) const override {
+ Type valueType = op.value()->getType();
+ if (!hasIntTypeOfWidth(valueType, {1, 64})) {
+ return matchFailure();
+ }
+
+ Type newType = legalizeIntegerType(valueType);
+ const auto loc = op.getLoc();
+ auto loadOp = rewriter.create<spirv::LoadOp>(
+ loc, newType, op.ptr(),
+ op.getAttrOfType<IntegerAttr>(
+ spirv::attributeName<spirv::MemoryAccess>()),
+ op.getAttrOfType<IntegerAttr>("alignment"));
+ Value *result = loadOp.getResult();
+
+ // If this is a load of a i1, replace it with a load of i8, and truncate the
+ // result. Use INotEqualOp because SConvert doesn't work for i1.
+ if (hasIntTypeOfWidth(valueType, {1})) {
+ auto zero = spirv::ConstantOp::getZero(newType, loc, &rewriter);
+ result = rewriter.create<spirv::INotEqualOp>(loc, valueType, result, zero)
+ .getResult();
+ }
+
+ rewriter.replaceOp(op, result);
+ return matchSuccess();
+ }
+};
+
+/// Rewrite store operation that contain i1 and i64 types to i8 and i32 types
+/// respectively.
+struct AdjustStoreOp : public OpRewritePattern<spirv::StoreOp> {
+ using OpRewritePattern<spirv::StoreOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ Type valueType = op.value()->getType();
+ if (!hasIntTypeOfWidth(valueType, {1, 64})) {
+ return matchFailure();
+ }
+
+ Type newType = legalizeIntegerType(valueType);
+ const auto loc = op.getLoc();
+ Value *value;
+ if (hasIntTypeOfWidth(valueType, {1})) {
+ Value *zero =
+ spirv::ConstantOp::getZero(newType, loc, &rewriter).getResult();
+ Value *one =
+ spirv::ConstantOp::getOne(newType, loc, &rewriter).getResult();
+ value = rewriter.create<spirv::SelectOp>(loc, op.value(), one, zero)
+ .getResult();
+ } else {
+ value = rewriter.create<spirv::SConvertOp>(loc, newType, op.value())
+ .getResult();
+ }
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ op, op.ptr(), value,
+ op.getAttrOfType<IntegerAttr>(
+ spirv::attributeName<spirv::MemoryAccess>()),
+ op.getAttrOfType<IntegerAttr>("alignment"));
+ return matchSuccess();
+ }
+};
+
+/// Some Adjust* OpRewritePattern will generate useless SConvert operations,
+/// which are invalid operations. Remove the SConvert operation if this is an
+/// nop, i.e., if the source type and destination type are the same, remove the
+/// op. It relies on the furthur finialization to remove the op, and propagate
+/// right operands to other operations.
+struct RemoveNopSConvertOp : public OpRewritePattern<spirv::SConvertOp> {
+ using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(spirv::SConvertOp op,
+ PatternRewriter &rewriter) const override {
+ Type t1 = op.operand()->getType();
+ Type t2 = op.result()->getType();
+ if (t1 != t2) return matchFailure();
+ auto zero = spirv::ConstantOp::getZero(t1, op.getLoc(), &rewriter);
+ rewriter.replaceOpWithNewOp<spirv::IAddOp>(op, op.operand(), zero);
+ return matchFailure();
+ }
+};
+
+void AdjustIntegerWidthPass::runOnOperation() {
+ OwningRewritePatternList patterns;
+ // TODO(hanchung): Support for adjusting integer width for integer arithmetic
+ // operations.
+ patterns
+ .insert<AdjustAccessChainOp, AdjustAddressOfOp, AdjustGlobalVariableWidth,
+ AdjustLoadOp, AdjustStoreOp, RemoveNopSConvertOp>(&getContext());
+ Operation *op = getOperation();
+ applyPatternsGreedily(op->getRegions(), patterns);
+}
+
+static PassRegistration<AdjustIntegerWidthPass> pass(
+ "iree-spirv-adjust-integer-width",
+ "Adjust integer width from i1 and i64 types to i8 and i32 types "
+ "respectively");
+
+} // namespace
+
+std::unique_ptr<Pass> createAdjustIntegerWidthPass() {
+ return std::make_unique<AdjustIntegerWidthPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/BUILD b/iree/compiler/Translation/SPIRV/BUILD
index 9aecd6e..c4a8dd0 100644
--- a/iree/compiler/Translation/SPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/BUILD
@@ -20,6 +20,7 @@
cc_library(
name = "SPIRV",
srcs = [
+ "AdjustIntegerWidthPass.cpp",
"EmbeddedKernels.cpp",
"IREEIndexComputation.cpp",
"IREEToSPIRV.cpp",
diff --git a/iree/compiler/Translation/SPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/CMakeLists.txt
index 81e990a..2ae8217 100644
--- a/iree/compiler/Translation/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/CMakeLists.txt
@@ -31,6 +31,7 @@
"XLAIndexPropagation.h"
"XLAToSPIRV.h"
SRCS
+ "AdjustIntegerWidthPass.cpp"
"EmbeddedKernels.cpp"
"IndexComputation.cpp"
"IndexComputationAttribute.cpp"
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
index 42a03fe..5a6c065 100644
--- a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
+++ b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
@@ -32,6 +32,11 @@
// Performs analysis to compute affine maps that represent the index of the
// elements of tensor values needed within a workitem.
std::unique_ptr<OpPassBase<FuncOp>> createIndexComputationPass();
+
+// Legalizes integer width from i1 and i64 types to i8 and i32 types
+// respectively.
+std::unique_ptr<Pass> createAdjustIntegerWidthPass();
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
index 535f6d3..575816a 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
+++ b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
@@ -206,6 +206,7 @@
spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass());
spirvGenPasses->addPass(createIndexComputationPass());
spirvGenPasses->addPass(createIREEToSPIRVPass());
+ spirvGenPasses->addPass(createAdjustIntegerWidthPass());
if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) {
executableOp.emitError() << "Failed to generate spv.module";
return {};
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
index 811dbe3..2c72720 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.h
+++ b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
@@ -307,13 +307,6 @@
return emitError(loc, "unhandled element type ")
<< elementType << " while lowering to SPIR-V";
}
- if (auto intElementType = elementType.dyn_cast<IntegerType>()) {
- if (intElementType.getWidth() > 32) {
- // TODO(ravishankarm): Maybe its better to report a warning when this
- // happens.
- elementType = IntegerType::get(32, elementType.getContext());
- }
- }
if (argType.hasStaticShape()) {
int64_t stride = elementType.getIntOrFloatBitWidth() / 8;
for (auto dim : reverse(argType.getShape())) {
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
index 068d408..edcc76e 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
@@ -81,10 +81,6 @@
auto loc = convertOp.getLoc();
auto resultElemType =
convertOp.getResult()->getType().dyn_cast<ShapedType>().getElementType();
- if (auto intElemType = resultElemType.dyn_cast<IntegerType>()) {
- if (intElemType.getWidth() > 32)
- resultElemType = IntegerType::get(32, resultElemType.getContext());
- }
auto operandElemType =
convertOp.getOperand()->getType().dyn_cast<ShapedType>().getElementType();
diff --git a/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir b/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir
new file mode 100644
index 0000000..1658c18
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir
@@ -0,0 +1,69 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: iree-opt -iree-spirv-adjust-integer-width -verify-diagnostics -o - %s | IreeFileCheck %s
+
+module{
+ spv.module "Logical" "GLSL450" {
+ spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK: spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ func @foo_i64(%arg0 : i64, %arg1 : i64) -> () {
+ // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
+ // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+ // CHECK: spv.Store "StorageBuffer" %{{.*}} %{{.*}} : i32
+ %0 = spv._address_of @constant_arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ %1 = spv.constant 0 : i32
+ %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ %3 = spv.Load "StorageBuffer" %2 : i64
+ %4 = spv._address_of @constant_arg_1 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ %5 = spv.constant 0 : i32
+ %6 = spv.AccessChain %4[%5] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+ spv.Store "StorageBuffer" %6, %3 : i64
+ spv.Return
+ }
+ }
+
+ spv.module "Logical" "GLSL450" {
+ spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+ // CHECK: spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ // CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ func @foo_i1(%arg0 : i1, %arg1 : i1) -> () {
+ // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ // CHECK: spv.Load "StorageBuffer" %{{.*}} : i8
+ // CHECK-NEXT: spv.INotEqual {{.*}} : i8
+ // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+ // CHECK: spv.Select {{.*}} : i1, i8
+ // CHECK: spv.Store "StorageBuffer" {{.*}} : i8
+ %0 = spv._address_of @constant_arg_0 : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ %1 = spv.constant 0 : i32
+ %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ %3 = spv.Load "StorageBuffer" %2 : i1
+ %4 = spv._address_of @constant_arg_1 : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ %5 = spv.constant 0 : i32
+ %6 = spv.AccessChain %4[%5] : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+ spv.Store "StorageBuffer" %6, %3 : i1
+ spv.Return
+ }
+ }
+}