Allowing flow.tensor.constant to be used for constants. (#17024)

Prior it had only been used for dynamically shaped constants for
testing. Now it supports that mode (both for testing and programs that
want to ensure certain portions are treated dynamically, but mostly
testing) but also supports being used ala arith.constant. arith.constant
is still preferred for now as it composes with upstream stuff better but
our own op lets us support constant types arith cannot, such as
parameters. With this frontends can pass in IR with constant parameters
inline instead of needing to make them globals.

(I still don't really like that this op has the dynamic behavior and may
remove it in the future - the few uses of it can change to using casts
with optimization barriers on them as a more verbose but less nasty way
of doing the same thing)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index acc4812..d37b28c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -787,40 +787,6 @@
   return {};
 }
 
-namespace {
-
-struct ExpandDynamicShapeConstant : public OpRewritePattern<TensorConstantOp> {
-  using OpRewritePattern<TensorConstantOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(TensorConstantOp op,
-                                PatternRewriter &rewriter) const override {
-    auto constantOp =
-        rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue());
-    auto dynamicType = op.getType();
-    auto staticType = llvm::cast<ShapedType>(constantOp.getType());
-    SmallVector<Value> dynamicDims;
-    for (int64_t i = 0; i < dynamicType.getNumDynamicDims(); ++i) {
-      auto dimValue = rewriter
-                          .create<arith::ConstantIndexOp>(
-                              op.getLoc(), staticType.getDimSize(i))
-                          .getResult();
-      dynamicDims.push_back(
-          rewriter
-              .create<IREE::Util::OptimizationBarrierOp>(op.getLoc(), dimValue)
-              .getResult(0));
-    }
-    rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
-        op, dynamicType, constantOp.getResult(), dynamicDims);
-    return success();
-  }
-};
-
-} // namespace
-
-void TensorConstantOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                   MLIRContext *context) {
-  results.insert<ExpandDynamicShapeConstant>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // flow.tensor.tie_shape
 //===----------------------------------------------------------------------===//
@@ -986,6 +952,7 @@
     }
     auto idx = op.getConstantIndex().value();
 
+    // Fold static dims from the type.
     auto shapedType = llvm::cast<ShapedType>(op.getSource().getType());
     if (!shapedType.isDynamicDim(idx)) {
       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
@@ -993,19 +960,40 @@
       return success();
     }
 
+    // Find dims captured on shape-aware ops.
     auto dynamicDims = IREE::Util::findDynamicDims(
         op.getSource(), op->getBlock(), Block::iterator(op.getOperation()));
-    if (!dynamicDims.has_value()) {
-      return rewriter.notifyMatchFailure(op, "no dynamic dims found/usable");
+    if (dynamicDims.has_value()) {
+      unsigned dimOffset = 0;
+      for (unsigned i = 0; i < idx; ++i) {
+        if (shapedType.isDynamicDim(i))
+          ++dimOffset;
+      }
+      rewriter.replaceOp(op, dynamicDims.value()[dimOffset]);
+      return success();
     }
-    unsigned dimOffset = 0;
-    for (unsigned i = 0; i < idx; ++i) {
-      if (shapedType.isDynamicDim(i))
-        ++dimOffset;
-    }
-    rewriter.replaceOp(op, dynamicDims.value()[dimOffset]);
 
-    return success();
+    // Special handling of flow.tensor.constant which may be acting as a
+    // dynamically shaped value that we want to remove the tensor.dim of but
+    // still treat the shape as dynamic. We do this by inserting an optimization
+    // barrier between the constant and the consumers. Note that this use case
+    // is very specific and generally only applicable to tests/benchmarks.
+    if (auto constantOp = dyn_cast_if_present<IREE::Flow::TensorConstantOp>(
+            op.getShapedValue().getDefiningOp())) {
+      auto valueType = dyn_cast<ShapedType>(constantOp.getValue().getType());
+      if (valueType && valueType != constantOp.getType()) {
+        // Constant op is acting as a cast. If the dimension being queried was
+        // static it would have been resolved above so we know it's dynamic
+        // here.
+        Value staticValue = rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), valueType.getDimSize(idx));
+        rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(
+            op, staticValue);
+        return success();
+      }
+    }
+
+    return rewriter.notifyMatchFailure(op, "no dynamic dims found/usable");
   }
 };
 
@@ -1031,8 +1019,6 @@
       context);
   results.insert<ReplaceOpIfTensorOperandEmpty<TensorBitCastOp, 0, 0>>(context);
   results.insert<FlattenTensorCastLikeChain<TensorBitCastOp>>(context);
-  results.insert<ResolveShapedRank>(context);
-  results.insert<ResolveShapedDim>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 7d00257..0cb7297 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1594,6 +1594,41 @@
 }
 
 //===----------------------------------------------------------------------===//
+// flow.tensor.constant
+//===----------------------------------------------------------------------===//
+
+ParseResult TensorConstantOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  TypedAttr valueAttr;
+  if (failed(parser.parseAttribute(valueAttr)))
+    return failure();
+  result.addAttribute("value", valueAttr);
+  if (succeeded(parser.parseOptionalArrow())) {
+    Type resultType;
+    if (failed(parser.parseType(resultType)))
+      return failure();
+    result.addTypes(resultType);
+  } else {
+    result.addTypes(valueAttr.getType());
+  }
+  return success();
+}
+
+void TensorConstantOp::print(OpAsmPrinter &p) {
+  p << " ";
+  p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
+  p.printAttribute(getValue());
+  auto attrType = getValue().getType();
+  auto resultType = getType();
+  if (attrType != resultType) {
+    p << " -> ";
+    p.printType(resultType);
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // flow.tensor.tie_shape
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index cb21754..c797c03 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -11,6 +11,7 @@
 include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td"
 include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
@@ -1039,28 +1040,26 @@
 
 let opDocGroup = OpGroupTensorOps in {
 
-// TODO(benvanik): make this behave like std.constant if not dynamic.
-def FLOW_TensorConstantOp : FLOW_Op<"tensor.constant"> {
+def FLOW_TensorConstantOp : FLOW_PureOp<"tensor.constant"> {
   let summary = [{tensor constant that can have dynamic dimensions}];
   let description = [{
-    Allows specifying a constant where the return value can erase shape
-    information. This operation is declared as having side effects and has no
-    folder, so will not be optimized away by the compiler. The underlying shape
-    information should be hidden from the compiler and resolved at runtime.
+    Allows specifying a tensor constant of IREE-specific types/attributes or
+    where the return value can erase shape information.
 
     ```mlir
-    %c = flow.tensor.constant tensor<2x2xf32> -> tensor<?x?xf32>
-    %res = math.absf %c : tensor<?x?xf32>
+    %cst = flow.tensor.constant #something_tensor_like : tensor<2x2xf32>
+    %res = math.absf %cst : tensor<2x2xf32>
+    ```
+
+    ```mlir
+    %cst = flow.tensor.constant dense<4.0> : tensor<2x2xf32> -> tensor<?x2xf32>
+    %res = math.absf %cst : tensor<?x2xf32>
     ```
   }];
-  let arguments = (ins ElementsAttr:$value);
+  let arguments = (ins TypedAttrInterface:$value);
   let results = (outs AnyTensor:$result);
-  let assemblyFormat = [{
-    $value attr-dict `->` type($result)
-  }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 def FLOW_TensorTieShapeOp : FLOW_PureOp<"tensor.tie_shape", [
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index a07f17c..de5a8d4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -16,21 +16,19 @@
 
 // -----
 
-// CHECK-LABEL: @expandDynamicShapeConstant
-util.func public @expandDynamicShapeConstant() -> (tensor<?x?xi32>, index, index) {
+// CHECK-LABEL: @tensorDimOfDynamicConstant
+util.func public @tensorDimOfDynamicConstant() -> (index, index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  // CHECK-DAG: %[[CST:.+]] = arith.constant dense<2> : tensor<2x4xi32>
+  // CHECK-NOT: flow.tensor.constant
+  %0 = flow.tensor.constant dense<2> : tensor<2x4xi32> -> tensor<2x?xi32>
   // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  %d0 = tensor.dim %0, %c0 : tensor<2x?xi32>
   // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-  // CHECK-DAG: %[[D0:.+]] = util.optimization_barrier %[[C2]] : index
-  // CHECK-DAG: %[[D1:.+]] = util.optimization_barrier %[[C4]] : index
-  // CHECK: %[[T:.+]] = flow.tensor.reshape %[[CST]] : tensor<2x4xi32> -> tensor<?x?xi32>{%[[D0]], %[[D1]]}
-  %0 = flow.tensor.constant dense<2> : tensor<2x4xi32> -> tensor<?x?xi32>
-  %d0 = tensor.dim %0, %c0 : tensor<?x?xi32>
-  %d1 = tensor.dim %0, %c1 : tensor<?x?xi32>
-  // CHECK: util.return %[[T]], %[[D0]], %[[D1]]
-  util.return %0, %d0, %d1 : tensor<?x?xi32>, index, index
+  // CHECK-DAG: %[[C4_DYNAMIC:.+]] = util.optimization_barrier %[[C4]]
+  %d1 = tensor.dim %0, %c1 : tensor<2x?xi32>
+  // CHECK: util.return %[[C2]], %[[C4_DYNAMIC]]
+  util.return %d0, %d1 : index, index
 }
 
 // -----
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 3d1916f..867af80 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -12,8 +12,8 @@
 include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
 include "iree/compiler/Dialect/Util/IR/UtilAttrs.td"
 include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
-include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index a772f98..b80b75f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -34,6 +34,55 @@
       IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp()));
 }
 
+struct ConvertTensorConstantOp
+    : public OpConversionPattern<IREE::Flow::TensorConstantOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto attrType = dyn_cast<RankedTensorType>(constantOp.getValue().getType());
+    if (!attrType)
+      return failure();
+    auto resultType = constantOp.getType();
+
+    // If the op is acting as a dynamic value then preserve that behavior by
+    // calculation the shape through optimization barriers.
+    SmallVector<Value> dynamicDims;
+    if (!resultType.hasStaticShape()) {
+      for (unsigned i = 0; i < resultType.getRank(); ++i) {
+        if (resultType.isDynamicDim(i)) {
+          Value staticDim = rewriter.create<arith::ConstantIndexOp>(
+              constantOp.getLoc(), attrType.getDimSize(i));
+          Value dynamicDim = rewriter
+                                 .create<IREE::Util::OptimizationBarrierOp>(
+                                     constantOp.getLoc(), staticDim)
+                                 .getResult(0);
+          dynamicDims.push_back(dynamicDim);
+        }
+      }
+    }
+
+    // Capture the tensor constant strongly typed with constant lifetime.
+    Type constantType = IREE::Stream::ResourceType::get(
+        getContext(), IREE::Stream::Lifetime::Constant);
+    auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
+    auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
+        constantOp.getLoc(), constantType, constantOp.getValue(),
+        TypeAttr::get(resultType), dynamicDims, affinityAttr);
+
+    // Transfer to unknown lifetime.
+    Type unknownType = IREE::Stream::ResourceType::get(getContext());
+    auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+        constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult());
+    rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+        constantOp, unknownType, newOp.getResult(), constantSize, constantSize,
+        /*source_affinity=*/affinityAttr,
+        /*result_affinity=*/affinityAttr);
+    return success();
+  }
+};
+
 // Reshapes and bitcasts become clones here to preserve shape/element type
 // information (which may become actual transfers depending on source/target
 // shape) - they'll be elided if not needed.
@@ -888,7 +937,8 @@
                                             TypeConverter &typeConverter,
                                             RewritePatternSet &patterns) {
   patterns
-      .insert<ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
+      .insert<ConvertTensorConstantOp,
+              ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>,
               ConvertTensorCastLikeOp<IREE::Flow::TensorBitCastOp>,
               ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
               ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index 7de878f..756f319 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -1,5 +1,45 @@
 // RUN: iree-opt --split-input-file --iree-stream-conversion %s | FileCheck %s
 
+// CHECK-LABEL: @tensorConstantStatic
+util.func public @tensorConstantStatic() -> tensor<4x2xi32> {
+  // CHECK-DAG: %[[CST:.+]] = stream.tensor.constant : tensor<4x2xi32> in !stream.resource<constant> = dense<2> : tensor<4x2xi32>
+  // CHECK-DAG: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<constant>
+  // CHECK-DAG: %[[TRANSFER:.+]] = stream.async.transfer %[[CST]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+  %cst = flow.tensor.constant dense<2> : tensor<4x2xi32>
+  // CHECK: util.return %[[TRANSFER]], %[[SIZE]]
+  util.return %cst : tensor<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorConstantDynamic
+util.func public @tensorConstantDynamic() -> tensor<?x?xi32> {
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[D0:.+]] = util.optimization_barrier %[[C2]] : index
+  // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+  // CHECK-DAG: %[[D1:.+]] = util.optimization_barrier %[[C4]] : index
+  // CHECK-DAG: %[[CST:.+]] = stream.tensor.constant : tensor<?x?xi32>{%[[D0]], %[[D1]]} in !stream.resource<constant> = dense<2> : tensor<2x4xi32>
+  // CHECK-DAG: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<constant>
+  // CHECK-DAG: %[[TRANSFER:.+]] = stream.async.transfer %[[CST]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+  %cst = flow.tensor.constant dense<2> : tensor<2x4xi32> -> tensor<?x?xi32>
+  // CHECK: util.return %[[TRANSFER]], %[[SIZE]]
+  util.return %cst : tensor<?x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorConstantParameter
+util.func public @tensorConstantParameter() -> tensor<4x2xi32> {
+  // CHECK-DAG: %[[CST:.+]] = stream.tensor.constant : tensor<4x2xi32> in !stream.resource<constant> = #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32>
+  // CHECK-DAG: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<constant>
+  // CHECK-DAG: %[[TRANSFER:.+]] = stream.async.transfer %[[CST]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+  %cst = flow.tensor.constant #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32>
+  // CHECK: util.return %[[TRANSFER]], %[[SIZE]]
+  util.return %cst : tensor<4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @tensorReshapePassThrough
 //  CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
 util.func public @tensorReshapePassThrough(%input: tensor<5x24x48xf32>) -> tensor<30x2x96xf32> {
diff --git a/tests/e2e/regression/dynamic_add.mlir b/tests/e2e/regression/dynamic_add.mlir
index 6f26e88..318513b 100644
--- a/tests/e2e/regression/dynamic_add.mlir
+++ b/tests/e2e/regression/dynamic_add.mlir
@@ -1,6 +1,6 @@
 func.func @main() {
-  %lhs = flow.tensor.constant  dense<[[1.0,2.0,3.0,4.0],[-1.0,-2.0,-3.0,-4.0]]> : tensor<2x4xf32> -> tensor<?x4xf32>
-  %rhs = flow.tensor.constant  dense<[[5.0,6.0,7.0,8.0],[-5.0,-6.0,-7.0,-8.0]]> : tensor<2x4xf32> -> tensor<?x4xf32>
+  %lhs = flow.tensor.constant dense<[[1.0,2.0,3.0,4.0],[-1.0,-2.0,-3.0,-4.0]]> : tensor<2x4xf32> -> tensor<?x4xf32>
+  %rhs = flow.tensor.constant dense<[[5.0,6.0,7.0,8.0],[-5.0,-6.0,-7.0,-8.0]]> : tensor<2x4xf32> -> tensor<?x4xf32>
   %2 = stablehlo.add %lhs, %rhs : tensor<?x4xf32>
   %3 = util.optimization_barrier %2 : tensor<?x4xf32>
   %result = tensor.cast %3 : tensor<?x4xf32> to tensor<2x4xf32>
diff --git a/tests/e2e/regression/dynamic_torch_index_select_high_rank.mlir b/tests/e2e/regression/dynamic_torch_index_select_high_rank.mlir
index 77d90de..501b242 100644
--- a/tests/e2e/regression/dynamic_torch_index_select_high_rank.mlir
+++ b/tests/e2e/regression/dynamic_torch_index_select_high_rank.mlir
@@ -1,6 +1,6 @@
 func.func @torch_index_select1() {
-  %lhs = flow.tensor.constant  dense<[[6,7],[8,9]]> : tensor<2x2xi32> -> tensor<?x?xi32>
-  %rhs = flow.tensor.constant  dense<[[[[0,1],[1,0]],[[0,0],[1,1]]],[[[1,1],[0,0]],[[0,1],[1,0]]]]> : tensor<2x2x2x2xi32> -> tensor<?x?x?x?xi32>
+  %lhs = flow.tensor.constant dense<[[6,7],[8,9]]> : tensor<2x2xi32> -> tensor<?x?xi32>
+  %rhs = flow.tensor.constant dense<[[[[0,1],[1,0]],[[0,0],[1,1]]],[[[1,1],[0,0]],[[0,1],[1,0]]]]> : tensor<2x2x2x2xi32> -> tensor<?x?x?x?xi32>
   %0 = "stablehlo.torch_index_select"(%lhs, %rhs) {batch_dims = 1 : i64, dim = 1 : i64} : (tensor<?x?xi32>, tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
   %dshape = util.optimization_barrier %0 : tensor<?x?x?x?xi32>
   %result = tensor.cast %dshape : tensor<?x?x?x?xi32> to tensor<2x2x2x2xi32>
@@ -13,8 +13,8 @@
 }
 
 func.func @torch_index_select2() {
-  %lhs = flow.tensor.constant  dense<[[6,7],[8,9]]> : tensor<2x2xi32> -> tensor<?x?xi32>
-  %rhs = flow.tensor.constant  dense<[[[[0,1],[1,0]],[[0,0],[1,1]]],[[[1,1],[0,0]],[[0,1],[1,0]]]]> : tensor<2x2x2x2xi32> -> tensor<?x?x?x?xi32>
+  %lhs = flow.tensor.constant dense<[[6,7],[8,9]]> : tensor<2x2xi32> -> tensor<?x?xi32>
+  %rhs = flow.tensor.constant dense<[[[[0,1],[1,0]],[[0,0],[1,1]]],[[[1,1],[0,0]],[[0,1],[1,0]]]]> : tensor<2x2x2x2xi32> -> tensor<?x?x?x?xi32>
   %0 = "stablehlo.torch_index_select"(%lhs, %rhs) {batch_dims = 0 : i64, dim = 0 : i64} : (tensor<?x?xi32>, tensor<?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
   %dshape = util.optimization_barrier %0 : tensor<?x?x?x?x?xi32>
   %result = tensor.cast %dshape : tensor<?x?x?x?x?xi32> to tensor<2x2x2x2x2xi32>
diff --git a/tests/e2e/regression/dynamic_torch_index_select_negative.mlir b/tests/e2e/regression/dynamic_torch_index_select_negative.mlir
index 684109f..c2a4a60 100644
--- a/tests/e2e/regression/dynamic_torch_index_select_negative.mlir
+++ b/tests/e2e/regression/dynamic_torch_index_select_negative.mlir
@@ -1,6 +1,6 @@
 func.func @torch_index_select1() {
-  %lhs = flow.tensor.constant  dense<[[[100, 101],[110, 111]],[[200, 201],[210, 211]]]> : tensor<2x2x2xi32> -> tensor<?x?x?xi32>
-  %rhs = flow.tensor.constant  dense<[[[0, 1],[1, 0]],[[0, 0],[1, 1]]]> : tensor<2x2x2xi32> -> tensor<?x?x?xi32>
+  %lhs = flow.tensor.constant dense<[[[100, 101],[110, 111]],[[200, 201],[210, 211]]]> : tensor<2x2x2xi32> -> tensor<?x?x?xi32>
+  %rhs = flow.tensor.constant dense<[[[0, 1],[1, 0]],[[0, 0],[1, 1]]]> : tensor<2x2x2xi32> -> tensor<?x?x?xi32>
   %0 = "stablehlo.torch_index_select"(%lhs, %rhs) {batch_dims = -1 : i64, dim = -1 : i64} : (tensor<?x?x?xi32>, tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
   %dshape = util.optimization_barrier %0 : tensor<?x?x?xi32>
   %result = tensor.cast %dshape : tensor<?x?x?xi32> to tensor<2x2x2xi32>
diff --git a/tests/microbenchmarks/dynamic_shape_vectorization.mlir b/tests/microbenchmarks/dynamic_shape_vectorization.mlir
index c988a28..f7ce096 100644
--- a/tests/microbenchmarks/dynamic_shape_vectorization.mlir
+++ b/tests/microbenchmarks/dynamic_shape_vectorization.mlir
@@ -14,9 +14,9 @@
   %dim1 = util.unfoldable_constant 513 : index
   %dim2 = util.unfoldable_constant 385 : index
 
-  %A = flow.tensor.constant  dense<1.0> : tensor<513x257xf32> -> tensor<?x?xf32>
-  %B = flow.tensor.constant  dense<2.0> : tensor<257x385xf32> -> tensor<?x?xf32>
-  %C = flow.tensor.constant  dense<0.0> : tensor<513x385xf32> -> tensor<?x?xf32>
+  %A = flow.tensor.constant dense<1.0> : tensor<513x257xf32> -> tensor<?x?xf32>
+  %B = flow.tensor.constant dense<2.0> : tensor<257x385xf32> -> tensor<?x?xf32>
+  %C = flow.tensor.constant dense<0.0> : tensor<513x385xf32> -> tensor<?x?xf32>
 
   %gemm = linalg.matmul
       ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -26,9 +26,9 @@
 
 func.func @dynamic_elw() -> tensor<?x?xf32> {
   %c0 = arith.constant 0.000000e+00 : f32
-  %A = flow.tensor.constant  dense<1.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
-  %B = flow.tensor.constant  dense<2.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
-  %C = flow.tensor.constant  dense<0.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
+  %A = flow.tensor.constant dense<1.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
+  %B = flow.tensor.constant dense<2.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
+  %C = flow.tensor.constant dense<0.0> : tensor<513x1025xf32> -> tensor<?x?xf32>
 
   %gen = linalg.generic {
       indexing_maps = [