Add a pass to adjust integer width from i1 and i64 types to i8 and i32 types.

Due to the limit, vulkan-spirv backend does not support load/store for i1 and
i64. Adjust the integer width for bunch of operations: AccessChainOp,
AddressOfOp, GlobalVariableOp, LoadOp, and StoreOp. During the translation, some
invalid SConvertOp operations will be generated (like i32 -> i32), add another
pattern RemoveNopSConvertOp to legalize it.

PiperOrigin-RevId: 284756560
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 4051195..02a5f89 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -143,6 +143,7 @@
     conversionPassManager.addPass(xla_hlo::createLegalizeToStdPass());
     conversionPassManager.addPass(createIndexComputationPass());
     conversionPassManager.addPass(createIREEToSPIRVPass());
+    conversionPassManager.addPass(createAdjustIntegerWidthPass());
     if (failed(conversionPassManager.run(moduleOp))) {
       return moduleOp.emitError() << "failed to run conversion passes";
     }
diff --git a/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
new file mode 100644
index 0000000..88392e5
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
@@ -0,0 +1,250 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- AdjustIntegerWidthPass.cpp ------------------------------*- C++//-*-===//
+//
+// Pass to adjust integer widths of operations.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h"
+#include "iree/compiler/Utils/TypeConversionUtils.h"
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Pass to
+/// 1) Legalize 64-bit integer values to 32-bit integers values.
+/// 2) Legalize !spv.array containing i1 type to !spv.array of i8 types.
+struct AdjustIntegerWidthPass : public OperationPass<AdjustIntegerWidthPass> {
+  void runOnOperation() override;
+};
+
+// Returns true if the type contains any IntegerType of the width specified by
+// `widths`
+bool hasIntTypeOfWidth(Type type, ArrayRef<int64_t> widths) {
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    return llvm::is_contained(widths, intType.getWidth());
+  } else if (auto structType = type.dyn_cast<spirv::StructType>()) {
+    for (int64_t i = 0, e = structType.getNumElements(); i != e; ++i) {
+      if (hasIntTypeOfWidth(structType.getElementType(i), widths)) return true;
+    }
+    return false;
+  } else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+    return hasIntTypeOfWidth(arrayType.getElementType(), widths);
+  } else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+    return hasIntTypeOfWidth(ptrType.getPointeeType(), widths);
+  }
+  return false;
+}
+
+// Legalizes the integer types in struct.
+// 1) i1 -> i8,
+// 2) i64 -> i32.
+Type legalizeIntegerType(Type type) {
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    if (intType.getWidth() == 1) {
+      return IntegerType::get(8, intType.getContext());
+    } else if (intType.getWidth() == 64) {
+      return IntegerType::get(32, intType.getContext());
+    }
+  } else if (auto structType = type.dyn_cast<spirv::StructType>()) {
+    SmallVector<Type, 1> elementTypes;
+    for (auto i : llvm::seq<unsigned>(0, structType.getNumElements())) {
+      elementTypes.push_back(legalizeIntegerType(structType.getElementType(i)));
+    }
+    // TODO(ravishankarm): Use ABI attributes to legalize the struct type.
+    spirv::StructType::LayoutInfo structSize = 0;
+    VulkanLayoutUtils::Size structAlignment = 1;
+    auto t = spirv::StructType::get(elementTypes);
+    return VulkanLayoutUtils::decorateType(t, structSize, structAlignment);
+  } else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+    return spirv::ArrayType::get(
+        legalizeIntegerType(arrayType.getElementType()),
+        arrayType.getNumElements());
+  } else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+    return spirv::PointerType::get(
+        legalizeIntegerType(ptrType.getPointeeType()),
+        ptrType.getStorageClass());
+  }
+  return type;
+}
+
+/// Rewrite access chain operations where the pointee type contains i1 or i64
+/// types.
+struct AdjustAccessChainOp : public OpRewritePattern<spirv::AccessChainOp> {
+  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::AccessChainOp op,
+                                     PatternRewriter &rewriter) const override {
+    if (!hasIntTypeOfWidth(op.component_ptr()->getType(), {1, 64})) {
+      return matchFailure();
+    }
+    ValueRange indices(op.indices());
+    Type newType = legalizeIntegerType(op.component_ptr()->getType());
+    rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(op, newType,
+                                                      op.base_ptr(), indices);
+    return matchSuccess();
+  }
+};
+
+/// Rewrite address of operations which refers to global variables that contain
+/// i1 or i64 types.
+struct AdjustAddressOfOp : public OpRewritePattern<spirv::AddressOfOp> {
+  using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
+                                     PatternRewriter &rewriter) const override {
+    if (!hasIntTypeOfWidth(op.pointer()->getType(), {1, 64})) {
+      return matchFailure();
+    }
+    rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
+        op, legalizeIntegerType(op.pointer()->getType()),
+        SymbolRefAttr::get(op.variable(), rewriter.getContext()));
+    return matchSuccess();
+  }
+};
+
+/// Rewrite global variable ops that contain i1 and i64 types to i8 and i32
+/// types respectively.
+struct AdjustGlobalVariableWidth
+    : public OpRewritePattern<spirv::GlobalVariableOp> {
+  using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op,
+                                     PatternRewriter &rewriter) const override {
+    if (!hasIntTypeOfWidth(op.type(), {1, 64})) {
+      return matchFailure();
+    }
+    rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
+        op, legalizeIntegerType(op.type()), op.sym_name(),
+        op.getAttr("descriptor_set").cast<IntegerAttr>().getInt(),
+        op.getAttr("binding").cast<IntegerAttr>().getInt());
+    return matchSuccess();
+  }
+};
+
+/// Rewrite loads from !spv.ptr<i64,..> to load from !spv.ptr<i32,...>
+/// Rewrite loads from !spv.ptr<i1,...> to load from !spv.ptr<i8,...> followed
+/// by a truncate to i1 type.
+struct AdjustLoadOp : public OpRewritePattern<spirv::LoadOp> {
+  using OpRewritePattern<spirv::LoadOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::LoadOp op,
+                                     PatternRewriter &rewriter) const override {
+    Type valueType = op.value()->getType();
+    if (!hasIntTypeOfWidth(valueType, {1, 64})) {
+      return matchFailure();
+    }
+
+    Type newType = legalizeIntegerType(valueType);
+    const auto loc = op.getLoc();
+    auto loadOp = rewriter.create<spirv::LoadOp>(
+        loc, newType, op.ptr(),
+        op.getAttrOfType<IntegerAttr>(
+            spirv::attributeName<spirv::MemoryAccess>()),
+        op.getAttrOfType<IntegerAttr>("alignment"));
+    Value *result = loadOp.getResult();
+
+    // If this is a load of a i1, replace it with a load of i8, and truncate the
+    // result. Use INotEqualOp because SConvert doesn't work for i1.
+    if (hasIntTypeOfWidth(valueType, {1})) {
+      auto zero = spirv::ConstantOp::getZero(newType, loc, &rewriter);
+      result = rewriter.create<spirv::INotEqualOp>(loc, valueType, result, zero)
+                   .getResult();
+    }
+
+    rewriter.replaceOp(op, result);
+    return matchSuccess();
+  }
+};
+
+/// Rewrite store operation that contain i1 and i64 types to i8 and i32 types
+/// respectively.
+struct AdjustStoreOp : public OpRewritePattern<spirv::StoreOp> {
+  using OpRewritePattern<spirv::StoreOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::StoreOp op,
+                                     PatternRewriter &rewriter) const override {
+    Type valueType = op.value()->getType();
+    if (!hasIntTypeOfWidth(valueType, {1, 64})) {
+      return matchFailure();
+    }
+
+    Type newType = legalizeIntegerType(valueType);
+    const auto loc = op.getLoc();
+    Value *value;
+    if (hasIntTypeOfWidth(valueType, {1})) {
+      Value *zero =
+          spirv::ConstantOp::getZero(newType, loc, &rewriter).getResult();
+      Value *one =
+          spirv::ConstantOp::getOne(newType, loc, &rewriter).getResult();
+      value = rewriter.create<spirv::SelectOp>(loc, op.value(), one, zero)
+                  .getResult();
+    } else {
+      value = rewriter.create<spirv::SConvertOp>(loc, newType, op.value())
+                  .getResult();
+    }
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        op, op.ptr(), value,
+        op.getAttrOfType<IntegerAttr>(
+            spirv::attributeName<spirv::MemoryAccess>()),
+        op.getAttrOfType<IntegerAttr>("alignment"));
+    return matchSuccess();
+  }
+};
+
+/// Some Adjust* OpRewritePattern will generate useless SConvert operations,
+/// which are invalid operations. Remove the SConvert operation if this is an
+/// nop, i.e., if the source type and destination type are the same, remove the
+/// op. It relies on the furthur finialization to remove the op, and propagate
+/// right operands to other operations.
+struct RemoveNopSConvertOp : public OpRewritePattern<spirv::SConvertOp> {
+  using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(spirv::SConvertOp op,
+                                     PatternRewriter &rewriter) const override {
+    Type t1 = op.operand()->getType();
+    Type t2 = op.result()->getType();
+    if (t1 != t2) return matchFailure();
+    auto zero = spirv::ConstantOp::getZero(t1, op.getLoc(), &rewriter);
+    rewriter.replaceOpWithNewOp<spirv::IAddOp>(op, op.operand(), zero);
+    return matchFailure();
+  }
+};
+
+void AdjustIntegerWidthPass::runOnOperation() {
+  OwningRewritePatternList patterns;
+  // TODO(hanchung): Support for adjusting integer width for integer arithmetic
+  // operations.
+  patterns
+      .insert<AdjustAccessChainOp, AdjustAddressOfOp, AdjustGlobalVariableWidth,
+              AdjustLoadOp, AdjustStoreOp, RemoveNopSConvertOp>(&getContext());
+  Operation *op = getOperation();
+  applyPatternsGreedily(op->getRegions(), patterns);
+}
+
+static PassRegistration<AdjustIntegerWidthPass> pass(
+    "iree-spirv-adjust-integer-width",
+    "Adjust integer width from i1 and i64 types to i8 and i32 types "
+    "respectively");
+
+}  // namespace
+
+std::unique_ptr<Pass> createAdjustIntegerWidthPass() {
+  return std::make_unique<AdjustIntegerWidthPass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/BUILD b/iree/compiler/Translation/SPIRV/BUILD
index 9aecd6e..c4a8dd0 100644
--- a/iree/compiler/Translation/SPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/BUILD
@@ -20,6 +20,7 @@
 cc_library(
     name = "SPIRV",
     srcs = [
+        "AdjustIntegerWidthPass.cpp",
         "EmbeddedKernels.cpp",
         "IREEIndexComputation.cpp",
         "IREEToSPIRV.cpp",
diff --git a/iree/compiler/Translation/SPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/CMakeLists.txt
index 81e990a..2ae8217 100644
--- a/iree/compiler/Translation/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/CMakeLists.txt
@@ -31,6 +31,7 @@
     "XLAIndexPropagation.h"
     "XLAToSPIRV.h"
   SRCS
+    "AdjustIntegerWidthPass.cpp"
     "EmbeddedKernels.cpp"
     "IndexComputation.cpp"
     "IndexComputationAttribute.cpp"
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
index 42a03fe..5a6c065 100644
--- a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
+++ b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
@@ -32,6 +32,11 @@
 // Performs analysis to compute affine maps that represent the index of the
 // elements of tensor values needed within a workitem.
 std::unique_ptr<OpPassBase<FuncOp>> createIndexComputationPass();
+
+// Legalizes integer width from i1 and i64 types to i8 and i32 types
+// respectively.
+std::unique_ptr<Pass> createAdjustIntegerWidthPass();
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
index 535f6d3..575816a 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
+++ b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
@@ -206,6 +206,7 @@
   spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass());
   spirvGenPasses->addPass(createIndexComputationPass());
   spirvGenPasses->addPass(createIREEToSPIRVPass());
+  spirvGenPasses->addPass(createAdjustIntegerWidthPass());
   if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) {
     executableOp.emitError() << "Failed to generate spv.module";
     return {};
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
index 811dbe3..2c72720 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.h
+++ b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
@@ -307,13 +307,6 @@
       return emitError(loc, "unhandled element type ")
              << elementType << " while lowering to SPIR-V";
     }
-    if (auto intElementType = elementType.dyn_cast<IntegerType>()) {
-      if (intElementType.getWidth() > 32) {
-        // TODO(ravishankarm): Maybe its better to report a warning when this
-        // happens.
-        elementType = IntegerType::get(32, elementType.getContext());
-      }
-    }
     if (argType.hasStaticShape()) {
       int64_t stride = elementType.getIntOrFloatBitWidth() / 8;
       for (auto dim : reverse(argType.getShape())) {
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
index 068d408..edcc76e 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
@@ -81,10 +81,6 @@
   auto loc = convertOp.getLoc();
   auto resultElemType =
       convertOp.getResult()->getType().dyn_cast<ShapedType>().getElementType();
-  if (auto intElemType = resultElemType.dyn_cast<IntegerType>()) {
-    if (intElemType.getWidth() > 32)
-      resultElemType = IntegerType::get(32, resultElemType.getContext());
-  }
   auto operandElemType =
       convertOp.getOperand()->getType().dyn_cast<ShapedType>().getElementType();
 
diff --git a/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir b/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir
new file mode 100644
index 0000000..1658c18
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/test/adjust_integer_width.mlir
@@ -0,0 +1,69 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: iree-opt -iree-spirv-adjust-integer-width -verify-diagnostics -o - %s | IreeFileCheck %s
+
+module{
+  spv.module "Logical" "GLSL450" {
+    spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK: spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+    // CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+    spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+    spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+    func @foo_i64(%arg0 : i64, %arg1 : i64) -> () {
+      // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+      // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+      // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
+      // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+      // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+      // CHECK: spv.Store "StorageBuffer" %{{.*}} %{{.*}} : i32
+      %0 = spv._address_of @constant_arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+      %1 = spv.constant 0 : i32
+      %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+      %3 = spv.Load "StorageBuffer" %2 : i64
+      %4 = spv._address_of @constant_arg_1 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+      %5 = spv.constant 0 : i32
+      %6 = spv.AccessChain %4[%5] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
+      spv.Store "StorageBuffer" %6, %3 : i64
+      spv.Return
+    }
+  }
+
+  spv.module "Logical" "GLSL450" {
+    spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK: spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+    // CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+    spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+    spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+    func @foo_i1(%arg0 : i1, %arg1 : i1) -> () {
+      // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+      // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+      // CHECK: spv.Load "StorageBuffer" %{{.*}} : i8
+      // CHECK-NEXT: spv.INotEqual {{.*}} : i8
+      // CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+      // CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
+      // CHECK: spv.Select {{.*}} : i1, i8
+      // CHECK: spv.Store "StorageBuffer" {{.*}} : i8
+      %0 = spv._address_of @constant_arg_0 : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+      %1 = spv.constant 0 : i32
+      %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+      %3 = spv.Load "StorageBuffer" %2 : i1
+      %4 = spv._address_of @constant_arg_1 : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+      %5 = spv.constant 0 : i32
+      %6 = spv.AccessChain %4[%5] : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
+      spv.Store "StorageBuffer" %6, %3 : i1
+      spv.Return
+    }
+  }
+}