[Codegen][GPU] Let integer range optimization narrow GPU computations to i32 (#19473)

Add an option to -iree-util-optimize-int-arithmetic to have it perform
computations in i32 where possible, which is enabled when optimizing
arithmetic for GPU codegen. This allows LLVM co correctly conclude that
various computations don't need to be done at full 64-bit precision,
thus saving registers and instructions. (LLVM has some rewrites for
this, but they're, for example, gated on only having one use of the
potentially-truncated value, which means that shared math stays in an
over-wide data type).

This commit also marks several Vulkan tests as succeeding because they no longer need i64 arithmetic
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index c3dda57..c1ea822 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1009,7 +1009,10 @@
       .addPass(createIREELoopInvariantCodeMotionPass)
       .addPass(affine::createAffineExpandIndexOpsPass)
       .addPass(createLowerAffinePass)
-      .addPass(IREE::Util::createOptimizeIntArithmeticPass)
+      .addPass([]() {
+        return IREE::Util::createOptimizeIntArithmeticPass(
+            IREE::Util::OptimizeIntArithmeticPassOptions{/*narrowToI32=*/true});
+      })
       // Do another round of LICM now that we've lowered and optimized
       // arithmetic
       .addPass(createCSEPass)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
index ef9e587..be16390 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
@@ -32,28 +32,31 @@
 // and is contributed back to the final address with just one instruction.
 
 // Match the interesting constants.
-// CHECK-DAG: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK-DAG: %[[C6:.*]] = llvm.mlir.constant(6 : index) : i64
-// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : index) : i64
-// CHECK-DAG: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
+// CHECK-DAG: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK-DAG: %[[C6:.*]] = llvm.mlir.constant(6 : i32) : i32
+// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK-DAG: %[[C64:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK-DAG: %[[C4096:.*]] = llvm.mlir.constant(4096 : index) : i64
 // CHECK-DAG: %[[C8192:.*]] = llvm.mlir.constant(8192 : index) : i64
 //
 // Match the interesting special registers.
 // CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 2> : i32
 // CHECK-DAG: %[[TID_Y_EXT:.*]] = llvm.sext %[[TID_Y]] : i32 to i64
+// CHECK-DAG: %[[TID_Y_TRUNC:.*]] = llvm.trunc %[[TID_Y_EXT]] : i64 to i32
 // CHECK-DAG: %[[LANEID:.*]] = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
 // CHECK-DAG: %[[LANEID_EXT:.*]] = llvm.sext %[[LANEID]] : i32 to i64
-// CHECK-DAG: %[[TID_Y_IDX:.*]] = llvm.mul %[[TID_Y_EXT]], %[[C64]] overflow<nsw> : i64
+// CHECK-DAG: %[[LANEID_TRUNC:.*]] = llvm.trunc %[[LANEID_EXT]] : i64 to i32
+// CHECK-DAG: %[[TID_Y_IDX:.*]] = llvm.mul %[[TID_Y_TRUNC]], %[[C64]] overflow<nsw> : i32
 //
 // Match the loop invariant math on the special registers.
-// CHECK: %[[GRP_IDX:.*]] = llvm.add %[[TID_Y_IDX]], %[[LANEID_EXT]]  : i64
-// CHECK: %[[GRP_IDX1:.*]] = llvm.add %[[GRP_IDX]], %{{.*}}  : i64
-// CHECK: %[[GRP_IDX2:.*]] = llvm.and %[[GRP_IDX1]], %[[C6]]  : i64
-// CHECK: %[[GRP_IDX3:.*]] = llvm.shl %[[GRP_IDX2]], %[[C2]]  : i64
-// CHECK: %{{.*}} = llvm.xor %[[SRC:.*]], %[[GRP_IDX3]]  : i64
-// CHECK: %[[ADJ_SRC:.*]] = llvm.add %[[SRC]], %[[C16]]  : i64
-// CHECK: %[[INV:.*]] = llvm.xor %[[ADJ_SRC]], %[[GRP_IDX3]]  : i64
+// CHECK: %[[GRP_IDX:.*]] = llvm.add %[[TID_Y_IDX]], %[[LANEID_TRUNC]]  : i32
+// CHECK: %[[GRP_IDX1:.*]] = llvm.add %[[GRP_IDX]], %{{.*}}  : i32
+// CHECK: %[[GRP_IDX2:.*]] = llvm.and %[[GRP_IDX1]], %[[C6]]  : i32
+// CHECK: %[[GRP_IDX3:.*]] = llvm.shl %[[GRP_IDX2]], %[[C2]]  : i32
+// CHECK: %{{.*}} = llvm.xor %[[SRC:.*]], %[[GRP_IDX3]]  : i32
+// CHECK: %[[ADJ_SRC:.*]] = llvm.add %[[SRC]], %[[C16]]  : i32
+// CHECK: %[[INV:.*]] = llvm.xor %[[ADJ_SRC]], %[[GRP_IDX3]]  : i32
+// CHECK: %[[INV_EXT:.*]] = llvm.zext %[[INV]] : i32 to i64
 //
 // Find the basic block boundary.
 // CHECK: llvm.br ^[[LOOP_BODY:bb[0-9]+]](
@@ -65,7 +68,7 @@
 // CHECK: %[[VAR:.*]] = llvm.mul %[[IV]], %[[C4096]]
 //
 // Add the loop invariant part.
-// CHECK: %[[OFF:.*]] = llvm.add %{{.*}}, %[[INV]]
+// CHECK: %[[OFF:.*]] = llvm.add %{{.*}}, %[[INV_EXT]]
 //
 // Store the resulting offset in the memref descriptor.
 // llvm.insert %[[OFF]], %{{.*}}[2]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 2a13558..2b8e467 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -231,7 +231,10 @@
       .addPass(createCanonicalizerPass)
       .addPass(createCSEPass)
       .addPass(createLowerAffinePass)
-      .addPass(IREE::Util::createOptimizeIntArithmeticPass)
+      .addPass([]() {
+        return IREE::Util::createOptimizeIntArithmeticPass(
+            IREE::Util::OptimizeIntArithmeticPassOptions{/*narrowToI32=*/true});
+      })
 
       // Lower ApplyScale before the i64 Emulation Pass so that new 64-bit ops
       // are also emulated if not supported by the target.
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 3b6aa5f..f63891c 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -506,17 +506,30 @@
         loopLike.getLoopLowerBounds();
     std::optional<SmallVector<OpFoldResult>> maybeUpperBounds =
         loopLike.getLoopUpperBounds();
-    if (!maybeLowerBounds || !maybeUpperBounds) {
+    std::optional<SmallVector<Value>> maybeIvs =
+        loopLike.getLoopInductionVars();
+    if (!maybeLowerBounds || !maybeUpperBounds || !maybeIvs) {
       return;
     }
 
     // If any lower + upper bound pair cannot be definitely verified as lb < ub
     // then the loop may have a zero trip count.
-    for (auto [lb, ub] :
-         llvm::zip_equal(*maybeLowerBounds, *maybeUpperBounds)) {
-      if (!ValueBoundsConstraintSet::compare(lb, ValueBoundsConstraintSet::LT,
-                                             ub)) {
-        return;
+    for (auto [lb, ub, iv] :
+         llvm::zip_equal(*maybeLowerBounds, *maybeUpperBounds, *maybeIvs)) {
+      if (iv.getType().isIndex()) {
+        if (!ValueBoundsConstraintSet::compare(lb, ValueBoundsConstraintSet::LT,
+                                               ub)) {
+          return;
+        }
+      } else {
+        // Weaker test for non-`index` operands to some loops
+        // like scf.for, since the value bounds interface requires index types.
+        auto maybeLb = getConstantIntValue(lb);
+        auto maybeUb = getConstantIntValue(ub);
+        if (!maybeLb || !maybeUb)
+          return;
+        if (*maybeLb >= *maybeUb)
+          return;
       }
     }
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index b1c17bd..97eae87 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -463,7 +463,8 @@
 
 def Util_AssumeIntOp : Util_PureOp<"assume.int", [
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
-    DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface, ["inferResultRanges"]>
+    DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface, ["inferResultRanges"]>,
+    AllTypesMatch<["operands", "results"]>
 ]> {
   let summary = "memorializes assumptions about index/integer values.";
   let description = [{
@@ -497,7 +498,7 @@
     OpBuilder<(ins
       "ArrayRef<Value>":$operands,
       "ArrayRef<ArrayAttr>":$assumptions
-    )>
+    )>,
   ];
 
   let extraClassDeclaration = [{
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
index c58aeac..fdee32e 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
@@ -6,8 +6,10 @@
 
 #include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
 #include "llvm/Support/Debug.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
@@ -16,6 +18,7 @@
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -187,6 +190,131 @@
 };
 
 //===----------------------------------------------------------------------===//
+// Index -> int32 assumption narrowing
+// If we're narrowing `index` values to `i32`, a `util.assume.int` on `index`
+// introduces unnecessary zero-extensions and truncations to/from `index`
+// when introducing assumptions.
+//===----------------------------------------------------------------------===//
+struct RemoveIndexCastForAssumeOfI32
+    : public OpRewritePattern<Util::AssumeIntOp> {
+  RemoveIndexCastForAssumeOfI32(MLIRContext *context, DataFlowSolver &solver)
+      : OpRewritePattern(context), solver(solver) {}
+
+  LogicalResult matchAndRewrite(Util::AssumeIntOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallBitVector needNarrowing(op.getNumOperands(), false);
+    for (auto [idx, arg] : llvm::enumerate(op.getOperands())) {
+      if (!arg.getType().isIndex())
+        continue;
+      auto castOp = arg.getDefiningOp<arith::IndexCastUIOp>();
+      if (!castOp)
+        continue;
+      Value castIn = castOp.getIn();
+      Type intType = castIn.getType();
+      if (intType.getIntOrFloatBitWidth() > 32)
+        continue;
+
+      needNarrowing[idx] = true;
+    }
+    if (needNarrowing.none())
+      return failure();
+
+    SmallVector<Value> newArgs;
+    newArgs.reserve(op.getNumOperands());
+    for (auto [idx, arg] : llvm::enumerate(op.getOperands())) {
+      if (!needNarrowing[idx]) {
+        newArgs.push_back(arg);
+        continue;
+      }
+      newArgs.push_back(arg.getDefiningOp<arith::IndexCastUIOp>().getIn());
+    }
+    ArrayAttr assumptions = op.getAssumptionsAttr();
+    auto newOp = rewriter.create<Util::AssumeIntOp>(
+        op.getLoc(), ValueTypeRange<ArrayRef<Value>>{newArgs}, newArgs,
+        assumptions);
+    SmallVector<Value> replacements(newOp.getResults());
+    for (auto [newRes, oldRes] :
+         llvm::zip_equal(replacements, op.getResults())) {
+      if (newRes.getType() != oldRes.getType()) {
+        newRes = rewriter.create<arith::IndexCastUIOp>(
+            op.getLoc(), oldRes.getType(), newRes);
+      }
+      // Preserve assumption state.
+      auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldRes);
+      if (oldState) {
+        (void)solver.getOrCreateState<IntegerValueRangeLattice>(newRes)->join(
+            *oldState);
+      }
+    }
+    rewriter.replaceOp(op, replacements);
+    return success();
+  }
+
+  DataFlowSolver &solver;
+};
+
+//===----------------------------------------------------------------------===//
+// scf.for induction variable range narrowing
+// If the induction variable of an scf.for can be represented as an I32,
+// make that change to save on registers etc.
+//===----------------------------------------------------------------------===//
+struct NarrowSCFForIvToI32 : public OpRewritePattern<scf::ForOp> {
+  NarrowSCFForIvToI32(MLIRContext *context, DataFlowSolver &solver)
+      : OpRewritePattern(context), solver(solver) {}
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = forOp.getLoc();
+    Value iv = forOp.getInductionVar();
+    Type srcType = iv.getType();
+    if (!srcType.isIndex() && !srcType.isInteger(64))
+      return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64");
+    if (!staticallyLegalToConvertToUnsigned(solver, iv))
+      return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative");
+    if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep()))
+      return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative");
+    auto *ivState = solver.lookupState<IntegerValueRangeLattice>(iv);
+    if (ivState->getValue().getValue().smax().getActiveBits() > 31)
+      return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32");
+
+    Type i32 = rewriter.getI32Type();
+    auto doCastDown = [&](Value v) -> Value {
+      if (srcType.isIndex())
+        return rewriter.create<arith::IndexCastUIOp>(loc, i32, v);
+      else
+        return rewriter.create<arith::TruncIOp>(loc, i32, v);
+    };
+    Value newLb = doCastDown(forOp.getLowerBound());
+    Value newUb = doCastDown(forOp.getUpperBound());
+    Value newStep = doCastDown(forOp.getStep());
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPointToStart(&forOp.getRegion().front());
+      Value castBackOp;
+      if (srcType.isIndex()) {
+        castBackOp =
+            rewriter.create<arith::IndexCastUIOp>(iv.getLoc(), srcType, iv);
+      } else {
+        castBackOp = rewriter.create<arith::ExtUIOp>(iv.getLoc(), srcType, iv);
+      }
+      (void)solver.getOrCreateState<IntegerValueRangeLattice>(castBackOp)
+          ->join(*ivState);
+      rewriter.replaceAllUsesExcept(iv, castBackOp, castBackOp.getDefiningOp());
+    }
+    solver.eraseState(iv);
+    rewriter.modifyOpInPlace(forOp, [&]() {
+      iv.setType(i32);
+      forOp.getLowerBoundMutable().assign(newLb);
+      forOp.getUpperBoundMutable().assign(newUb);
+      forOp.getStepMutable().assign(newStep);
+    });
+    return success();
+  }
+
+  DataFlowSolver &solver;
+};
+
+//===----------------------------------------------------------------------===//
 // Divisibility
 //===----------------------------------------------------------------------===//
 
@@ -306,6 +434,8 @@
 
 class OptimizeIntArithmeticPass
     : public impl::OptimizeIntArithmeticPassBase<OptimizeIntArithmeticPass> {
+  using Base::Base;
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *ctx = op->getContext();
@@ -313,6 +443,8 @@
     expandAffineOps(op);
 
     DataFlowSolver solver;
+    // Needed to make the dead code analyis not be too conservative.
+    solver.load<SparseConstantPropagation>();
     solver.load<DeadCodeAnalysis>();
     solver.load<IntegerRangeAnalysis>();
     solver.load<IntegerDivisibilityAnalysis>();
@@ -322,6 +454,12 @@
     // Populate upstream arith patterns.
     arith::populateIntRangeOptimizationsPatterns(patterns, solver);
 
+    if (narrowToI32) {
+      arith::populateIntRangeNarrowingPatterns(patterns, solver, {32});
+      patterns.add<NarrowSCFForIvToI32, RemoveIndexCastForAssumeOfI32>(ctx,
+                                                                       solver);
+    }
+
     // Populate canonicalization patterns.
     auto arithDialect = ctx->getOrLoadDialect<arith::ArithDialect>();
     for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index fb0cf7d..aa4751f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -66,6 +66,12 @@
     "::mlir::arith::ArithDialect",
     "::mlir::iree_compiler::IREE::Util::UtilDialect"
   ];
+  let options = [
+    Option<"narrowToI32", "narrow-to-i32", "bool",
+      /*default=*/"false",
+      "Flag indicating if computations that can be performed with 32 bits should be."
+      " Mainly used for GPU code generation to not waste registers.">
+  ];
 }
 
 def PropagateSubrangesPass : Pass<"iree-util-propagate-subranges", "mlir::ModuleOp"> {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
index c4a4724..b6befb4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
@@ -27,6 +27,7 @@
             "integer_divisibility.mlir",
             "ipo.mlir",
             "optimize_int_arithmetic.mlir",
+            "optimize_int_arithmetic_narrowing.mlir",
             "patterns.mlir",
             "propagate_subranges.mlir",
             "simplify_global_accesses.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index 02ba4d9..b1cd45c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -25,6 +25,7 @@
     "integer_divisibility.mlir"
     "ipo.mlir"
     "optimize_int_arithmetic.mlir"
+    "optimize_int_arithmetic_narrowing.mlir"
     "patterns.mlir"
     "propagate_subranges.mlir"
     "simplify_global_accesses.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic_narrowing.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic_narrowing.mlir
new file mode 100644
index 0000000..c40dbe4
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic_narrowing.mlir
@@ -0,0 +1,68 @@
+// RUN: iree-opt --split-input-file \
+// RUN:  --iree-util-optimize-int-arithmetic=narrow-to-i32=true --cse %s \
+// RUN:  | FileCheck %s
+// We inherit a number of patterns from upstream for narrowing specific arith
+// operations. Those are not the focus of testing, but we may test some of them
+// here incidentally as part of verifying that the overall pass and local
+// patterns are effective.
+
+// CHECK-LABEL: @narrow_tid_computations
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
+// CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x upper_bound 64
+// CHECK-DAG: %[[TID_I32:.+]] = arith.index_castui %[[THREAD_ID_X]] : index to i32
+// CHECK: %[[V0:.+]] = arith.divui %[[TID_I32]], %[[C16]] : i32
+// CHECK-NEXT: %[[V1:.+]] = arith.remui %[[TID_I32]], %[[C16]] : i32
+// CHECK-NEXT: %[[V2:.+]] = arith.muli %[[V0]], %[[C32]] : i32
+// CHECK-NEXT: %[[V3:.+]] = arith.addi %[[V2]], %[[V1]] : i32
+// CHECK-NEXT: %[[RET:.+]] = arith.index_castui %[[V3]] : i32 to index
+// CHECK: return %[[RET]]
+util.func @narrow_tid_computations() -> index {
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %thread_id_x = gpu.thread_id x upper_bound 64
+  %0 = arith.divui %thread_id_x, %c16 : index
+  %1 = arith.remui %thread_id_x, %c16 : index
+  %2 = arith.muli %0, %c32 : index
+  %3 = arith.addi %2, %1 : index
+  util.return %3 : index
+}
+
+// -----
+
+// CHECK-LABEL: @narrow_assumes
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+// CHECK-NEXT: %[[ASSUME:.+]] = util.assume.int %[[ARG0]]<umin = 16, umax = 122, udiv = 16> : i32
+// CHECK-NEXT: %[[AS_INDEX:.+]] = arith.index_castui %[[ASSUME]] : i32 to index
+// CHECK-NEXT: util.return %[[ASSUME]], %[[AS_INDEX]]
+util.func @narrow_assumes(%arg0: i32) -> (i32, index) {
+  %0 = arith.index_castui %arg0 : i32 to index
+  %1 = util.assume.int %0<umin = 16, umax = 122, udiv = 16> : index
+  %2 = arith.index_castui %1 : index to i32
+  util.return %2, %1 : i32, index
+}
+
+// -----
+
+// CHECK-LABEL: @narrow_scf_for
+// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : i32
+// CHECK-DAG: %[[C96:.+]] = arith.constant 96 : i32
+// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : i32
+// CHECK-DAG: %[[TID:.+]] = gpu.thread_id x upper_bound 64
+// CHECK-DAG: %[[TID_I32:.+]] = arith.index_castui %[[TID]] : index to i32
+// CHECK: scf.for %[[ARG1:.+]] = %[[TID_I32]] to %[[C96]] step %[[C64]]
+// CHECK-NEXT: %[[V0:.+]] = arith.addi %[[ARG1]], %[[C512]]
+// CHECK-NEXT: %[[V0_IDX:.+]] = arith.index_castui %[[V0]] : i32 to index
+// CHECK-NEXT: memref.store {{.*}}[%[[V0_IDX]]]
+util.func @narrow_scf_for(%arg0: memref<?xf32>) {
+  %c0_f32 = arith.constant 0.0 : f32
+  %c64 = arith.constant 64 : index
+  %c96 = arith.constant 96 : index
+  %c512 = arith.constant 512 : index
+  %tid = gpu.thread_id x upper_bound 64
+  scf.for %arg1 = %tid to %c96 step %c64 {
+    %0 = arith.addi %arg1, %c512 : index
+    memref.store %c0_f32, %arg0[%0] : memref<?xf32>
+  }
+  util.return
+}
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index 7284fc0..5d5d248 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -40,18 +40,9 @@
     "onnx/node/generated/test_argmin_negative_axis_keepdims_random_select_last_index",
     "onnx/node/generated/test_argmin_no_keepdims_example_select_last_index",
     "onnx/node/generated/test_argmin_no_keepdims_random_select_last_index",
-    "onnx/node/generated/test_averagepool_2d_ceil",
-    "onnx/node/generated/test_averagepool_2d_default",
-    "onnx/node/generated/test_averagepool_2d_dilations",
-    "onnx/node/generated/test_averagepool_2d_pads",
-    "onnx/node/generated/test_averagepool_2d_pads_count_include_pad",
-    "onnx/node/generated/test_averagepool_2d_precomputed_pads",
-    "onnx/node/generated/test_averagepool_2d_precomputed_pads_count_include_pad",
     "onnx/node/generated/test_averagepool_2d_precomputed_same_upper",
-    "onnx/node/generated/test_averagepool_2d_precomputed_strides",
     "onnx/node/generated/test_averagepool_2d_same_lower",
     "onnx/node/generated/test_averagepool_2d_same_upper",
-    "onnx/node/generated/test_averagepool_2d_strides",
     "onnx/node/generated/test_basic_deform_conv_with_padding",
     "onnx/node/generated/test_basic_deform_conv_without_padding",
     "onnx/node/generated/test_bernoulli_seed",