Bump StableHLO to 0264c4d64c82ae74a54b85d274eec5084c2c0abf (#16561)
- This update spans a change made to StableHLO to switch op attributes
from `I64ElementsAttr` to `DenseI64ArrayAttr`, discussed here:
https://groups.google.com/a/openxla.org/g/openxla-discuss/c/hEoA4V5DZF0/m/rdNEiM20BgAJ
- Unfortunately, StableHLO generated prior to this change upstream may
not be compatible with the StableHLO input plug-in in IREE after this
PR.
- Min version of jax/jaxlib compatible with this change: v0.4.24
- Min version of TF compatible with this change: v2.16.0
- Some of the changes here were propagated from the subset of StableHLO
to LinAlg lowers that live in the StableHLO repro.
To update stored SHLO bytecode, I did the following:
1. Compiled openxla/stablehlo with current IREE repo version and
upgraded IREE repo versions to `old_stablehlo` and `new_stablehlo`
respectively.
1. `old_stablehlo/build/bin/stablehlo-opt --mlir-print-op-generic
<model>.mlirbc > <model>.mlir`
1. Manually edit <model>.mlir to bring it up to new spec
1. `new_stablehlo/build/bin/stablehlo-opt --emit-bytecode <model>.mlir >
new_<model>.mlirbc`
The following covered most of the edits that needed to be made to the
model MLIR:
```
sed -i \
-e 's/dimensions = dense<\([0-9]*\)> : tensor<1xi64>/dimensions = array<i64: \1>/g' \
-e 's/dimensions = dense<\[\([0-9]*\), \([0-9]*\)\]> : tensor<2xi64>/dimensions = array<i64: \1, \2>/g' \
-e 's/rhs_dilation = dense<\([0-9]*\)> : tensor<2xi64>/rhs_dilation = array<i64: \1, \1>/g' \
-e 's/window_strides = dense<\([0-9]*\)> : tensor<2xi64>/window_strides = array<i64: \1, \1>/g' \
model.mlir
```
ci-extra: build_e2e_test_artifacts, test_benchmark_suites
diff --git a/build_tools/python/e2e_test_framework/models/tf_models.py b/build_tools/python/e2e_test_framework/models/tf_models.py
index 3c81790..c616aa1 100644
--- a/build_tools/python/e2e_test_framework/models/tf_models.py
+++ b/build_tools/python/e2e_test_framework/models/tf_models.py
@@ -21,7 +21,7 @@
tags=["int32", "seqlen128"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_5aed9c3c3dfe8247ce76b74d518fa570b94dc0c3732631734d02ad70e4c74867.mlirbc",
entry_function="predict",
input_types=["1x128xi32", "1x128xi32", "1x128xi32"],
)
@@ -32,7 +32,7 @@
tags=["fp32", "seqlen512", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#tfbertformaskedlm
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_e757a10b24f6ff83aaae0ceb5bb05d4efe9ff3e9931f8e9a29f12bc5c2e42b5e.mlirbc",
entry_function="forward",
input_types=["1x512xi32", "1x512xi32"],
)
@@ -43,7 +43,7 @@
tags=["fp32", "cnn", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/efficientnet_v2.py
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_1af8c88f4e64e388a0c87bbeddcfb888084059df30cd631340d51794a0796e0f.mlirbc",
entry_function="forward",
input_types=["1x384x384x3xf32"],
)
@@ -56,7 +56,7 @@
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Derived from https://github.com/mlcommons/inference/tree/master/language/bert
# Instructions on how to regenerate the model: https://gist.github.com/mariecwhite/e61ccebd979d98d097946ac7725bcc29
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_000793afb016fb3afc559304bcb3ba6cdb2df1825e8976ca236c07c12e4f65fa.mlirbc",
entry_function="serving_default",
input_types=["1x384xi32", "1x384xi32", "1x384xi32"],
)
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
index 76515ce..8e2739b 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
@@ -305,7 +305,7 @@
RankedTensorType::get(resultType.getShape(),
lhsType.getElementType()),
lhs, resultExtents,
- rewriter.getI64TensorAttr(lhsBroadcastDimensions));
+ rewriter.getDenseI64ArrayAttr(lhsBroadcastDimensions));
auto rhsBroadcastDimensions = llvm::to_vector(
llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
Value broadcastedRhs =
@@ -314,7 +314,7 @@
RankedTensorType::get(resultType.getShape(),
rhsType.getElementType()),
rhs, resultExtents,
- rewriter.getI64TensorAttr(rhsBroadcastDimensions));
+ rewriter.getDenseI64ArrayAttr(rhsBroadcastDimensions));
// And generate the final non-broadcasted binary op.
Value finalResult = Adaptor::createOp(
@@ -353,7 +353,7 @@
rewriter.create<mlir::stablehlo::ConstantOp>(loc, op.getValue());
Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
- op, resultTy, constant, shape, rewriter.getI64TensorAttr({}));
+ op, resultTy, constant, shape, rewriter.getDenseI64ArrayAttr({}));
return success();
}
};
@@ -412,7 +412,7 @@
RankedTensorType::get(resultType.getShape(),
predType.getElementType()),
pred, resultExtents,
- rewriter.getI64TensorAttr(predBroadcastDimensions));
+ rewriter.getDenseI64ArrayAttr(predBroadcastDimensions));
}
auto onTrueBroadcastDimensions = llvm::to_vector(
llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
@@ -422,7 +422,7 @@
RankedTensorType::get(resultType.getShape(),
onTrueType.getElementType()),
onTrue, resultExtents,
- rewriter.getI64TensorAttr(onTrueBroadcastDimensions));
+ rewriter.getDenseI64ArrayAttr(onTrueBroadcastDimensions));
auto onFalseBroadcastDimensions = llvm::to_vector(
llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
Value broadcastedOnFalse =
@@ -431,7 +431,7 @@
RankedTensorType::get(resultType.getShape(),
onFalseType.getElementType()),
onFalse, resultExtents,
- rewriter.getI64TensorAttr(onFalseBroadcastDimensions));
+ rewriter.getDenseI64ArrayAttr(onFalseBroadcastDimensions));
// And generate the final non-broadcasted ternary op.
Value finalResult = rewriter.create<mlir::stablehlo::SelectOp>(
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeToLinalgUtils.h b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeToLinalgUtils.h
index f311a0e..6055b8b 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeToLinalgUtils.h
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeToLinalgUtils.h
@@ -103,6 +103,16 @@
/// Extracts integer values from the attribute |elements|.
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements);
+/// Returns true if the given |values| is a splat of the given |queryValue|.
+inline bool isSplatValue(const ArrayRef<int64_t> &values, int64_t queryValue) {
+ for (auto value : values) {
+ if (value != queryValue) {
+ return false;
+ }
+ }
+ return true;
+}
+
/// Returns true if the given |attr| is a splat of the given |value|.
inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
index fe6d413..8158c21 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
@@ -764,7 +764,7 @@
return failure();
Location loc = op.getLoc();
- DenseIntElementsAttr empty = rewriter.getI64TensorAttr({});
+ DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({});
if (elemTy.hasStaticShape()) {
SmallVector<Value> broadcasts(op.getNumResults());
for (auto [bcast, init, outTy] : llvm::zip_equal(
@@ -907,8 +907,7 @@
if (!operandType || !operandType.hasStaticShape())
return failure();
- auto sliceEnd =
- llvm::to_vector(gather.getSliceSizes().getValues<int64_t>());
+ auto sliceEnd = llvm::to_vector(gather.getSliceSizes());
SmallVector<int64_t> sliceStart(sliceEnd.size(), 0);
for (auto [mapIndex, value] :
llvm::zip_equal(dnums.getStartIndexMap(), index.getValues<APInt>())) {
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/GatherToTorchIndexSelect.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/GatherToTorchIndexSelect.cpp
index 58e39ad..998c9d6 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/GatherToTorchIndexSelect.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/GatherToTorchIndexSelect.cpp
@@ -80,18 +80,17 @@
}
}
- for (auto [idx, value] :
- llvm::enumerate(gather.getSliceSizes().getValues<APInt>())) {
+ for (auto [idx, value] : llvm::enumerate(gather.getSliceSizes())) {
// First shape value must be 1.
if (idx == 0) {
- if (value.getSExtValue() != 1) {
+ if (value != 1) {
return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
- if (value.getSExtValue() != operandTy.getDimSize(idx)) {
+ if (value != operandTy.getDimSize(idx)) {
return rewriter.notifyMatchFailure(
gather, "slice_size doesn't match operand dimension");
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
index 7fb89d1..0f3d82e 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
@@ -1244,8 +1244,8 @@
LogicalResult matchAndRewrite(mlir::stablehlo::DotOp op,
PatternRewriter &rewriter) const override {
- auto lhs = op.getLhs();
- auto rhs = op.getRhs();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
auto lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/UnfuseBatchNorm.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/UnfuseBatchNorm.cpp
index de696d6..3b0ed03 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/UnfuseBatchNorm.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/UnfuseBatchNorm.cpp
@@ -27,8 +27,7 @@
Value broadcastToFeatureDim(Location loc, RankedTensorType resultType,
Value value1d, Value shapeValue, int64_t featureDim,
PatternRewriter &rewriter) {
- auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64));
- auto dims = DenseIntElementsAttr::get(dimsType, {featureDim});
+ DenseI64ArrayAttr dims = rewriter.getDenseI64ArrayAttr({featureDim});
if (shapeValue) {
return rewriter.createOrFold<mlir::stablehlo::DynamicBroadcastInDimOp>(
loc, resultType, value1d, shapeValue, dims);
@@ -73,8 +72,7 @@
auto epsilonTensorAttr =
DenseElementsAttr::get(scalarType, {cast<Attribute>(epsilonAttr)});
Value epsilon = b.create<mlir::stablehlo::ConstantOp>(epsilonTensorAttr);
- auto dimsType = RankedTensorType::get({0}, b.getIntegerType(64));
- auto dims = DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
+ DenseI64ArrayAttr dims = rewriter.getDenseI64ArrayAttr({});
if (broadcastToType.hasStaticShape()) {
return b.create<mlir::stablehlo::BroadcastInDimOp>(broadcastToType, epsilon,
/*broadcast_dims=*/dims);
@@ -159,8 +157,7 @@
auto reduceResultType = RankedTensorType::get(
{operandType.getDimSize(featureIndex)}, operandType.getElementType());
auto reduce = rewriter.create<mlir::stablehlo::ReduceOp>(
- loc, reduceResultType, operand, zero,
- rewriter.getI64TensorAttr(reduceDims));
+ loc, reduceResultType, operand, zero, reduceDims);
// setup "stablehlo.reduce"'s body
Region ®ion = reduce.getBody();
@@ -207,7 +204,7 @@
reduceSize = b.create<mlir::stablehlo::ReshapeOp>(
RankedTensorType::get({}, operandType.getElementType()), reduceSize);
return b.createOrFold<mlir::stablehlo::DynamicBroadcastInDimOp>(
- scaleType, reduceSize, scaleShape, b.getI64TensorAttr({}));
+ scaleType, reduceSize, scaleShape, b.getDenseI64ArrayAttr({}));
}
// the "operand" has static shape
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/canonicalization.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/canonicalization.mlir
index 0f990b9..d5e8f54 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/canonicalization.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/canonicalization.mlir
@@ -478,9 +478,9 @@
func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor<*xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>
%1 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {
- broadcast_dimensions = dense<0> : tensor<1xi64>,
- known_expanding_dimensions = dense<> : tensor<0xi64>,
- known_nonexpanding_dimensions = dense<0> : tensor<1xi64>
+ broadcast_dimensions = array<i64: 0>,
+ known_expanding_dimensions = array<i64>,
+ known_nonexpanding_dimensions = array<i64: 0>
} : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor<?xf32>
// CHECK: return %[[RES]] : tensor<?xf32>
@@ -499,7 +499,7 @@
start_index_map = [0, 2],
>,
indices_are_sorted = false,
- slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32>
+ slice_sizes = array<i64: 3, 6, 5>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32>
return %1 : tensor<3x6x5xf32>
// CHECK: %[[RET:.*]] = stablehlo.slice %arg0 [1:4, 0:6, 2:7]
// CHECK-SAME: : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32>
@@ -518,7 +518,7 @@
start_index_map = [2],
>,
indices_are_sorted = false,
- slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<i32>) -> tensor<5x6x4xf32>
+ slice_sizes = array<i64: 5, 6, 4>} : (tensor<5x6x7xf32>, tensor<i32>) -> tensor<5x6x4xf32>
return %1 : tensor<5x6x4xf32>
// CHECK: %[[RET:.*]] = stablehlo.slice %arg0 [0:5, 0:6, 1:5]
// CHECK-SAME: : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
@@ -538,7 +538,7 @@
start_index_map = [0, 2],
>,
indices_are_sorted = false,
- slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32>
+ slice_sizes = array<i64: 3, 6, 1>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32>
return %1 : tensor<3x6xf32>
// CHECK: %[[V0:.*]] = stablehlo.slice %arg0 [1:4, 0:6, 2:3]
// CHECK-SAME: : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32>
@@ -558,7 +558,7 @@
collapsed_slice_dims = [0],
start_index_map = [0]
>, indices_are_sorted = true,
- slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
+ slice_sizes = array<i64: 1, 2>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
return %1 : tensor<2xui32>
// CHECK: %[[V0:.*]] = stablehlo.slice %arg0 [3:4, 0:2]
// CHECK-SAME: : (tensor<4x2xui32>) -> tensor<1x2xui32>
@@ -578,7 +578,7 @@
collapsed_slice_dims = [0],
start_index_map = [0]
>, indices_are_sorted = true,
- slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
+ slice_sizes = array<i64: 1, 2>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
return %1 : tensor<2xui32>
// CHECK: %[[V0:.*]] = stablehlo.slice %arg0 [0:1, 0:2]
// CHECK-SAME: : (tensor<4x2xui32>) -> tensor<1x2xui32>
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/gather_to_torch_index_select.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/gather_to_torch_index_select.mlir
index 79b944d..c2530ae 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/gather_to_torch_index_select.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/gather_to_torch_index_select.mlir
@@ -16,7 +16,7 @@
start_index_map = [0],
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 4>
} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
// CHECK: return [[RES]]
@@ -34,7 +34,7 @@
start_index_map = [0],
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 3]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 3>
} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
func.return %0 : tensor<1x3x3xf32>
}
@@ -50,7 +50,7 @@
start_index_map = [0, 1],
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 4>
} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
func.return %0 : tensor<1x3x4xf32>
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
index 2ebad08..972067f 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -64,10 +64,10 @@
// CHECK: stablehlo.broadcast_in_dim %[[OR]], dims = [] : (tensor<i32>) -> tensor<1x8x8x64xi32>
// CHECK: %[[XOR:.*]] = stablehlo.xor %[[ARG2]], %[[ARG3]] : tensor<i32>
// CHECK: stablehlo.broadcast_in_dim %[[XOR]], dims = [] : (tensor<i32>) -> tensor<1x8x8x64xi32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %2 = "stablehlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
- %3 = "stablehlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
+ %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
+ %2 = "stablehlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = array<i64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
+ %3 = "stablehlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = array<i64>} : (tensor<i32>) -> tensor<1x8x8x64xi32>
%4 = stablehlo.add %0, %1 : tensor<1x8x8x64xf32>
%5 = stablehlo.atan2 %0, %1 : tensor<1x8x8x64xf32>
%6 = stablehlo.divide %0, %1 : tensor<1x8x8x64xf32>
@@ -92,8 +92,8 @@
func.func @reorder_broadcast_in_dim_scalar_binary_diff_type(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<1x8x8x64xcomplex<f32>> {
// CHECK: %[[X:.+]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<complex<f32>>
// CHECK: stablehlo.broadcast_in_dim %[[X]], dims = [] : (tensor<complex<f32>>) -> tensor<1x8x8x64xcomplex<f32>>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
+ %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
%2 = "stablehlo.complex"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xcomplex<f32>>
return %2 : tensor<1x8x8x64xcomplex<f32>>
}
@@ -104,8 +104,8 @@
func.func @reorder_broadcast_in_dim_1d_binary(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<4x3xf32> {
// CHECK: %[[ATAN2:.*]] = stablehlo.atan2 %[[ARG0]], %[[ARG1]] : tensor<3xf32>
// CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[ATAN2]], dims = [1] : (tensor<3xf32>) -> tensor<4x3xf32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
- %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xf32>) -> tensor<4x3xf32>
+ %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<3xf32>) -> tensor<4x3xf32>
%2 = stablehlo.atan2 %0, %1 : tensor<4x3xf32>
// CHECK: return %[[BCAST]]
return %2 : tensor<4x3xf32>
@@ -117,8 +117,8 @@
func.func @reorder_broadcast_in_dim_2d_binary(%arg0: tensor<2x4xi32>, %arg1: tensor<2x4xi32>) -> tensor<3x2x4xi32> {
// CHECK: %[[POWER:.*]] = stablehlo.power %[[ARG0]], %[[ARG1]] : tensor<2x4xi32>
// CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[POWER]], dims = [1, 2] : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
- %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
%2 = stablehlo.power %0, %1 : tensor<3x2x4xi32>
// CHECK: return %[[BCAST]]
return %2 : tensor<3x2x4xi32>
@@ -154,7 +154,7 @@
// CHECK: stablehlo.broadcast_in_dim %[[SQRT]], dims = [] : (tensor<f32>) -> tensor<1x8x8x64xf32>
// CHECK: %[[TANH:.*]] = stablehlo.tanh %[[ARG0]] : tensor<f32>
// CHECK: stablehlo.broadcast_in_dim %[[TANH]], dims = [] : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x8x8x64xf32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64> } : (tensor<f32>) -> tensor<1x8x8x64xf32>
%1 = stablehlo.abs %0 : tensor<1x8x8x64xf32>
%2 = stablehlo.ceil %0 : tensor<1x8x8x64xf32>
%3 = stablehlo.cosine %0 : tensor<1x8x8x64xf32>
@@ -177,7 +177,7 @@
func.func @reorder_broadcast_in_dim_1d_unary(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
// CHECK: %[[COS:.*]] = stablehlo.cosine %[[ARG0]] : tensor<3xf32>
// CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[COS]], dims = [1] : (tensor<3xf32>) -> tensor<4x3xf32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xf32>) -> tensor<4x3xf32>
%1 = stablehlo.cosine %0 : tensor<4x3xf32>
// CHECK: return %[[BCAST]]
return %1 : tensor<4x3xf32>
@@ -189,7 +189,7 @@
func.func @reorder_in_dim_2d_unary(%arg0: tensor<2x4xf32>) -> tensor<3x2x4xf32> {
// CHECK: %[[LOG:.*]] = stablehlo.log %[[ARG0]] : tensor<2x4xf32>
// CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[LOG]], dims = [1, 2] : (tensor<2x4xf32>) -> tensor<3x2x4xf32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32>
%1 = stablehlo.log %0 : tensor<3x2x4xf32>
// CHECK: return %[[BCAST]]
return %1 : tensor<3x2x4xf32>
@@ -203,7 +203,7 @@
// CHECK: stablehlo.broadcast_in_dim %[[REAL]], dims = [] : (tensor<f32>) -> tensor<1x8x8x64xf32>
// CHECK: %[[IMAG:.*]] = stablehlo.imag %[[ARG0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: stablehlo.broadcast_in_dim %[[IMAG]], dims = [] : (tensor<f32>) -> tensor<1x8x8x64xf32>
- %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<complex<f32>>) -> tensor<1x8x8x64xcomplex<f32>>
+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64>} : (tensor<complex<f32>>) -> tensor<1x8x8x64xcomplex<f32>>
%1 = stablehlo.real %0 : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
%2 = stablehlo.imag %0 : (tensor<1x8x8x64xcomplex<f32>>) -> tensor<1x8x8x64xf32>
return %1, %2: tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
@@ -277,7 +277,7 @@
// CHECK-LABEL: @mul_float_bool_cast_broadcast
func.func @mul_float_bool_cast_broadcast(%arg0: tensor<5xi1>, %arg1: tensor<5x6xf32>) -> tensor<5x6xf32> {
%0 = stablehlo.convert %arg0 : (tensor<5xi1>) -> tensor<5xf32>
- %1 = "stablehlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<5x6xf32>
+ %1 = "stablehlo.broadcast_in_dim"(%0) {broadcast_dimensions = array<i64: 0>} : (tensor<5xf32>) -> tensor<5x6xf32>
%2 = stablehlo.multiply %1, %arg1 : tensor<5x6xf32>
return %2 : tensor<5x6xf32>
}
@@ -290,7 +290,7 @@
func.func @mul_float_bool_cast_dyn_broadcast(%arg0: tensor<?xi1>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = stablehlo.convert %arg0 : (tensor<?xi1>) -> tensor<?xf32>
%1 = shape.shape_of %arg1 : tensor<?x?xf32> -> tensor<2xindex>
- %2 = "stablehlo.dynamic_broadcast_in_dim"(%0, %1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+ %2 = "stablehlo.dynamic_broadcast_in_dim"(%0, %1) {broadcast_dimensions = array<i64: 0>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%3 = stablehlo.multiply %2, %arg1 : tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}
@@ -351,11 +351,11 @@
batch_group_count = 1 : i64,
dimension_numbers = #stablehlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
feature_group_count = 1 : i64,
- lhs_dilation = dense<1> : tensor<1xi64>,
+ lhs_dilation = array<i64: 1>,
padding = dense<0> : tensor<1x2xi64>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
- rhs_dilation = dense<1> : tensor<1xi64>,
- window_strides = dense<[1]> : tensor<1xi64>
+ rhs_dilation = array<i64: 1>,
+ window_strides = array<i64: 1>
} : (tensor<16x32x256xf32>, tensor<1x256x256xbf16>) -> tensor<16x32x256xf32>
// CHECK: return %[[CONV]]
func.return %0 : tensor<16x32x256xf32>
@@ -427,8 +427,8 @@
func.func @broadcast_iota_sort_slice_is_topk(%in : tensor<16x16x16xf32>) -> (tensor<16x16x8xf32>, tensor<16x16x8xi32>) {
%iota = "stablehlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<16xi32>
- %broadcasted_0 = "stablehlo.broadcast_in_dim"(%iota) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<16xi32>) -> tensor<16x16xi32>
- %broadcasted_1 = "stablehlo.broadcast_in_dim"(%broadcasted_0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<16x16xi32>) -> tensor<16x16x16xi32>
+ %broadcasted_0 = "stablehlo.broadcast_in_dim"(%iota) {broadcast_dimensions = array<i64: 1>} : (tensor<16xi32>) -> tensor<16x16xi32>
+ %broadcasted_1 = "stablehlo.broadcast_in_dim"(%broadcasted_0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<16x16xi32>) -> tensor<16x16x16xi32>
%0:2 = "stablehlo.sort"(%in, %broadcasted_1) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
@@ -448,8 +448,8 @@
func.func @broadcast_iota_sort_slice_incorrect_dims(%in : tensor<16x16x16xf32>) -> (tensor<16x16x8xf32>, tensor<16x16x8xi32>) {
%iota = "stablehlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<16xi32>
- %broadcasted_0 = "stablehlo.broadcast_in_dim"(%iota) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<16xi32>) -> tensor<16x16xi32>
- %broadcasted_1 = "stablehlo.broadcast_in_dim"(%broadcasted_0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<16x16xi32>) -> tensor<16x16x16xi32>
+ %broadcasted_0 = "stablehlo.broadcast_in_dim"(%iota) {broadcast_dimensions = array<i64: 1>} : (tensor<16xi32>) -> tensor<16x16xi32>
+ %broadcasted_1 = "stablehlo.broadcast_in_dim"(%broadcasted_0) {broadcast_dimensions = array<i64: 0, 1>} : (tensor<16x16xi32>) -> tensor<16x16x16xi32>
%0:2 = "stablehlo.sort"(%in, %broadcasted_1) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
index ff36eb1..e775a44 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
@@ -586,7 +586,8 @@
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- SmallVector<int64_t> broadcastDimensions = op.getBroadcastDimensions();
+ SmallVector<int64_t> broadcastDimensions =
+ llvm::to_vector(op.getBroadcastDimensions());
Value operand = adaptor.getOperand();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
@@ -721,7 +722,8 @@
if (!resultTy)
return failure();
- SmallVector<int64_t> broadcastDimensions = op.getBroadcastDimensions();
+ SmallVector<int64_t> broadcastDimensions =
+ llvm::to_vector(op.getBroadcastDimensions());
SmallVector<std::optional<bool>> expansionBehavior(
broadcastDimensions.size());
@@ -1745,7 +1747,7 @@
int64_t resultRank = resultType.getRank();
// slice_sizes has to have the same size as operand.rank, and doing it this
// way permits an unranked operand.
- int64_t operandRank = gatherOp.getSliceSizes().getNumElements();
+ int64_t operandRank = gatherOp.getSliceSizes().size();
int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim();
@@ -1960,10 +1962,8 @@
if (!op.getWindowDimensions().has_value())
return rewriter.notifyMatchFailure(op, "no window dimensions found");
- auto strides =
- llvm::to_vector(op.getWindowStridesAttr().getValues<int64_t>());
- auto window =
- llvm::to_vector(op.getWindowDimensionsAttr().getValues<int64_t>());
+ auto strides = llvm::to_vector(op.getWindowStrides().value());
+ auto window = llvm::to_vector(op.getWindowDimensions().value());
if (static_cast<int64_t>(strides.size()) != operandTy.getRank() ||
static_cast<int64_t>(window.size()) != operandTy.getRank())
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
index 08f77e6..5d68d62 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
@@ -19,13 +19,16 @@
/// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(Location loc, Value input,
DenseIntElementsAttr padding,
- DenseIntElementsAttr lhsDilation,
+ DenseI64ArrayAttr lhsDilation,
llvm::ArrayRef<int64_t> dimMappings,
OpBuilder &rewriter) {
- if ((!padding || isSplatValue(padding, 0)) &&
- (!lhsDilation || isSplatValue(lhsDilation, 1))) {
+ SmallVector<int64_t> lhsDilationValues;
+ if (lhsDilation)
+ lhsDilationValues = llvm::to_vector(lhsDilation.asArrayRef());
+ bool noPadding = !padding || isSplatValue(padding, 0);
+ bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1);
+ if (noPadding && noDilation)
return input;
- }
auto inputType = cast<ShapedType>(input.getType());
int64_t rank = inputType.getRank();
@@ -48,10 +51,10 @@
// Translate input dilation into interior padding.
SmallVector<int64_t, 8> padInterior(rank, 0);
if (lhsDilation) {
- assert(rank == lhsDilation.size() + 2);
- for (int64_t i : llvm::seq<int64_t>(0, lhsDilation.size())) {
+ assert(rank == static_cast<int64_t>(lhsDilationValues.size()) + 2);
+ for (int64_t i : llvm::seq<int64_t>(0, lhsDilationValues.size())) {
int64_t dim = dimMappings[i];
- padInterior[dim] = lhsDilation.getValues<int64_t>()[i] - 1;
+ padInterior[dim] = lhsDilationValues[i] - 1;
}
}
@@ -68,10 +71,8 @@
RankedTensorType::get({}, inputType.getElementType())));
}
- return rewriter.create<mlir::stablehlo::PadOp>(
- loc, input, zero, rewriter.getDenseI64ArrayAttr(padLow),
- rewriter.getDenseI64ArrayAttr(padHigh),
- rewriter.getDenseI64ArrayAttr(padInterior));
+ return rewriter.create<mlir::stablehlo::PadOp>(loc, input, zero, padLow,
+ padHigh, padInterior);
}
/// If the ConvolutionOp has a window reversal, applies it to the filter.
@@ -83,8 +84,7 @@
return filter;
}
llvm::SmallVector<int64_t> reversedDims;
- for (auto [idx, reversed] :
- llvm::enumerate(reversals.value().getValues<bool>())) {
+ for (auto [idx, reversed] : llvm::enumerate(reversals.value())) {
if (reversed) {
reversedDims.push_back(
op.getDimensionNumbers().getKernelSpatialDimensions()[idx]);
@@ -213,8 +213,12 @@
loc, resultType.getShape(), resultType.getElementType(), dynSizes);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
linalg::LinalgOp res;
- Attribute strides = op.getWindowStridesAttr();
- Attribute dilations = op.getRhsDilationAttr();
+ Attribute strides;
+ if (auto s = op.getWindowStrides())
+ strides = rewriter.getI64TensorAttr(*s);
+ Attribute dilations;
+ if (auto d = op.getRhsDilation())
+ dilations = rewriter.getI64TensorAttr(*d);
// Apply padding and input dilation.
llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
@@ -507,7 +511,7 @@
AffineExpr stride = dim0;
if (op.getWindowStrides().has_value())
- stride = stride * op.getWindowStrides().value().getValues<int64_t>()[i];
+ stride = stride * op.getWindowStrides().value()[i];
AffineExpr srcExpr = stride + dim1;
srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr;
@@ -596,7 +600,7 @@
Attribute windowStrides;
if (op.getWindowStrides()) {
- windowStrides = op.getWindowStrides().value();
+ windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value());
} else {
windowStrides = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
@@ -605,7 +609,7 @@
Attribute rhsDilation;
if (op.getRhsDilation()) {
- rhsDilation = op.getRhsDilation().value();
+ rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value());
} else {
rhsDilation = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
index 0505d26..52c287d 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
@@ -451,13 +451,11 @@
// ScanOp
//===----------------------------------------------------------------------===//
-static bool checkUnary(DenseIntElementsAttr attr) {
- llvm::SmallVector<int64_t> values;
- values = extract1DVector(attr);
-
- bool result = true;
+static bool checkUnary(const ArrayRef<int64_t> &values) {
for (auto value : values) {
- result = result && (value == 1);
+ if (value != 1) {
+ return false;
+ }
}
return true;
}
@@ -490,7 +488,7 @@
auto init0 = op.getInitValues().front();
auto init0Ty = init0.getType().cast<ShapedType>();
- auto window = extract1DVector(op.getWindowDimensions());
+ auto window = llvm::to_vector(op.getWindowDimensions());
llvm::SmallVector<int64_t, 4> reduceAxes;
for (int i = 0, s = window.size(); i < s; ++i) {
if (window[i] == 1)
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
index 5474459..632dbb5 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
@@ -123,7 +123,7 @@
}
auto srcRank = cast<ShapedType>(adaptor.getInputs()[0].getType()).getRank();
- SmallVector<int64_t> reductionDims = extract1DVector(op.getDimensions());
+ SmallVector<int64_t> reductionDims = llvm::to_vector(op.getDimensions());
SmallVector<Type> resultTypes;
if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
@@ -220,8 +220,7 @@
"unsupported reduce (noop or empty)");
}
- auto reductionDims =
- llvm::to_vector(op.getDimensions().getValues<int64_t>());
+ auto reductionDims = llvm::to_vector(op.getDimensions());
// stablehlo.reduce doesn't specify the order of the reduction dimensions.
llvm::sort(reductionDims);
@@ -332,8 +331,7 @@
return failure();
auto numOperands = initValues.size();
- llvm::SmallVector<int64_t> windowDimensions =
- extract1DVector(op.getWindowDimensions());
+ llvm::SmallVector<int64_t> windowDimensions(op.getWindowDimensions());
llvm::SmallVector<int64_t> padding;
if (op.getPadding()) {
@@ -342,17 +340,17 @@
llvm::SmallVector<int64_t> baseDilations;
if (op.getBaseDilations()) {
- baseDilations = extract1DVector(*op.getBaseDilations());
+ baseDilations = llvm::to_vector(*op.getBaseDilations());
}
llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
if (op.getWindowStrides()) {
- windowStrides = extract1DVector(*op.getWindowStrides());
+ windowStrides = llvm::to_vector(*op.getWindowStrides());
}
llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
if (op.getWindowDilations()) {
- windowDilations = extract1DVector(*op.getWindowDilations());
+ windowDilations = llvm::to_vector(*op.getWindowDilations());
}
auto rank = static_cast<int64_t>(windowDimensions.size());
@@ -426,13 +424,9 @@
staticInteriors[idx] = dilation - 1;
}
- auto padLows = rewriter.getDenseI64ArrayAttr(staticLows);
- auto padHighs = rewriter.getDenseI64ArrayAttr(staticHighs);
- auto padInteriors = rewriter.getDenseI64ArrayAttr(staticInteriors);
-
for (auto [input, initValue] : llvm::zip(inputs, initValues)) {
input = rewriter.create<mlir::stablehlo::PadOp>(
- loc, input, initValue, padLows, padHighs, padInteriors);
+ loc, input, initValue, staticLows, staticHighs, staticInteriors);
}
}
@@ -567,19 +561,17 @@
int lastDim = rank - 1;
SmallVector<int64_t, 2> fakeWindowShapes;
for (int i = 1; i < lastDim; ++i) {
- fakeWindowShapes.push_back(
- op.getWindowDimensions().getValues<int64_t>()[i]);
+ fakeWindowShapes.push_back(op.getWindowDimensions()[i]);
}
if (op.getWindowStrides() &&
- (op.getWindowStrides().value().getValues<int64_t>()[0] != 1 ||
- op.getWindowStrides().value().getValues<int64_t>()[lastDim] != 1)) {
+ (op.getWindowStrides().value()[0] != 1 ||
+ op.getWindowStrides().value()[lastDim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_strides to be [1,x,y,(z),1]");
}
- if (op.getWindowDimensions() &&
- (op.getWindowDimensions().getValues<int64_t>()[0] != 1 ||
- op.getWindowDimensions().getValues<int64_t>()[lastDim] != 1)) {
+ if (op.getWindowDimensions()[0] != 1 ||
+ op.getWindowDimensions()[lastDim] != 1) {
return rewriter.notifyMatchFailure(
op, "expected window_dimensions to be [1,x,y,(z),1]");
}
@@ -588,7 +580,7 @@
SmallVector<int64_t> vec;
if (op.getWindowStridesAttr()) {
for (int i = 1; i < lastDim; ++i) {
- vec.push_back(op.getWindowStrides().value().getValues<int64_t>()[i]);
+ vec.push_back(op.getWindowStrides().value()[i]);
}
} else {
vec.assign(rank - 2, 1);
@@ -599,7 +591,7 @@
vec.clear();
if (op.getWindowDilations()) {
for (int i = 1; i < lastDim; ++i) {
- vec.push_back(op.getWindowDilations().value().getValues<int64_t>()[i]);
+ vec.push_back(op.getWindowDilations().value()[i]);
}
} else {
vec.assign(rank - 2, 1);
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_no_broadcast.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_no_broadcast.mlir
index 9e6192c..3a6fbd1 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_no_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_no_broadcast.mlir
@@ -157,7 +157,7 @@
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: stablehlo.add
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 1>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -167,7 +167,7 @@
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: stablehlo.add
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -177,7 +177,7 @@
func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -187,7 +187,7 @@
func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 2>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_with_broadcast.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_with_broadcast.mlir
index 95ea14a..6ee69ca 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_with_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_chlo_with_broadcast.mlir
@@ -153,7 +153,7 @@
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: stablehlo.add
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 1>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -162,7 +162,7 @@
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: stablehlo.add
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -171,7 +171,7 @@
func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -180,7 +180,7 @@
func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
- %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
+ %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array<i64: 2>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_control_flow.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_control_flow.mlir
index 069f67f..3c158b1 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_control_flow.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/legalize_control_flow.mlir
@@ -71,7 +71,7 @@
%5 = stablehlo.constant dense<1> : tensor<i32>
%6 = stablehlo.add %arg1, %5 : tensor<i32>
%7 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<i32>
- %8 = "stablehlo.broadcast_in_dim"(%7) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<3xi32>
+ %8 = "stablehlo.broadcast_in_dim"(%7) {broadcast_dimensions = array<i64>} : (tensor<i32>) -> tensor<3xi32>
%9 = stablehlo.add %arg2, %8 : tensor<3xi32>
"stablehlo.return"(%6, %9) : (tensor<i32>, tensor<3xi32>) -> ()
}) : (tensor<i32>, tensor<3xi32>) -> (tensor<i32>, tensor<3xi32>)
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
index bcf0c43..52f7c58 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
@@ -365,7 +365,7 @@
// CHECK: func @broadcast_in_dim
func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
- {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
+ {broadcast_dimensions = array<i64: 4,0,2>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
func.return %0 : tensor<7x10x6x4x5xf32>
}
@@ -389,7 +389,7 @@
// CHECK: func @broadcast_in_dim_ui32
func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
- {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
+ {broadcast_dimensions = array<i64: 4,0,2>}
: (tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32>
func.return %0 : tensor<7x10x6x4x5xui32>
}
@@ -440,7 +440,7 @@
func.func @broadcast_in_dim_with_transpose(
%operand: tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
- {broadcast_dimensions = dense<[2, 0, 1]> : tensor<3xi64>}
+ {broadcast_dimensions = array<i64: 2, 0, 1>}
: (tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32>
func.return %0 : tensor<3x4x2x5xf32>
}
@@ -605,9 +605,9 @@
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @dynamic_iota_f32
-// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
-func.func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
- %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
+func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor<?x?x8xf32> {
+ %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xf32>)
func.return %result : tensor<?x?x8xf32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
@@ -626,10 +626,10 @@
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK: func @dyanmic_iota_ui32
-// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
-func.func @dyanmic_iota_ui32(%shape: tensor<?xi32>) -> tensor<?x?x8xui32> {
- %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xui32>)
+// CHECK: func @dynamic_iota_ui32
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
+func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor<?x?x8xui32> {
+ %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xui32>)
func.return %result : tensor<?x?x8xui32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
@@ -655,7 +655,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<[0]> : tensor<1xi64>}
+ }) {dimensions = array<i64: 0>}
: (tensor<?xf32>, tensor<4xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
@@ -675,7 +675,7 @@
^bb0(%arg2: tensor<f32>):
%1 = stablehlo.add %arg2, %arg2 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<[0]> : tensor<1xi64>}
+ }) {dimensions = array<i64: 0>}
: (tensor<?xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
@@ -706,7 +706,7 @@
{comparison_direction = #stablehlo<comparison_direction EQ>}
: (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%3) : (tensor<i1>) -> ()
- }) {dimensions = dense<[0]> : tensor<1xi64>}
+ }) {dimensions = array<i64: 0>}
: (tensor<?xcomplex<f32>>, tensor<?xcomplex<f32>>) -> tensor<?xi1>
func.return %0 : tensor<?xi1>
}
@@ -1093,8 +1093,8 @@
stablehlo.return %9 : tensor<f32>
}) {
padding = dense<0> : tensor<4x2xi64>,
- window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
+ window_dimensions = array<i64: 1, 2, 2, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>
} : (tensor<2x8x8x1xf32>, tensor<2x4x4x1xf32>, tensor<f32>) -> tensor<2x8x8x1xf32>
return %0 : tensor<2x8x8x1xf32>
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_convolution.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_convolution.mlir
index 2b20586..2e361bd 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_convolution.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_convolution.mlir
@@ -38,8 +38,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[0, 0]]> : tensor<1x2xi64>,
- rhs_dilation = dense<1> : tensor<1xi64>,
- window_strides = dense<1> : tensor<1xi64>,
+ rhs_dilation = array<i64: 1>,
+ window_strides = array<i64: 1>,
someattr
} : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32>
func.return %0 : tensor<?x7x?xf32>
@@ -81,8 +81,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x4x?xf32>
func.return %0 : tensor<?x2x4x?xf32>
}
@@ -224,8 +224,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>,
- rhs_dilation = dense<1> : tensor<3xi64>,
- window_strides = dense<1> : tensor<3xi64>
+ rhs_dilation = array<i64: 1, 1, 1>,
+ window_strides = array<i64: 1, 1, 1>
} : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
func.return %0 : tensor<?x7x7x7x?xf32>
}
@@ -264,8 +264,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 2, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
func.return %0 : tensor<1x2x4x3xf32>
}
@@ -350,8 +350,8 @@
>,
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>,
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>,
someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32>
func.return %0 : tensor<2x3x4x6xf32>
}
@@ -391,8 +391,8 @@
>,
feature_group_count = 2 : i64,
padding = dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>,
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>,
someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32>
func.return %0 : tensor<2x3x6x4xf32>
}
@@ -439,8 +439,8 @@
>,
feature_group_count = 96 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 2, 2>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
func.return %0 : tensor<1x56x56x96xf32>
}
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
@@ -477,8 +477,8 @@
>,
feature_group_count = 96 : i64,
padding = dense<[[1, 1], [2, 2]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 2, 2>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32>
func.return %0 : tensor<1x57x58x96xf32>
}
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
index ebd0019..38792b3 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
@@ -625,7 +625,7 @@
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%787 = "stablehlo.add"(%arg2, %arg3) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%787) : (tensor<i32>) -> ()
- }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 0], [4, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[1, 5]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<7x5xi32>, tensor<i32>) -> tensor<7x5xi32>
+ }) {base_dilations = array<i64: 1, 1>, padding = dense<[[0, 0], [4, 0]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 1>, window_dimensions = array<i64: 1, 5>, window_strides = array<i64: 1, 1>} : (tensor<7x5xi32>, tensor<i32>) -> tensor<7x5xi32>
return %reduce : tensor<7x5xi32>
}
// CHECK: %extracted = tensor.extract %[[ARG1]][] : tensor<i32>
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_gather.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_gather.mlir
index c8003fc..2156cd9 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_gather.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_gather.mlir
@@ -15,7 +15,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>,
+ slice_sizes = array<i64: 1, 1, 8>,
someattr
} : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32>
func.return %res : tensor<1x8x8xi32>
@@ -60,7 +60,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>,
+ slice_sizes = array<i64: 1, 1, 8>,
someattr
} : (tensor<1x4x8xi32>, tensor<1x8x2xui32>) -> tensor<1x8x8xi32>
func.return %res : tensor<1x8x8xi32>
@@ -85,7 +85,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 1, 8>
} : (tensor<1x4x8xui32>, tensor<1x8x2xi32>) -> tensor<1x8x8xui32>
func.return %res : tensor<1x8x8xui32>
}
@@ -108,7 +108,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[4, 2]> : tensor<2xi64>
+ slice_sizes = array<i64: 4, 2>
} : (tensor<6x3xi32>, tensor<5x2xi32>) -> tensor<5x4x2xi32>
func.return %res : tensor<5x4x2xi32>
}
@@ -149,7 +149,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[2, 3, 4]> : tensor<3xi64>
+ slice_sizes = array<i64: 2, 3, 4>
} : (tensor<?x?x?xi32>, tensor<5x2xi32>) -> tensor<2x3x4x5xi32>
func.return %res : tensor<2x3x4x5xi32>
}
@@ -199,7 +199,7 @@
start_index_map = [3, 1, 2, 0]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 2, 1, 4]> : tensor<4xi64>
+ slice_sizes = array<i64: 1, 2, 1, 4>
} : (tensor<6x3x2x7xi32>, tensor<5x4xi32>) -> tensor<5x2x4xi32>
func.return %res : tensor<5x2x4xi32>
}
@@ -257,7 +257,7 @@
start_index_map = [0]
>,
indices_are_sorted = false,
- slice_sizes = dense<[3, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 3, 4>
} : (tensor<?x?xi32>, tensor<5x2xi32>) -> tensor<3x4x5x2xi32>
func.return %res : tensor<3x4x5x2xi32>
}
@@ -298,7 +298,7 @@
start_index_map = [0]
>,
indices_are_sorted = false,
- slice_sizes = dense<[3, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 3, 4>
} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<3x4x?xi32>
func.return %res : tensor<3x4x?xi32>
}
@@ -339,7 +339,7 @@
start_index_map = [0]
>,
indices_are_sorted = false,
- slice_sizes = dense<[3, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 3, 4>
} : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?x?xi32>
func.return %res : tensor<?x?x?xi32>
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_reduce.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_reduce.mlir
index de2fd82..f61474b 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_reduce.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_reduce.mlir
@@ -14,7 +14,7 @@
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>, someattr} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
+ }) {dimensions = array<i64: 1>, someattr} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
func.return %0 : tensor<5xi32>
}
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
@@ -47,7 +47,7 @@
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>, someattr} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+ }) {dimensions = array<i64: 1>, someattr} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}
// CHECK: stablehlo.reduce
@@ -64,7 +64,7 @@
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = stablehlo.maximum %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<4xi32>
+ }) {dimensions = array<i64: 0>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<4xi32>
func.return %0 : tensor<4xi32>
}
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
@@ -94,7 +94,7 @@
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = stablehlo.maximum %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<?xi32>
+ }) {dimensions = array<i64: 0>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
@@ -116,7 +116,7 @@
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%1 = stablehlo.add %arg1, %arg2 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
+ }) {dimensions = array<i64: 1>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
func.return %0 : tensor<1xf32>
}
// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty()
@@ -141,7 +141,7 @@
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<5x4x3xi32>, tensor<i32>) -> tensor<4xi32>
+ }) {dimensions = array<i64: 0, 2>} : (tensor<5x4x3xi32>, tensor<i32>) -> tensor<4xi32>
func.return %0 : tensor<4xi32>
}
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
@@ -220,7 +220,7 @@
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
@@ -259,7 +259,7 @@
%673 = "stablehlo.select"(%669, %arg3, %arg16) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%674 = "stablehlo.select"(%671, %672, %673) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%670, %674) : (tensor<i32>, tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
+ }) {dimensions = array<i64: 0>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
func.return %res0, %res1 : tensor<2xi32>, tensor<2xi32>
}
// CHECK-DAG: %[[CST0:.*]] = arith.constant -2147483648 : i32
@@ -320,7 +320,7 @@
%1 = "stablehlo.select"(%0, %arg7, %arg9) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = "stablehlo.select"(%0, %arg8, %arg10) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%1, %2) : (tensor<f32>, tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<128x10xf32>, tensor<128x10xi32>, tensor<f32>, tensor<i32>) ->(tensor<128xf32>, tensor<128xi32>)
+ }) {dimensions = array<i64: 1>} : (tensor<128x10xf32>, tensor<128x10xi32>, tensor<f32>, tensor<i32>) ->(tensor<128xf32>, tensor<128xi32>)
func.return %res0, %res1 : tensor<128xf32>, tensor<128xi32>
}
// CHECK-DAG: %[[CST0:.*]] = arith.constant 1.000000e+00 : f32
@@ -403,8 +403,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>,
someattr} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
func.return %0 : tensor<1x8x8x64xf32>
}
@@ -430,8 +430,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
func.return %0 : tensor<1x8x8x64xf32>
}
// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
@@ -455,8 +455,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
func.return %0 : tensor<1x8x8x64xf32>
}
// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
@@ -479,8 +479,8 @@
^bb0(%arg1: tensor<f32>, %arg2 : tensor<f32>):
%2 = stablehlo.maximum %arg1, %arg2 : tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
func.return %1 : tensor<1x8x8x64xf32>
}
@@ -506,8 +506,8 @@
%1 = stablehlo.add %arg2, %arg4 : tensor<f32>
%2 = stablehlo.maximum %arg3, %arg5 : tensor<f32>
"stablehlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
func.return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
}
@@ -541,8 +541,8 @@
^bb0(%arg1: tensor<ui32>, %arg2: tensor<ui32>):
stablehlo.return %arg1 : tensor<ui32>
}) {
- window_dimensions = dense<[1, 1]> : tensor<2xi64>,
- window_strides = dense<[1, 1]> : tensor<2xi64>
+ window_dimensions = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x1xui32>, tensor<ui32>) -> tensor<1x1xui32>
return %1 : tensor<1x1xui32>
}
@@ -558,8 +558,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
func.return %0 : tensor<?x?x?x?xf32>
}
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -597,8 +597,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
- window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 2, 1>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
func.return %0 : tensor<1x8x8x8x64xf32>
}
// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
@@ -622,8 +622,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
- window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 2, 1>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
func.return %0 : tensor<1x8x8x8x64xf32>
}
// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
@@ -647,8 +647,8 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
- window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ }) {window_dimensions = array<i64: 1, 3, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 2, 1>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
func.return %0 : tensor<1x8x8x8x64xf32>
}
// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
@@ -672,9 +672,9 @@
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {base_dilations = dense<[1, 1, 1, 2, 1]> : tensor<5xi64>,
- window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
- window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x16x64xf32>
+ }) {base_dilations = array<i64: 1, 1, 1, 2, 1>,
+ window_dimensions = array<i64: 1, 3, 3, 3, 1>,
+ window_strides = array<i64: 1, 2, 2, 2, 1>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x16x64xf32>
func.return %0 : tensor<1x8x8x16x64xf32>
}
@@ -694,7 +694,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ }) {base_dilations = array<i64: 1, 1>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 2>, window_dimensions = array<i64: 1, 2>, window_strides = array<i64: 2, 1>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
func.return %0 : tensor<4x7xf32>
}
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7xf32>
@@ -731,7 +731,7 @@
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
%2 = stablehlo.multiply %1, %c2 : tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
- }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ }) {base_dilations = array<i64: 1, 1>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 2>, window_dimensions = array<i64: 1, 2>, window_strides = array<i64: 2, 1>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
func.return %0 : tensor<4x7xf32>
}
@@ -751,7 +751,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x7xf32>
+ }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 2>, window_dimensions = array<i64: 1, 2>, window_strides = array<i64: 2, 1>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x7xf32>
func.return %0 : tensor<3x7xf32>
}
// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
@@ -768,7 +768,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x4xf32>
+ }) {base_dilations = array<i64: 2, 1>, window_dilations = array<i64: 1, 2>, window_dimensions = array<i64: 1, 2>, window_strides = array<i64: 2, 1>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x4xf32>
func.return %0 : tensor<3x4xf32>
}
// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
@@ -786,7 +786,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ }) {base_dilations = array<i64: 2, 1>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 2>, window_dimensions = array<i64: 1, 2>, window_strides = array<i64: 2, 1>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
func.return %0 : tensor<4x7xf32>
}
// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
@@ -803,7 +803,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {base_dilations = dense<> : tensor<0xi64>, padding = dense<> : tensor<0x2xi64>, window_dilations = dense<> : tensor<0xi64>, window_dimensions = dense<> : tensor<0xi64>, window_strides = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ }) {base_dilations = array<i64>, padding = dense<> : tensor<0x2xi64>, window_dilations = array<i64>, window_dimensions = array<i64>, window_strides = array<i64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
}
// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
diff --git a/tests/e2e/regression/dynamic_reduce_min.mlir b/tests/e2e/regression/dynamic_reduce_min.mlir
index 577b47d..382d268 100644
--- a/tests/e2e/regression/dynamic_reduce_min.mlir
+++ b/tests/e2e/regression/dynamic_reduce_min.mlir
@@ -6,7 +6,7 @@
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = stablehlo.minimum %arg1, %arg2 : tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
- }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<f32>
+ }) {dimensions = array<i64: 0, 1>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<f32>
check.expect_almost_eq_const(%1, dense<-4.0> : tensor<f32>) : tensor<f32>
return
}
diff --git a/tests/e2e/regression/lowering_config.mlir b/tests/e2e/regression/lowering_config.mlir
index 259eb01..766fbdf 100644
--- a/tests/e2e/regression/lowering_config.mlir
+++ b/tests/e2e/regression/lowering_config.mlir
@@ -39,8 +39,8 @@
dimension_numbers = #stablehlo.conv<raw input_batch_dimension = 0, input_feature_dimension = 3, input_spatial_dimensions = [1, 2], kernel_input_feature_dimension = 2, kernel_output_feature_dimension = 3, kernel_spatial_dimensions = [0, 1], output_batch_dimension = 0, output_feature_dimension = 3, output_spatial_dimensions = [1, 2]>,
feature_group_count = 1 : i64,
padding = dense<1> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<36x7x7x512xf32>, tensor<3x3x512x512xf32>) -> tensor<36x7x7x512xf32>
%1 = "stablehlo.convolution"(%input, %filter) {
compilation_info = #conv_compilation1,
@@ -48,8 +48,8 @@
dimension_numbers = #stablehlo.conv<raw input_batch_dimension = 0, input_feature_dimension = 3, input_spatial_dimensions = [1, 2], kernel_input_feature_dimension = 2, kernel_output_feature_dimension = 3, kernel_spatial_dimensions = [0, 1], output_batch_dimension = 0, output_feature_dimension = 3, output_spatial_dimensions = [1, 2]>,
feature_group_count = 1 : i64,
padding = dense<1> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<36x7x7x512xf32>, tensor<3x3x512x512xf32>) -> tensor<36x7x7x512xf32>
check.expect_almost_eq(%0, %1) : tensor<36x7x7x512xf32>
return
diff --git a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
index 2e5bcf9..f68c5d9 100644
--- a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
+++ b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
@@ -19,7 +19,7 @@
from iree.compiler.tools import InputType, compile_file
from iree.runtime import load_vm_flatbuffer_file
-MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.45208053dcd69ebd7428fe5b785249a7bdff2d62d55fb81b815889c4e1b993bb.tar"
+MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.2bec0cb356ae7c059e04624a627eb3b15b0a556cbd781bbed9f8d32e80a4311d.tar"
Tensor = TypeVar("Tensor")
diff --git a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
index 980cbdd..420cddd 100644
--- a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
+++ b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
@@ -68,11 +68,11 @@
cf.cond_br %extracted, ^bb2(%26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39 : tensor<i64>, tensor<i64>, tensor<40xf32>, tensor<i64>, tensor<74x40xf32>, tensor<i64>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<5x1x64xf32>, tensor<5x1x1xf32>, tensor<5x1x1xf32>, tensor<5xi64>, tensor<5x1x10xf32>, tensor<5x1x10xf32>), ^bb3(%26, %31, %32, %33, %37, %38, %39 : tensor<i64>, tensor<i64>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<5xi64>, tensor<5x1x10xf32>, tensor<5x1x10xf32>)
^bb2(%41: tensor<i64>, %42: tensor<i64>, %43: tensor<40xf32>, %44: tensor<i64>, %45: tensor<74x40xf32>, %46: tensor<i64>, %47: tensor<1x10xf32>, %48: tensor<1x10xf32>, %49: tensor<5x1x64xf32>, %50: tensor<5x1x1xf32>, %51: tensor<5x1x1xf32>, %52: tensor<5xi64>, %53: tensor<5x1x10xf32>, %54: tensor<5x1x10xf32>): // pred: ^bb1
%55 = stablehlo.add %41, %cst_5 : tensor<i64>
- %56 = "stablehlo.gather"(%50, %41) {dimension_numbers = #stablehlo.gather<offset_dims = [0, 1], collapsed_slice_dims = [0], start_index_map = [0]>, slice_sizes = dense<1> : tensor<3xi64>} : (tensor<5x1x1xf32>, tensor<i64>) -> tensor<1x1xf32>
+ %56 = "stablehlo.gather"(%50, %41) {dimension_numbers = #stablehlo.gather<offset_dims = [0, 1], collapsed_slice_dims = [0], start_index_map = [0]>, slice_sizes = array<i64: 1, 1, 1>} : (tensor<5x1x1xf32>, tensor<i64>) -> tensor<1x1xf32>
%57 = stablehlo.reshape %56 : (tensor<1x1xf32>) -> tensor<1xf32>
%58 = stablehlo.broadcast_in_dim %57, dims = [0] : (tensor<1xf32>) -> tensor<1x10xf32>
%59 = stablehlo.compare GT, %58, %7 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xi1>
- %60 = "stablehlo.gather"(%49, %41) {dimension_numbers = #stablehlo.gather<offset_dims = [0, 1], collapsed_slice_dims = [0], start_index_map = [0]>, slice_sizes = dense<[1, 1, 64]> : tensor<3xi64>} : (tensor<5x1x64xf32>, tensor<i64>) -> tensor<1x64xf32>
+ %60 = "stablehlo.gather"(%49, %41) {dimension_numbers = #stablehlo.gather<offset_dims = [0, 1], collapsed_slice_dims = [0], start_index_map = [0]>, slice_sizes = array<i64: 1, 1, 64>} : (tensor<5x1x64xf32>, tensor<i64>) -> tensor<1x64xf32>
%61 = stablehlo.concatenate %60, %48, dim = 1 : (tensor<1x64xf32>, tensor<1x10xf32>) -> tensor<1x74xf32>
%62 = stablehlo.dot %61, %45, precision = [DEFAULT] : (tensor<1x74xf32>, tensor<74x40xf32>) -> tensor<1x40xf32>
%63 = stablehlo.reshape %43 : (tensor<40xf32>) -> tensor<1x40xf32>
diff --git a/tests/e2e/stablehlo_ops/broadcast_in_dim.mlir b/tests/e2e/stablehlo_ops/broadcast_in_dim.mlir
index 0f92c52..e78bd10 100644
--- a/tests/e2e/stablehlo_ops/broadcast_in_dim.mlir
+++ b/tests/e2e/stablehlo_ops/broadcast_in_dim.mlir
@@ -1,7 +1,7 @@
func.func @broadcast_in_dim_2D_3D() {
%input = util.unfoldable_constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi32>
- %res = "stablehlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ %res = "stablehlo.broadcast_in_dim"(%input) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
check.expect_eq_const(%res, dense<[
[[1, 2, 3, 4], [5, 6, 7, 8]],
[[1, 2, 3, 4], [5, 6, 7, 8]],
@@ -11,7 +11,7 @@
func.func @broadcast_in_dim_3D_scalar() {
%input = util.unfoldable_constant dense<42> : tensor<i32>
- %res = "stablehlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<3x2x4xi32>
+ %res = "stablehlo.broadcast_in_dim"(%input) {broadcast_dimensions = array<i64>} : (tensor<i32>) -> tensor<3x2x4xi32>
check.expect_eq_const(%res, dense<42> : tensor<3x2x4xi32>) : tensor<3x2x4xi32>
return
}
diff --git a/tests/e2e/stablehlo_ops/convolution.mlir b/tests/e2e/stablehlo_ops/convolution.mlir
index b7fb713..57e223a 100644
--- a/tests/e2e/stablehlo_ops/convolution.mlir
+++ b/tests/e2e/stablehlo_ops/convolution.mlir
@@ -22,8 +22,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
check.expect_almost_eq_const(%res, dense<[[
[[1310.0],[1466.0],[1622.0]],
[[2090.0],[2246.0],[2402.0]]
@@ -60,8 +60,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x4x1xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<2x4x4x1xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
check.expect_almost_eq_const(%res, dense<[[
[[1310.0],[1466.0],[1622.0]],
[[2090.0],[2246.0],[2402.0]]
@@ -93,8 +93,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
check.expect_almost_eq_const(%res, dense<[[
[[1310.0],[1466.0],[1622.0]],
[[2090.0],[2246.0],[2402.0]]
@@ -126,8 +126,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<1x3x2x2xf32>) -> tensor<1x2x3x1xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<1x4x4x2xf32>, tensor<1x3x2x2xf32>) -> tensor<1x2x3x1xf32>
check.expect_almost_eq_const(%res, dense<[[
[[1310.0],[1466.0],[1622.0]],
[[2090.0],[2246.0],[2402.0]]
@@ -159,8 +159,8 @@
output_spatial_dimensions = [3, 1]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x3x1x2xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x3x1x2xf32>
check.expect_almost_eq_const(%res, dense<[[
[[1310.0, 2090.0]],
[[1466.0, 2246.0]],
@@ -194,8 +194,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} :
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} :
(tensor<1x4x5x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x4x5x1xf32>
check.expect_almost_eq_const(%res, dense<[[
[[ 600.0], [ 736.0], [ 872.0], [1008.0], [ 476.0]],
@@ -233,8 +233,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} :
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} :
(tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32>
check.expect_almost_eq_const(%res, dense<[
[[[ 80.0], [121.0], [142.0], [163.0], [100.0]],
@@ -323,8 +323,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} :
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} :
(tensor<2x4x5x3xf32>, tensor<2x3x3x6xf32>) -> tensor<2x3x3x6xf32>
check.expect_almost_eq_const(%res, dense<[
[[[16065.0, 16290.0, 16515.0, 16740.0, 16965.0, 17190.0],
@@ -394,8 +394,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 2, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
check.expect_almost_eq_const(%res, dense<
[[[[-0.45181108, -0.37253797, -1.1074474 ],
@@ -427,8 +427,8 @@
>,
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32>
check.expect_almost_eq_const(%res, dense<4.0> : tensor<2x3x4x6xf32>) : tensor<2x3x4x6xf32>
return
}
diff --git a/tests/e2e/stablehlo_ops/gather.mlir b/tests/e2e/stablehlo_ops/gather.mlir
index af3cd0e..6b95536 100644
--- a/tests/e2e/stablehlo_ops/gather.mlir
+++ b/tests/e2e/stablehlo_ops/gather.mlir
@@ -13,7 +13,7 @@
offset_dims = [0, 1],
start_index_map = [0],
>,
- slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 1, 5>
} : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
return
@@ -34,7 +34,7 @@
offset_dims = [0, 1],
start_index_map = [0],
>,
- slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 1, 5>
} : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
return
@@ -64,7 +64,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 1, 8>
} : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32>
check.expect_eq_const(%result, dense<[[
[ 8, 9, 10, 11, 12, 13, 14, 15],
@@ -97,7 +97,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 1, 3]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 1, 3>
} : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x4x3xi32>
check.expect_eq_const(%result, dense<[[
[ 8, 9, 10],
@@ -126,7 +126,7 @@
start_index_map = [0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ slice_sizes = array<i64: 1, 2, 3>
} : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x2x3x4xi32>
check.expect_eq_const(%result, dense<[[
[[ 8, 16, 16, 0],
@@ -157,7 +157,7 @@
start_index_map = [3, 2, 0, 1]
>,
indices_are_sorted = false,
- slice_sizes = dense<[1, 2, 1, 3]> : tensor<4xi64>
+ slice_sizes = array<i64: 1, 2, 1, 3>
} : (tensor<1x3x2x4xi32>, tensor<2x4xi32>) -> tensor<2x2x3xi32>
check.expect_eq_const(%result, dense<[
diff --git a/tests/e2e/stablehlo_ops/reduce.mlir b/tests/e2e/stablehlo_ops/reduce.mlir
index fdad897..07e6aae 100644
--- a/tests/e2e/stablehlo_ops/reduce.mlir
+++ b/tests/e2e/stablehlo_ops/reduce.mlir
@@ -6,7 +6,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
check.expect_eq_const(%res, dense<55> : tensor<1xi32>) : tensor<1xi32>
return
}
@@ -19,7 +19,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
check.expect_eq_const(%res, dense<10> : tensor<1xi32>) : tensor<1xi32>
return
}
@@ -32,7 +32,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<5x1x1xi32>, tensor<i32>) -> tensor<5xi32>
+ }) {dimensions = array<i64: 1, 2>} : (tensor<5x1x1xi32>, tensor<i32>) -> tensor<5xi32>
check.expect_eq_const(%res, dense<[1, 2, 3, 4, 5]> : tensor<5xi32>) : tensor<5xi32>
return
}
@@ -50,7 +50,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<3xi32>
+ }) {dimensions = array<i64: 0>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<3xi32>
check.expect_eq_const(%res, dense<[5, 7, 9]> : tensor<3xi32>) : tensor<3xi32>
return
}
@@ -64,7 +64,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<2xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<2xi32>
check.expect_eq_const(%res, dense<[6, 15]> : tensor<2xi32>) : tensor<2xi32>
return
}
@@ -80,7 +80,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<2x3xi32>
+ }) {dimensions = array<i64: 0>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<2x3xi32>
check.expect_eq_const(%res, dense<[[4, 8, 12],[16, 20, 24]]> : tensor<2x3xi32>) : tensor<2x3xi32>
return
}
@@ -96,7 +96,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<4x2xi32>
+ }) {dimensions = array<i64: 2>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<4x2xi32>
check.expect_eq_const(%res, dense<[[6, 15],[6, 15],[6, 15],[6, 15]]> : tensor<4x2xi32>) : tensor<4x2xi32>
return
}
@@ -112,7 +112,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<3xi32>
+ }) {dimensions = array<i64: 0, 1>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<3xi32>
check.expect_eq_const(%res, dense<[20, 28, 36]> : tensor<3xi32>) : tensor<3xi32>
return
}
@@ -128,7 +128,7 @@
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<i32>
+ }) {dimensions = array<i64: 0, 1, 2>} : (tensor<4x2x3xi32>, tensor<i32>) -> tensor<i32>
check.expect_eq_const(%res, dense<84> : tensor<i32>) : tensor<i32>
return
}
@@ -141,7 +141,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
+ }) {dimensions = array<i64: 1>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
check.expect_almost_eq_const(%res, dense<55.0> : tensor<1xf32>) : tensor<1xf32>
return
}
@@ -156,7 +156,7 @@
%3 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
})
- {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
+ {dimensions = array<i64: 1>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
check.expect_almost_eq_const(%res, dense<10.0> : tensor<1xf32>) : tensor<1xf32>
return
}
@@ -169,7 +169,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<5x1x1xf32>, tensor<f32>) -> tensor<5xf32>
+ }) {dimensions = array<i64: 1, 2>} : (tensor<5x1x1xf32>, tensor<f32>) -> tensor<5xf32>
check.expect_almost_eq_const(%res, dense<[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<5xf32>) : tensor<5xf32>
return
}
@@ -184,7 +184,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<3xf32>
+ }) {dimensions = array<i64: 0>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<3xf32>
check.expect_almost_eq_const(%res, dense<[5.0, 7.0, 9.0]> : tensor<3xf32>) : tensor<3xf32>
return
}
@@ -196,7 +196,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
+ }) {dimensions = array<i64: 1>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
check.expect_almost_eq_const(%res, dense<[6.0, 15.0]> : tensor<2xf32>) : tensor<2xf32>
return
}
@@ -212,7 +212,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
+ }) {dimensions = array<i64: 0>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
check.expect_almost_eq_const(%res, dense<[[4.0, 8.0, 12.0],[16.0, 20.0, 24.0]]> : tensor<2x3xf32>) : tensor<2x3xf32>
return
}
@@ -228,7 +228,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<4x3xf32>
+ }) {dimensions = array<i64: 1>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<4x3xf32>
check.expect_almost_eq_const(%res, dense<[
[5.0, 7.0, 9.0],
[5.0, 7.0, 9.0],
@@ -248,7 +248,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<4x2xf32>
+ }) {dimensions = array<i64: 2>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<4x2xf32>
check.expect_almost_eq_const(%res, dense<[
[6.0, 15.0],
[6.0, 15.0],
@@ -268,7 +268,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<3xf32>
+ }) {dimensions = array<i64: 0, 1>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<3xf32>
check.expect_almost_eq_const(%res, dense<[20.0, 28.0, 36.0]> : tensor<3xf32>) : tensor<3xf32>
return
}
@@ -284,7 +284,7 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<f32>
+ }) {dimensions = array<i64: 0, 1, 2>} : (tensor<4x2x3xf32>, tensor<f32>) -> tensor<f32>
check.expect_almost_eq_const(%res, dense<84.0> : tensor<f32>) : tensor<f32>
return
}
@@ -303,7 +303,7 @@
%4 = "stablehlo.select"(%0, %arg3, %arg5) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%5 = "stablehlo.select"(%2, %3, %4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%1, %5) : (tensor<i32>, tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
+ }) {dimensions = array<i64: 0>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
check.expect_eq_const(%res0, dense<[17, 18]> : tensor<2xi32>) : tensor<2xi32>
check.expect_eq_const(%res1, dense<[16, 17]> : tensor<2xi32>) : tensor<2xi32>
return
@@ -316,7 +316,7 @@
^bb0(%arg0 : tensor<i32>, %arg1 : tensor<i32>):
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x5xi32>, tensor<i32>) -> tensor<2xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<2x5xi32>, tensor<i32>) -> tensor<2xi32>
check.expect_eq_const(%2, dense<[25, 50]> : tensor<2xi32>) : tensor<2xi32>
return
}
@@ -330,7 +330,7 @@
^bb0(%arg0 : tensor<i32>, %arg1 : tensor<i32>):
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x5xi32>, tensor<i32>) -> tensor<2xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<2x5xi32>, tensor<i32>) -> tensor<2xi32>
check.expect_eq_const(%2, dense<[25, 50]> : tensor<2xi32>) : tensor<2xi32>
return
}
@@ -342,7 +342,7 @@
^bb0(%arg0 : tensor<i32>, %arg1 : tensor<i32>):
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
+ }) {dimensions = array<i64: 1>} : (tensor<1x10xi32>, tensor<i32>) -> tensor<1xi32>
check.expect_eq_const(%2, dense<[65]> : tensor<1xi32>) : tensor<1xi32>
return
}
@@ -354,7 +354,7 @@
^bb0(%arg0 : tensor<i32>, %arg1 : tensor<i32>):
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%3) : (tensor<i32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
+ }) {dimensions = array<i64: 0>} : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
check.expect_eq_const(%2, dense<65> : tensor<i32>) : tensor<i32>
return
}
diff --git a/tests/e2e/stablehlo_ops/reduce_window.mlir b/tests/e2e/stablehlo_ops/reduce_window.mlir
index 568d51c..2d21ca7 100644
--- a/tests/e2e/stablehlo_ops/reduce_window.mlir
+++ b/tests/e2e/stablehlo_ops/reduce_window.mlir
@@ -8,8 +8,8 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
+ }) {window_dimensions = array<i64: 1, 2, 3, 1>,
+ window_strides = array<i64: 1, 2, 3, 1>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
check.expect_eq_const(%res, dense<[[[[30.0], [48.0]],[[102.0], [120.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32>
return
}
@@ -24,8 +24,8 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 1, 1, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x3x4x1xf32>
+ }) {window_dimensions = array<i64: 1, 2, 3, 1>,
+ window_strides = array<i64: 1, 1, 1, 1>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x3x4x1xf32>
check.expect_eq_const(%res, dense<[[
[[ 30.0], [ 36.0], [ 42.0], [ 48.0]],
[[ 66.0], [ 72.0], [ 78.0], [ 84.0]],
@@ -43,8 +43,8 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
+ }) {window_dimensions = array<i64: 1, 2, 3, 1>,
+ window_strides = array<i64: 1, 2, 3, 1>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
check.expect_almost_eq_const(%res, dense<[[[[9.0], [12.0]], [[21.0], [24.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32>
return
}
@@ -59,8 +59,8 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
+ }) {window_dimensions = array<i64: 1, 2, 3, 1>,
+ window_strides = array<i64: 1, 2, 3, 1>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
check.expect_almost_eq_const(%res, dense<[[[[1.0], [4.0]], [[13.0], [14.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32>
return
}
@@ -75,8 +75,8 @@
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
%3 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>,
+ }) {window_dimensions = array<i64: 1, 2, 3, 1>,
+ window_strides = array<i64: 1, 2, 3, 1>,
padding = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x3x2x1xf32>
check.expect_almost_eq_const(%res, dense<[[[[3.0], [6.0]], [[15.0], [18.0]], [[21.0], [24.0]]]]> : tensor<1x3x2x1xf32>) : tensor<1x3x2x1xf32>
return
@@ -90,8 +90,8 @@
%4 = stablehlo.add %arg1, %arg2 : tensor<f32>
"stablehlo.return"(%4) : (tensor<f32>) -> ()
}) {padding = dense<[[1, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>,
- window_dimensions = dense<[2, 1, 1]> : tensor<3xi64>,
- window_strides = dense<1> : tensor<3xi64>
+ window_dimensions = array<i64: 2, 1, 1>,
+ window_strides = array<i64: 1, 1, 1>
} : (tensor<2x2x2xf32>, tensor<f32>) -> tensor<2x2x2xf32>
check.expect_almost_eq_const(%res, dense<[[[1.0, 1.0], [1.0, 1.0]], [[2.0, 2.0], [2.0, 2.0]]]> : tensor<2x2x2xf32>) : tensor<2x2x2xf32>
return
diff --git a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
index bbcc10f..e26b098 100644
--- a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
+++ b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
@@ -84,15 +84,15 @@
iree_fetch_artifact(
NAME "model-EfficientNetV2STF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_1af8c88f4e64e388a0c87bbeddcfb888084059df30cd631340d51794a0796e0f.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-MiniLML12H384Uncased"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_5aed9c3c3dfe8247ce76b74d518fa570b94dc0c3732631734d02ad70e4c74867.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
UNPACK
)
@@ -112,15 +112,15 @@
iree_fetch_artifact(
NAME "model-BertForMaskedLMTF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_e757a10b24f6ff83aaae0ceb5bb05d4efe9ff3e9931f8e9a29f12bc5c2e42b5e.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-BertLargeTF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_000793afb016fb3afc559304bcb3ba6cdb2df1825e8976ca236c07c12e4f65fa.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
UNPACK
)
diff --git a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
index 6639a69..57d2d49 100644
--- a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
+++ b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
@@ -246,7 +246,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -260,7 +260,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -302,7 +302,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -316,7 +316,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -552,7 +552,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -567,7 +567,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -612,7 +612,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -627,7 +627,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -867,7 +867,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -882,7 +882,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -927,7 +927,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -942,7 +942,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -1167,7 +1167,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1179,7 +1179,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1191,7 +1191,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1203,7 +1203,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1349,7 +1349,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2194,7 +2194,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2212,7 +2212,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2266,7 +2266,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2284,7 +2284,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2584,7 +2584,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2603,7 +2603,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2660,7 +2660,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2679,7 +2679,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2983,7 +2983,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -3002,7 +3002,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -3059,7 +3059,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -3078,7 +3078,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -3363,7 +3363,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3379,7 +3379,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3395,7 +3395,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3411,7 +3411,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3601,7 +3601,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
diff --git a/tests/e2e/vulkan_specific/conv.mlir b/tests/e2e/vulkan_specific/conv.mlir
index 0cce02e..571a6ee 100644
--- a/tests/e2e/vulkan_specific/conv.mlir
+++ b/tests/e2e/vulkan_specific/conv.mlir
@@ -61,8 +61,8 @@
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>}
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>}
: (tensor<1x4x6x2xf32>, tensor<2x3x2x3xf32>) -> (tensor<1x3x4x3xf32>)
check.expect_almost_eq_const(%2, dense<
[[[[ 8.39452888, 8.62796353, 8.86139818],
@@ -140,7 +140,7 @@
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
- >, feature_group_count = 1 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x3x3x4xf32>, tensor<2x2x4x32xf32>) -> tensor<1x2x2x32xf32>
+ >, feature_group_count = 1 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = array<i64: 1, 1>, window_strides = array<i64: 1, 1>} : (tensor<1x3x3x4xf32>, tensor<2x2x4x32xf32>) -> tensor<1x2x2x32xf32>
check.expect_almost_eq_const(%0, dense<
[[[[113.25, 127.0, 198.0, 173.25, 159.5, 190.75, 135.5, 160.0,
@@ -185,7 +185,7 @@
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
- >, feature_group_count = 16 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x1x4x16xf32>, tensor<1x1x1x16xf32>) -> tensor<1x1x4x16xf32>
+ >, feature_group_count = 16 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = array<i64: 1, 1>, window_strides = array<i64: 1, 1>} : (tensor<1x1x4x16xf32>, tensor<1x1x1x16xf32>) -> tensor<1x1x4x16xf32>
check.expect_almost_eq_const(%0, dense<
[[[[12.0, 15.0, 0.0, 3.0, 2.25, 17.5, 15.75, 5.0, 7.5, 0.0, 0.25, 7.5, 15.75, 10.5, 0.0, 16.25],
diff --git a/tests/microbenchmarks/stablehlo_conv.mlir b/tests/microbenchmarks/stablehlo_conv.mlir
index b320a2f..0318b6c 100644
--- a/tests/microbenchmarks/stablehlo_conv.mlir
+++ b/tests/microbenchmarks/stablehlo_conv.mlir
@@ -26,8 +26,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<2> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 2, 2>
} : (tensor<1x224x224x3xf32>, tensor<3x3x3x32xf32>) -> tensor<1x112x112x32xf32>
return %0 : tensor<1x112x112x32xf32>
}
@@ -50,8 +50,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x112x112x32xf32>, tensor<1x1x32x64xf32>) -> tensor<1x112x112x64xf32>
return %0 : tensor<1x112x112x64xf32>
}
@@ -74,8 +74,8 @@
>,
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x7x7x1024xf32>, tensor<1x1x1024x1024xf32>) -> tensor<1x7x7x1024xf32>
return %0 : tensor<1x7x7x1024xf32>
}
@@ -109,8 +109,8 @@
>,
feature_group_count = 1024 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x15x1x1024xf32>, tensor<15x1x1x1024xf32>) -> tensor<1x1x1x1024xf32>
return %res : tensor<1x1x1x1024xf32>
}
@@ -133,8 +133,8 @@
>,
feature_group_count = 512 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x15x1x512xf32>, tensor<15x1x1x512xf32>) -> tensor<1x1x1x512xf32>
return %res : tensor<1x1x1x512xf32>
}
@@ -157,8 +157,8 @@
>,
feature_group_count = 512 : i64,
padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>
+ rhs_dilation = array<i64: 1, 1>,
+ window_strides = array<i64: 1, 1>
} : (tensor<1x16x1x512xf32>, tensor<15x1x1x512xf32>) -> tensor<1x2x1x512xf32>
return %res : tensor<1x2x1x512xf32>
}
diff --git a/third_party/stablehlo b/third_party/stablehlo
index f8dcebf..0264c4d 160000
--- a/third_party/stablehlo
+++ b/third_party/stablehlo
@@ -1 +1 @@
-Subproject commit f8dcebfa1ec166806974f6ae0dfb902d36b47238
+Subproject commit 0264c4d64c82ae74a54b85d274eec5084c2c0abf