Bump StableHLO to f8dcebfa1ec166806974f6ae0dfb902d36b47238 (#16049)
Updated the serialized models too, but just manually, so used naming
convention to capture that it wasn't newly generated/manually done.
diff --git a/build_tools/python/e2e_test_framework/models/jax_models.py b/build_tools/python/e2e_test_framework/models/jax_models.py
index 874d894..6613683 100644
--- a/build_tools/python/e2e_test_framework/models/jax_models.py
+++ b/build_tools/python/e2e_test_framework/models/jax_models.py
@@ -11,7 +11,7 @@
from e2e_test_framework.definitions import common_definitions
import e2e_test_framework.models.utils as model_utils
-GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180"
+GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180j"
ID_FORMAT = string.Template("${model_id}-batch${batch_size}")
NAME_FORMAT = string.Template("${name}_BATCH${batch_size}")
diff --git a/build_tools/python/e2e_test_framework/models/tf_models.py b/build_tools/python/e2e_test_framework/models/tf_models.py
index 071aec9..3c81790 100644
--- a/build_tools/python/e2e_test_framework/models/tf_models.py
+++ b/build_tools/python/e2e_test_framework/models/tf_models.py
@@ -21,7 +21,7 @@
tags=["int32", "seqlen128"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="predict",
input_types=["1x128xi32", "1x128xi32", "1x128xi32"],
)
@@ -32,7 +32,7 @@
tags=["fp32", "seqlen512", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#tfbertformaskedlm
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="forward",
input_types=["1x512xi32", "1x512xi32"],
)
@@ -43,7 +43,7 @@
tags=["fp32", "cnn", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/efficientnet_v2.py
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="forward",
input_types=["1x384x384x3xf32"],
)
@@ -56,7 +56,7 @@
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Derived from https://github.com/mlcommons/inference/tree/master/language/bert
# Instructions on how to regenerate the model: https://gist.github.com/mariecwhite/e61ccebd979d98d097946ac7725bcc29
- source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc",
+ source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="serving_default",
input_types=["1x384xi32", "1x384xi32", "1x384xi32"],
)
@@ -81,7 +81,7 @@
input_types=["1x1xi32", "12x2x1x12x4x64xf32"],
)
-TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975"
+TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j"
ID_FORMAT = string.Template("${model_id}-batch-${batch_size}")
NAME_FORMAT = string.Template("${name}Batch${batch_size}")
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
index a4ec6a8..2d5726d 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Utils/IndexSet.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
@@ -448,7 +449,7 @@
llvm::to_vector(llvm::seq<int64_t>(0, inputShape.size()));
std::swap(permutation[srcDim], permutation[dstDim]);
std::swap(inputShape[srcDim], inputShape[dstDim]);
- DenseIntElementsAttr permutationAttr = rewriter.getI64VectorAttr(permutation);
+ auto permutationAttr = rewriter.getDenseI64ArrayAttr(permutation);
return rewriter.create<mlir::stablehlo::TransposeOp>(
loc, RankedTensorType::get(inputShape, inputType.getElementType()), input,
permutationAttr);
@@ -705,7 +706,7 @@
result = rewriter.create<mlir::stablehlo::TransposeOp>(
loc,
RankedTensorType::get(transposeResultShape, inputType.getElementType()),
- result, rewriter.getI64VectorAttr(permutation));
+ result, rewriter.getDenseI64ArrayAttr(permutation));
// Reshape
llvm::SmallVector<int64_t> finalShape(inputShape);
@@ -852,7 +853,7 @@
auto inputType = cast<RankedTensorType>(op.getOperand().getType());
SmallVector<int64_t> reduceInputShape(inputType.getShape());
Value reduceInput = adaptor.getOperand();
- DenseIntElementsAttr permutationAttr;
+ DenseI64ArrayAttr permutationAttr;
SmallVector<int64_t> scatterResultShape(resultType.getShape());
auto elemType = getElementTypeOrSelf(reduceInput.getType());
@@ -861,7 +862,7 @@
auto permutation =
llvm::to_vector(llvm::seq<int64_t>(0, scatterResultShape.size()));
std::swap(permutation[0], permutation[scatterDim]);
- permutationAttr = rewriter.getI64VectorAttr(permutation);
+ permutationAttr = rewriter.getDenseI64ArrayAttr(permutation);
std::swap(reduceInputShape[0], reduceInputShape[scatterDim]);
std::swap(scatterResultShape[0], scatterResultShape[scatterDim]);
// Transpose the input.
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
index b22f318..b26589a 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp
@@ -2144,14 +2144,14 @@
} else {
values = rewriter.create<mlir::stablehlo::SliceOp>(
op.getLoc(), tupleFirstElement,
- DenseIntElementsAttr::get(indicesTy, beginIndices),
- DenseIntElementsAttr::get(indicesTy, endIndices),
- DenseIntElementsAttr::get(indicesTy, strides));
+ rewriter.getDenseI64ArrayAttr(beginIndices),
+ rewriter.getDenseI64ArrayAttr(endIndices),
+ rewriter.getDenseI64ArrayAttr(strides));
indices = rewriter.create<mlir::stablehlo::SliceOp>(
op.getLoc(), tupleSecondElement,
- DenseIntElementsAttr::get(indicesTy, beginIndices),
- DenseIntElementsAttr::get(indicesTy, endIndices),
- DenseIntElementsAttr::get(indicesTy, strides));
+ rewriter.getDenseI64ArrayAttr(beginIndices),
+ rewriter.getDenseI64ArrayAttr(endIndices),
+ rewriter.getDenseI64ArrayAttr(strides));
}
rewriter.replaceOp(op, {values, indices});
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
index 63ce62e..160c048 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp
@@ -41,6 +41,16 @@
// allowed to materialize as new constants.
constexpr int64_t kFoldOpEltLimit = 65536;
+static bool isIotaRange(ArrayRef<int64_t> dims) {
+ for (auto [idx, value] : llvm::enumerate(dims)) {
+ if (idx != value) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
static bool isIotaRange(ElementsAttr attr) {
auto elems = attr.tryGetValues<APInt>();
if (!elems)
@@ -469,7 +479,7 @@
return failure();
// Fold when broadcast is a noop.
- DenseIntElementsAttr dims = op.getBroadcastDimensions();
+ auto dims = op.getBroadcastDimensions();
bool isDimsIota = isIotaRange(dims);
if (type == operandTy && isDimsIota) {
rewriter.replaceOp(op, operand);
@@ -485,7 +495,7 @@
return success();
}
- auto bsDimIndices = dims.getValues<int64_t>();
+ auto bsDimIndices = dims;
if (operandTy.hasStaticShape() && type.hasStaticShape() &&
type.getNumElements() == operandTy.getNumElements()) {
// BroadcastInDim equivalent to reshape.
@@ -505,12 +515,10 @@
// Eliminate redundant nested BroadcastInDim.
if (auto broadcastInDimOp =
operand.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>()) {
- auto newIndices = cast<DenseIntElementsAttr>(
- broadcastInDimOp.getBroadcastDimensions().mapValues(
- dims.getElementType(), [&bsDimIndices](const APInt &dim) {
- return APInt(dim.getBitWidth(),
- bsDimIndices[dim.getSExtValue()], true);
- }));
+ auto newIndices =
+ rewriter.getDenseI64ArrayAttr(llvm::to_vector(llvm::map_range(
+ broadcastInDimOp.getBroadcastDimensions(),
+ [&bsDimIndices](int64_t dim) { return bsDimIndices[dim]; })));
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
op, type, broadcastInDimOp.getOperand(), newIndices);
return success();
@@ -631,7 +639,7 @@
// output has static shape, replace with broadcast_in_dim
if (type.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
- op, type, op.getOperand(), op.getBroadcastDimensions());
+ op, type, op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}
@@ -648,7 +656,7 @@
refineOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
rewriter, op,
RankedTensorType::get(outputShape, type.getElementType()),
- op.getOperand(), op.getBroadcastDimensions());
+ op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}
}
@@ -670,16 +678,11 @@
return failure();
// Compose broadcast dimensions.
- DenseIntElementsAttr precedingBcastDims =
- precedingBcast.getBroadcastDimensions();
- DenseIntElementsAttr bcastDims = bcast.getBroadcastDimensions();
- SmallVector<APInt> composition;
- for (APInt precedingDim : precedingBcastDims) {
- composition.push_back(
- *(bcastDims.value_begin<APInt>() + precedingDim.getZExtValue()));
+ SmallVector<int64_t> composition;
+ for (int64_t precedingDim : precedingBcast.getBroadcastDimensions()) {
+ composition.push_back(bcast.getBroadcastDimensions()[precedingDim]);
}
- auto composedBcastDims =
- DenseIntElementsAttr::get(precedingBcastDims.getType(), composition);
+ auto composedBcastDims = rewriter.getDenseI64ArrayAttr(composition);
rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
bcast, bcast.getType(), precedingBcast.getOperand(),
@@ -928,9 +931,9 @@
auto sliceType = RankedTensorType::get(sliceShape, elementType);
Value result = rewriter.create<mlir::stablehlo::SliceOp>(
gather.getLoc(), sliceType, gather.getOperand(),
- rewriter.getI64TensorAttr(sliceStart),
- rewriter.getI64TensorAttr(sliceEnd),
- rewriter.getI64TensorAttr(sliceStride));
+ rewriter.getDenseI64ArrayAttr(sliceStart),
+ rewriter.getDenseI64ArrayAttr(sliceEnd),
+ rewriter.getDenseI64ArrayAttr(sliceStride));
ArrayRef<int64_t> collapsedSliceDims = dnums.getCollapsedSliceDims();
if (!collapsedSliceDims.empty()) {
@@ -1030,7 +1033,7 @@
"tensor type");
}
- SmallVector<int64_t> permValues(permutation.getValues<int64_t>());
+ SmallVector<int64_t> permValues(permutation);
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp
index ed28321..50e8e8f 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp
@@ -47,13 +47,8 @@
auto transposePermutation =
llvm::to_vector<5>(llvm::concat<const int64_t>(leftDims, rightDims));
- TensorType transposePermutationType =
- RankedTensorType::get({static_cast<int64_t>(transposePermutation.size())},
- rewriter.getIntegerType(64));
-
auto transposePermutationAttr =
- llvm::cast<DenseIntElementsAttr>(DenseIntElementsAttr::get(
- transposePermutationType, llvm::ArrayRef(transposePermutation)));
+ rewriter.getDenseI64ArrayAttr(transposePermutation);
// Compute the resulting shape.
llvm::SmallVector<int64_t, 5> transposedShape;
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp
index e671fed..e33b75b 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp
@@ -149,7 +149,7 @@
} else {
// Generate a transpose.
rewriter.replaceOpWithNewOp<mlir::stablehlo::TransposeOp>(
- einsum, dotGeneralOp, rewriter.getI64TensorAttr(resultPerms));
+ einsum, dotGeneralOp, rewriter.getDenseI64ArrayAttr(resultPerms));
}
return success();
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
index 53321b6..fff4726 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
@@ -39,11 +40,8 @@
return true;
}
-DenseIntElementsAttr make1DElementsAttr(OpBuilder &b,
- ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- b.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
+DenseI64ArrayAttr make1DElementsAttr(OpBuilder &b, ArrayRef<int64_t> integers) {
+ return b.getDenseI64ArrayAttr(integers);
}
Value getF32Const(ImplicitLocOpBuilder b, ArrayRef<int64_t> shapes,
@@ -90,7 +88,7 @@
auto transposed = rewriter.create<mlir::stablehlo::TransposeOp>(
op.getLoc(),
RankedTensorType::get(transposeShape, lhsType.getElementType()),
- op.getLhs(), rewriter.getI64TensorAttr(permutations));
+ op.getLhs(), rewriter.getDenseI64ArrayAttr(permutations));
llvm::SmallVector<int64_t> newSpatialDimensions(spatialDims.size());
std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1);
@@ -158,7 +156,7 @@
auto transposeKernel = rewriter.create<mlir::stablehlo::TransposeOp>(
op.getLoc(),
RankedTensorType::get(transposeShape, kernelType.getElementType()),
- kernel, rewriter.getI64TensorAttr(permutation));
+ kernel, rewriter.getDenseI64ArrayAttr(permutation));
auto newDimensionNumbers = mlir::stablehlo::ConvDimensionNumbersAttr::get(
op.getContext(), dimensionNumbers.getInputBatchDimension(),
@@ -246,7 +244,7 @@
auto transposed = rewriter.create<mlir::stablehlo::TransposeOp>(
op.getLoc(), resultType, newConv,
- rewriter.getI64TensorAttr(invertPermutation));
+ rewriter.getDenseI64ArrayAttr(invertPermutation));
rewriter.replaceOp(op, transposed.getResult());
return success();
@@ -286,7 +284,7 @@
}
return b.create<mlir::stablehlo::TransposeOp>(
loc, RankedTensorType::get(transposeShape, type.getElementType()), src,
- b.getI64TensorAttr(targetOrder));
+ b.getDenseI64ArrayAttr(targetOrder));
}
Value ReshapeIfNonStandard(OpBuilder &b, Location loc, Value src,
@@ -748,7 +746,8 @@
}
indices = builder.create<mlir::stablehlo::TransposeOp>(
- indicesTy.clone(newShape), indices, builder.getI64TensorAttr(perm));
+ indicesTy.clone(newShape), indices,
+ builder.getDenseI64ArrayAttr(perm));
indicesTy = llvm::cast<RankedTensorType>(indices.getType());
indexVectorDim = indicesTy.getRank() - 1;
}
@@ -792,7 +791,7 @@
newShape.push_back(updateTy.getDimSize(updatePerm[i]));
update = builder.create<mlir::stablehlo::TransposeOp>(
updateTy.clone(newShape), update,
- builder.getI64TensorAttr(updatePerm));
+ builder.getDenseI64ArrayAttr(updatePerm));
}
}
@@ -1025,7 +1024,7 @@
llvm::seq<int64_t>(resultRank - valueTy.getRank(), resultRank));
return rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
op.getLoc(), newTy, value, lhsShape,
- rewriter.getI64TensorAttr(dimensions));
+ rewriter.getDenseI64ArrayAttr(dimensions));
};
zero = broadcast(zero);
@@ -1184,7 +1183,7 @@
Value result =
rewriter.create<ElementwiseOpT>(op.getLoc(), resultType, bcastOperands);
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
- op, op.getType(), result, bcastOps[0].getBroadcastDimensions());
+ op, op.getType(), result, bcastOps[0].getBroadcastDimensionsAttr());
for (auto bcastOp : bcastOps) {
if (bcastOp.getOperation()->use_empty()) {
@@ -1283,11 +1282,11 @@
lhs = rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
op.getLoc(), resultTy.clone(lhsTy.getElementType()), lhs, outSize,
- rewriter.getI64TensorAttr({0, 1}));
+ rewriter.getDenseI64ArrayAttr({0, 1}));
rhs = rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
op.getLoc(), resultTy.clone(rhsTy.getElementType()), rhs, outSize,
- rewriter.getI64TensorAttr({0, 1}));
+ rewriter.getDenseI64ArrayAttr({0, 1}));
auto computeETy = lhsTy.getElementType();
if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth())
@@ -1451,12 +1450,12 @@
// Transpose the left hand side and the right hand side.
lhs = builder.create<mlir::stablehlo::TransposeOp>(
RankedTensorType::get(lhsTransposeShape, lhsTy.getElementType()), lhs,
- builder.getI64TensorAttr(permLhs));
+ builder.getDenseI64ArrayAttr(permLhs));
lhsTy = llvm::cast<RankedTensorType>(lhs.getType());
rhs = builder.create<mlir::stablehlo::TransposeOp>(
RankedTensorType::get(rhsTransposeShape, rhsTy.getElementType()), rhs,
- builder.getI64TensorAttr(permRhs));
+ builder.getDenseI64ArrayAttr(permRhs));
rhsTy = llvm::cast<RankedTensorType>(rhs.getType());
auto dimI32Ty = RankedTensorType::get({1}, builder.getI32Type());
@@ -1512,7 +1511,7 @@
RankedTensorType::get(resultTy.getShape(), lhsTy.getElementType());
lhs = builder.createOrFold<mlir::stablehlo::DynamicBroadcastInDimOp>(
lhsBroadcastTy, lhs, outputShape,
- rewriter.getI64TensorAttr(lhsDimMapping));
+ rewriter.getDenseI64ArrayAttr(lhsDimMapping));
// Broadcast the right hand side to match the expected output shape.
llvm::SmallVector<int64_t> rhsDimMapping(rhsTy.getRank());
@@ -1524,7 +1523,7 @@
RankedTensorType::get(resultTy.getShape(), rhsTy.getElementType());
rhs = builder.createOrFold<mlir::stablehlo::DynamicBroadcastInDimOp>(
rhsBroadcastTy, rhs, outputShape,
- rewriter.getI64TensorAttr(rhsDimMapping));
+ rewriter.getDenseI64ArrayAttr(rhsDimMapping));
lhs = builder.createOrFold<mlir::stablehlo::ConvertOp>(resultTy, lhs);
rhs = builder.createOrFold<mlir::stablehlo::ConvertOp>(resultTy, rhs);
@@ -1651,7 +1650,7 @@
return true;
}
- (void)rewriter.notifyMatchFailure(iotaOp, "Iota must be on last dimension");
+ (void)rewriter.notifyMatchFailure(iotaOp, "iota must be on last dimension");
return false;
}
@@ -1659,11 +1658,9 @@
input.getDefiningOp())) {
auto broadcastLastDim =
cast<ShapedType>(broadcastOp.getType()).getRank() - 1;
- SmallVector<int64_t> broadcastDimensions = llvm::to_vector(
- broadcastOp.getBroadcastDimensions().getValues<int64_t>());
- if (broadcastDimensions.back() != broadcastLastDim) {
+ if (broadcastOp.getBroadcastDimensions().back() != broadcastLastDim) {
(void)rewriter.notifyMatchFailure(
- broadcastOp, "Last dimension must be maintained in broadcast");
+ broadcastOp, "last dimension must be maintained in broadcast");
return false;
}
return isIotaOrIotaBroadcast(rewriter, broadcastOp.getOperand());
@@ -1682,7 +1679,7 @@
Value topKInput;
if (opOperands.size() != 2 || opResults.size() != 2) {
return rewriter.notifyMatchFailure(
- op, "Slice that maps to TopK must have exactly two inputs/outputs");
+ op, "slice that maps to TopK must have exactly two inputs/outputs");
}
Value inputIota;
@@ -1697,7 +1694,7 @@
}
if (!inputIota) {
- return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota.");
+ return rewriter.notifyMatchFailure(op, "sort isn't called from Iota");
}
Block &block = op.getRegion().front();
@@ -1713,7 +1710,7 @@
if (!getTop) {
return rewriter.notifyMatchFailure(op,
- "Unsupported comparison direction");
+ "unsupported comparison direction");
}
Value topV, topI;
@@ -1722,27 +1719,25 @@
for (auto [idx, result] : llvm::enumerate(opResults)) {
if (result.getUsers().empty())
return rewriter.notifyMatchFailure(
- op, "Sort isn't calling into a slice op.");
+ op, "sort isn't calling into a slice op");
auto sliceOp =
dyn_cast<mlir::stablehlo::SliceOp>(*result.getUsers().begin());
if (!sliceOp) {
return rewriter.notifyMatchFailure(
- op, "Sort isn't calling into a slice op.");
+ op, "sort isn't calling into a slice op");
}
- for (auto stride : sliceOp.getStrides().getValues<int64_t>()) {
+ for (auto stride : sliceOp.getStrides()) {
if (stride != 1) {
return rewriter.notifyMatchFailure(
- op, "All slice strides must be 1 in order to match to TopK.");
+ op, "all slice strides must be 1 in order to match to TopK");
}
}
// Treat the first slice as inputs, the second as indices.
if (idx == 0) {
topV = sliceOp.getResult();
- SmallVector<int64_t> limitIndices =
- llvm::to_vector(sliceOp.getLimitIndices().getValues<int64_t>());
- k = limitIndices.back();
+ k = sliceOp.getLimitIndices().back();
} else {
topI = sliceOp.getResult();
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
index 8e36b84..2ebad08 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -355,7 +355,7 @@
padding = dense<0> : tensor<1x2xi64>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
rhs_dilation = dense<1> : tensor<1xi64>,
- window_strides = dense<1> : tensor<1xi64>
+ window_strides = dense<[1]> : tensor<1xi64>
} : (tensor<16x32x256xf32>, tensor<1x256x256xbf16>) -> tensor<16x32x256xf32>
// CHECK: return %[[CONV]]
func.return %0 : tensor<16x32x256xf32>
@@ -413,8 +413,8 @@
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
- %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32>
- %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32>
+ %1 = "stablehlo.slice"(%0#0) { start_indices = array<i64: 0, 0>, limit_indices = array<i64: 16, 8>, strides = array<i64: 1, 1> } : (tensor<16x16xf32>) -> tensor<16x8xf32>
+ %2 = "stablehlo.slice"(%0#1) { start_indices = array<i64: 0, 0>, limit_indices = array<i64: 16, 8>, strides = array<i64: 1, 1> } : (tensor<16x16xi32>) -> tensor<16x8xi32>
return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32>
}
@@ -434,8 +434,8 @@
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 2 : i64, is_stable = true} : (tensor<16x16x16xf32>, tensor<16x16x16xi32>) -> (tensor<16x16x16xf32>, tensor<16x16x16xi32>)
- %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32>
- %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32>
+ %1 = "stablehlo.slice"(%0#0) { start_indices = array<i64: 0, 0, 0>, limit_indices = array<i64: 16, 16, 8>, strides = array<i64: 1, 1, 1> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32>
+ %2 = "stablehlo.slice"(%0#1) { start_indices = array<i64: 0, 0, 0>, limit_indices = array<i64: 16, 16, 8>, strides = array<i64: 1, 1, 1> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32>
return %1, %2 : tensor<16x16x8xf32>, tensor<16x16x8xi32>
}
@@ -455,8 +455,8 @@
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 2 : i64, is_stable = true} : (tensor<16x16x16xf32>, tensor<16x16x16xi32>) -> (tensor<16x16x16xf32>, tensor<16x16x16xi32>)
- %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32>
- %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32>
+ %1 = "stablehlo.slice"(%0#0) { start_indices = array<i64: 0, 0, 0>, limit_indices = array<i64: 16, 16, 8>, strides = array<i64: 1, 1, 1> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32>
+ %2 = "stablehlo.slice"(%0#1) { start_indices = array<i64: 0, 0, 0>, limit_indices = array<i64: 16, 16, 8>, strides = array<i64: 1, 1, 1> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32>
return %1, %2 : tensor<16x16x8xf32>, tensor<16x16x8xi32>
}
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp
index a6d8cfa..f82fb9a 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp
@@ -178,8 +178,7 @@
int64_t rank = inputType.getRank();
int64_t n = inputType.getDimSize(rank - 1);
- int64_t fftLength =
- op.getFftLength().getSplatValue<IntegerAttr>().getInt() / 2 + 1;
+ int64_t fftLength = op.getFftLength().front() / 2 + 1;
Location loc = op.getLoc();
auto matrixType =
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
index d26afec..ff36eb1 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp
@@ -492,15 +492,12 @@
SmallVector<AffineExpr> dimExprs;
dimExprs.reserve(nloops);
- if (broadcastOp.getBroadcastDimensions()) {
- for (auto [idx, broadcastDim] : llvm::enumerate(
- broadcastOp.getBroadcastDimensions().getValues<APInt>())) {
- int size = broadcastDim.getSExtValue();
- bool expansionNeeded =
- operandShape[idx] == 1 && resultType.getShape()[size] != 1;
- dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
- : b->getAffineDimExpr(size));
- }
+ for (auto [idx, size] :
+ llvm::enumerate(broadcastOp.getBroadcastDimensions())) {
+ bool expansionNeeded =
+ operandShape[idx] == 1 && resultType.getShape()[size] != 1;
+ dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
+ : b->getAffineDimExpr(size));
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
@@ -577,7 +574,7 @@
return rewriter.create<mlir::stablehlo::TransposeOp>(
loc,
RankedTensorType::get(transposedOperandShape, operandTy.getElementType()),
- operand, rewriter.getI64VectorAttr(permutation));
+ operand, rewriter.getDenseI64ArrayAttr(permutation));
}
struct BroadcastInDimOpToBroadcastConverter final
@@ -589,8 +586,7 @@
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- SmallVector<int64_t> broadcastDimensions =
- llvm::to_vector(op.getBroadcastDimensions().getValues<int64_t>());
+ SmallVector<int64_t> broadcastDimensions = op.getBroadcastDimensions();
Value operand = adaptor.getOperand();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
@@ -658,9 +654,8 @@
// Use static type info.
auto bcastDims =
- llvm::map_to_vector(op.getBroadcastDimensions(), [](const APInt &d) {
- return static_cast<int64_t>(d.getLimitedValue());
- });
+ llvm::map_to_vector(op.getBroadcastDimensions(),
+ [](int64_t d) { return static_cast<int64_t>(d); });
for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) {
if (ShapedType::isDynamic(dim))
continue;
@@ -671,17 +666,13 @@
}
// Use annotated expansion behavior, if available.
- if (op.getKnownExpandingDimensions()) {
- for (const auto &it :
- op.getKnownExpandingDimensions()->getValues<APInt>()) {
- auto i = it.getLimitedValue();
+ if (auto dims = op.getKnownExpandingDimensions()) {
+ for (int i : *dims) {
dimExprs[i] = rewriter.getAffineConstantExpr(0);
}
}
- if (op.getKnownNonexpandingDimensions()) {
- for (const auto &it :
- op.getKnownNonexpandingDimensions()->getValues<APInt>()) {
- auto i = it.getLimitedValue();
+ if (auto dims = op.getKnownNonexpandingDimensions()) {
+ for (int i : *dims) {
dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]);
}
}
@@ -730,8 +721,7 @@
if (!resultTy)
return failure();
- SmallVector<int64_t> broadcastDimensions =
- llvm::to_vector(op.getBroadcastDimensions().getValues<int64_t>());
+ SmallVector<int64_t> broadcastDimensions = op.getBroadcastDimensions();
SmallVector<std::optional<bool>> expansionBehavior(
broadcastDimensions.size());
@@ -745,14 +735,14 @@
// Use annotated expansion behavior, if available.
if (op.getKnownExpandingDimensions()) {
- for (const auto &it :
- op.getKnownExpandingDimensions()->getValues<int64_t>()) {
+ auto dims = op.getKnownExpandingDimensions().value();
+ for (int it : dims) {
expansionBehavior[it] = true;
}
}
if (op.getKnownNonexpandingDimensions()) {
- for (const auto &it :
- op.getKnownNonexpandingDimensions()->getValues<int64_t>()) {
+ auto dims = op.getKnownNonexpandingDimensions().value();
+ for (int it : dims) {
expansionBehavior[it] = false;
}
}
@@ -853,7 +843,7 @@
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
for (auto [idx, value] : llvm::enumerate(op.getPermutation())) {
- inputExprs[value.getZExtValue()] = b->getAffineDimExpr(idx);
+ inputExprs[value] = b->getAffineDimExpr(idx);
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
@@ -876,8 +866,7 @@
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());
- auto permutation = rewriter.getDenseI64ArrayAttr(
- llvm::to_vector(op.getPermutation().getValues<int64_t>()));
+ auto permutation = op.getPermutationAttr();
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
op, adaptor.getOperand(), emptyTensor, permutation,
@@ -1453,8 +1442,7 @@
inputExprs.reserve(nloops);
for (int64_t i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i));
- for (const APInt &dim : op.getDimensions()) {
- int i = dim.getZExtValue();
+ for (int i : op.getDimensions()) {
if (resultType.isDynamicDim(i))
return {};
int n = resultType.getShape()[i];
@@ -1479,9 +1467,9 @@
}
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
- auto startIndices = sliceOp.getStartIndices().getValues<int64_t>();
- auto limitIndices = sliceOp.getLimitIndices().getValues<int64_t>();
- auto sliceStrides = sliceOp.getStrides().getValues<int64_t>();
+ auto startIndices = sliceOp.getStartIndices();
+ auto limitIndices = sliceOp.getLimitIndices();
+ auto sliceStrides = sliceOp.getStrides();
for (int64_t i = 0, e = argType.getRank(); i < e; ++i) {
int64_t start = startIndices[i];
@@ -1526,9 +1514,8 @@
SmallVector<OpFoldResult, 3> startIndices, sizes;
auto originalStartIndexType = llvm::cast<ShapedType>(
dynamicSliceOp.getStartIndices().front().getType());
- for (auto [idx, start, size] :
- llvm::enumerate(adaptor.getStartIndices(),
- dynamicSliceOp.getSliceSizes().getValues<int64_t>())) {
+ for (auto [idx, start, size] : llvm::enumerate(
+ adaptor.getStartIndices(), dynamicSliceOp.getSliceSizes())) {
sizes.push_back(rewriter.getI64IntegerAttr(size));
// By stablehlo.DynamicSlice definition:
@@ -2305,7 +2292,7 @@
SmallVector<OpFoldResult> sliceStarts;
bool hasNegativePadding = false;
- for (int64_t low : op.getEdgePaddingLow().getValues<int64_t>()) {
+ for (int64_t low : op.getEdgePaddingLow()) {
if (low >= 0) {
padLow.push_back(low);
sliceStarts.push_back(rewriter.getIndexAttr(0));
@@ -2316,7 +2303,7 @@
}
}
- for (int64_t high : op.getEdgePaddingHigh().getValues<int64_t>()) {
+ for (int64_t high : op.getEdgePaddingHigh()) {
if (high >= 0) {
padHigh.push_back(high);
} else {
@@ -2332,8 +2319,8 @@
// Create a new pad op with the positive values.
Value pad = rewriter.create<mlir::stablehlo::PadOp>(
op.getLoc(), adaptor.getOperand(), adaptor.getPaddingValue(),
- rewriter.getI64TensorAttr(padLow), rewriter.getI64TensorAttr(padHigh),
- op.getInteriorPadding());
+ rewriter.getDenseI64ArrayAttr(padLow),
+ rewriter.getDenseI64ArrayAttr(padHigh), op.getInteriorPadding());
// Then slice according to the negative edge padding. Static shapes only for
// now.
@@ -2365,24 +2352,26 @@
return rewriter.notifyMatchFailure(op, "type conversion failed");
// Negative edge padding is decomposed separately.
- auto isNegative = [](const APInt &intVal) { return intVal.isNegative(); };
- if (llvm::any_of(op.getEdgePaddingLow().getValues<APInt>(), isNegative) ||
- llvm::any_of(op.getEdgePaddingHigh().getValues<APInt>(), isNegative))
+ auto isNegative = [](int64_t intVal) { return intVal < 0; };
+ if (llvm::any_of(op.getEdgePaddingLow(), isNegative) ||
+ llvm::any_of(op.getEdgePaddingHigh(), isNegative))
return failure();
Value paddingVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, adaptor.getPaddingValue());
- SmallVector<OpFoldResult> low(
- op.getEdgePaddingLow().getValues<IntegerAttr>());
+ auto i64ToFoldResult = [&](const int64_t &i) -> OpFoldResult {
+ return rewriter.getIntegerAttr(rewriter.getI64Type(), i);
+ };
// If there is no interior padding lower to tensor.pad directly.
- if (llvm::all_of(op.getInteriorPadding().getValues<APInt>(),
- [](const APInt &intVal) { return intVal.isZero(); })) {
- SmallVector<OpFoldResult> high(
- op.getEdgePaddingHigh().getValues<IntegerAttr>());
+ if (llvm::all_of(op.getInteriorPadding(),
+ [](const int64_t &i) { return i == 0; })) {
auto padTensorOp = rewriter.create<tensor::PadOp>(
- loc, resultType, adaptor.getOperand(), low, high, paddingVal);
+ loc, resultType, adaptor.getOperand(),
+ llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult),
+ llvm::map_to_vector(op.getEdgePaddingHigh(), i64ToFoldResult),
+ paddingVal);
rewriter.replaceOp(op, padTensorOp.getResult());
return success();
}
@@ -2405,15 +2394,15 @@
.getResult();
});
// Map interior padding to strides.
- auto strides =
- llvm::map_to_vector(op.getInteriorPadding().getValues<IntegerAttr>(),
- [&](IntegerAttr stride) -> OpFoldResult {
- return rewriter.getIntegerAttr(
- stride.getType(), stride.getValue() + 1);
- });
+ auto strides = llvm::map_to_vector(
+ op.getInteriorPadding(), [&](const int64_t &stride) -> OpFoldResult {
+ return rewriter.getIntegerAttr(rewriter.getI64Type(), stride + 1);
+ });
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
- op, adaptor.getOperand(), fill, low, sizes, strides);
+ op, adaptor.getOperand(), fill,
+ llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), sizes,
+ strides);
return success();
}
};
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
index 971115b..353852d 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp
@@ -55,9 +55,6 @@
}
}
- IntegerType indexType = rewriter.getIntegerType(64);
- auto attrType = RankedTensorType::get({rank}, indexType);
-
Value zero;
if (auto complexType = dyn_cast<ComplexType>(inputType.getElementType())) {
auto zeroElement = rewriter.getZeroAttr(complexType.getElementType());
@@ -72,9 +69,9 @@
}
return rewriter.create<mlir::stablehlo::PadOp>(
- loc, input, zero, DenseIntElementsAttr::get(attrType, padLow),
- DenseIntElementsAttr::get(attrType, padHigh),
- DenseIntElementsAttr::get(attrType, padInterior));
+ loc, input, zero, rewriter.getDenseI64ArrayAttr(padLow),
+ rewriter.getDenseI64ArrayAttr(padHigh),
+ rewriter.getDenseI64ArrayAttr(padInterior));
}
/// If the ConvolutionOp has a window reversal, applies it to the filter.
@@ -95,10 +92,7 @@
}
return b.create<mlir::stablehlo::ReverseOp>(
- loc, filter,
- mlir::DenseIntElementsAttr::get(
- RankedTensorType::get(reversedDims.size(), b.getI64Type()),
- reversedDims));
+ loc, filter, b.getDenseI64ArrayAttr(reversedDims));
}
/// Returns true if the given `dimensionNumbers` from a stablehlo.convolution op
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
index e924809..cc27b85 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp
@@ -374,10 +374,10 @@
if (!operandType || !operandType.hasStaticShape()) {
return failure();
}
- if (!op.getFftLength().isSplat()) {
+ if (!llvm::all_equal(op.getFftLength())) {
return rewriter.notifyMatchFailure(op, "non-splat length");
}
- int fftLength = op.getFftLength().getSplatValue<IntegerAttr>().getInt();
+ int fftLength = op.getFftLength().front();
if (fftLength & (fftLength - 1)) {
return rewriter.notifyMatchFailure(
op, "expected FFT length to be a power of two");
@@ -442,7 +442,7 @@
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, ty.getElementType());
rewriter.replaceOpWithNewOp<IREE::LinalgExt::ReverseOp>(
op, typeConverter->convertType(op.getType()), adaptor.getOperands(),
- emptyTensor, op.getDimensions());
+ emptyTensor, rewriter.getI64TensorAttr(op.getDimensions()));
return success();
}
};
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp
index 592d6d5..033dd69 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp
@@ -426,9 +426,9 @@
llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
Value slice = builder.create<mlir::stablehlo::SliceOp>(
- loc, resultTy, reshape, builder.getI64TensorAttr(offset),
- builder.getI64TensorAttr(resultTy.getShape()),
- builder.getI64TensorAttr(stride));
+ loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset),
+ builder.getDenseI64ArrayAttr(resultTy.getShape()),
+ builder.getDenseI64ArrayAttr(stride));
// Set the new tensor values.
store = setState64(builder, loc, store, newState);
@@ -636,12 +636,14 @@
// Slice to only the required results.
collapseShape[0] = resultTy.getNumElements();
- llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
- llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
+ auto sliceResultTy = intermediateType.clone(collapseShape);
+ llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0);
+ llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1);
Value slice = builder.create<mlir::stablehlo::SliceOp>(
- loc, intermediateType.clone(collapseShape), reshapeIntermediate,
- builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape),
- builder.getI64TensorAttr(stride));
+ loc, sliceResultTy, reshapeIntermediate,
+ builder.getDenseI64ArrayAttr(offset),
+ builder.getDenseI64ArrayAttr(collapseShape),
+ builder.getDenseI64ArrayAttr(stride));
Value reshapeResult =
builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice);
@@ -727,12 +729,14 @@
// Slice to only the required results.
collapseShape[0] = resultTy.getNumElements();
- llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
- llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
+ auto sliceResultTy = intermediateType.clone(collapseShape);
+ llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0);
+ llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1);
Value slice = builder.create<mlir::stablehlo::SliceOp>(
- loc, intermediateType.clone(collapseShape), reshapeIntermediate,
- builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape),
- builder.getI64TensorAttr(stride));
+ loc, sliceResultTy, reshapeIntermediate,
+ builder.getDenseI64ArrayAttr(offset),
+ builder.getDenseI64ArrayAttr(collapseShape),
+ builder.getDenseI64ArrayAttr(stride));
Value reshapeResult =
builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice);
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
index a579fe7..c261d03 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp
@@ -404,7 +404,7 @@
if (!resultTy.hasStaticShape())
return failure();
- auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape());
+ auto broadcastSizes = rewriter.getDenseI64ArrayAttr(resultTy.getShape());
broadcastValues.push_back(rewriter.create<mlir::stablehlo::BroadcastOp>(
loc, resultTy, initValue, broadcastSizes));
}
@@ -426,12 +426,9 @@
staticInteriors[idx] = dilation - 1;
}
- auto padAttrType =
- RankedTensorType::get({rank}, rewriter.getIntegerType(64));
- auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows);
- auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs);
- auto padInteriors =
- DenseIntElementsAttr::get(padAttrType, staticInteriors);
+ auto padLows = rewriter.getDenseI64ArrayAttr(staticLows);
+ auto padHighs = rewriter.getDenseI64ArrayAttr(staticHighs);
+ auto padInteriors = rewriter.getDenseI64ArrayAttr(staticInteriors);
for (auto [input, initValue] : llvm::zip(inputs, initValues)) {
input = rewriter.create<mlir::stablehlo::PadOp>(
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
index f8b5c10..bcf0c43 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir
@@ -417,7 +417,7 @@
func.func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
- {broadcast_dimensions = dense<[0]> : tensor<1xi64>}
+ {broadcast_dimensions = array<i64: 0>}
: (tensor<1xf32>) -> tensor<1x5xf32>
func.return %0 : tensor<1x5xf32>
}
@@ -464,7 +464,7 @@
// CHECK: func @broadcast_in_dim_scalar
func.func @broadcast_in_dim_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
- {broadcast_dimensions = dense<[]> : tensor<0xi64>}
+ {broadcast_dimensions = array<i64>}
: (tensor<f32>) -> tensor<7x10x6xf32>
func.return %0 : tensor<7x10x6xf32>
}
@@ -484,7 +484,7 @@
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @broadcast_scalar
func.func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
- %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
+ %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array<i64: 4, 2, 1>} : (tensor<f32>) -> tensor<4x2x1xf32>
func.return %0: tensor<4x2x1xf32>
}
// CHECK: tensor.empty() : tensor<4x2x1xf32>
@@ -505,7 +505,7 @@
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK: func @broadcast
func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
- %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
+ %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array<i64: 4, 2, 1>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
func.return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -655,7 +655,7 @@
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>}
+ }) {dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?xf32>, tensor<4xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
@@ -675,7 +675,7 @@
^bb0(%arg2: tensor<f32>):
%1 = stablehlo.add %arg2, %arg2 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>}
+ }) {dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?xf32>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}
@@ -706,7 +706,7 @@
{comparison_direction = #stablehlo<comparison_direction EQ>}
: (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%3) : (tensor<i1>) -> ()
- }) {dimensions = dense<0> : tensor<1xi64>}
+ }) {dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?xcomplex<f32>>, tensor<?xcomplex<f32>>) -> tensor<?xi1>
func.return %0 : tensor<?xi1>
}
@@ -741,9 +741,9 @@
func.func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
%0 = arith.constant dense<0.0> : tensor<f32>
%1 = "stablehlo.pad"(%arg0, %0) {
- edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
- edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
- interior_padding = dense<0> : tensor<2xi64>
+ edge_padding_high = array<i64: 2, 3>,
+ edge_padding_low = array<i64: 4, 5>,
+ interior_padding = array<i64: 0, 0>
} : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
func.return %1 : tensor<18x12xf32>
}
@@ -758,9 +758,9 @@
func.func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf32> {
%0 = "stablehlo.pad"(%arg0, %arg1) {
- edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
- edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
- interior_padding = dense<0> : tensor<2xi64>
+ edge_padding_high = array<i64: 2, 3>,
+ edge_padding_low = array<i64: 4, 5>,
+ interior_padding = array<i64: 0, 0>
} : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
func.return %0 : tensor<18x12xf32>
}
@@ -777,9 +777,9 @@
func.func @pad_interior(%arg0: tensor<12x4xui32>, %arg1: tensor<ui32>) -> tensor<29x15xui32> {
%0 = arith.constant dense<0> : tensor<ui32>
%1 = "stablehlo.pad"(%arg0, %arg1) {
- edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
- edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
- interior_padding = dense<[1, 1]> : tensor<2xi64>
+ edge_padding_high = array<i64: 2, 3>,
+ edge_padding_low = array<i64: 4, 5>,
+ interior_padding = array<i64: 1, 1>
} : (tensor<12x4xui32>, tensor<ui32>) -> tensor<29x15xui32>
func.return %1 : tensor<29x15xui32>
}
@@ -798,9 +798,9 @@
func.func @pad_interior_negative(%arg0: tensor<12x4xui32>, %arg1: tensor<ui32>) -> tensor<25x9xui32> {
%0 = arith.constant dense<0> : tensor<ui32>
%1 = "stablehlo.pad"(%arg0, %arg1) {
- edge_padding_high = dense<[-2, 3]> : tensor<2xi64>,
- edge_padding_low = dense<[4, -1]> : tensor<2xi64>,
- interior_padding = dense<[1, 1]> : tensor<2xi64>
+ edge_padding_high = array<i64: -2, 3>,
+ edge_padding_low = array<i64: 4, -1>,
+ interior_padding = array<i64: 1, 1>
} : (tensor<12x4xui32>, tensor<ui32>) -> tensor<25x9xui32>
func.return %1 : tensor<25x9xui32>
}
@@ -1066,7 +1066,7 @@
// CHECK: func @reverse
func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "stablehlo.reverse"(%input) {
- dimensions = dense<1> : tensor<1xi64>, someattr
+ dimensions = array<i64: 1>, someattr
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
func.return %result : tensor<2x3xf32>
}
@@ -1332,9 +1332,9 @@
// CHECK: tensor.extract_slice %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32>
func.func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
%0 = "stablehlo.slice"(%arg0) {
- start_indices = dense<[1, 0]> : tensor<2xi64>,
- limit_indices = dense<[2, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 1, 0>,
+ limit_indices = array<i64: 2, 4>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<1x4xi32>
func.return %0 : tensor<1x4xi32>
}
@@ -1345,9 +1345,9 @@
// CHECK: tensor.extract_slice %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>
func.func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "stablehlo.slice"(%arg0) {
- start_indices = dense<[1, 1]> : tensor<2xi64>,
- limit_indices = dense<[2, 3]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 1, 1>,
+ limit_indices = array<i64: 2, 3>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
func.return %0 : tensor<1x2xi32>
}
@@ -1358,9 +1358,9 @@
// CHECK: tensor.extract_slice %{{.*}}[0] [6] [2] : tensor<13xi32> to tensor<6xi32>
func.func @slice_with_strides(%arg0: tensor<13xi32>) -> tensor<6xi32> {
%0 = "stablehlo.slice"(%arg0) {
- limit_indices = dense<12> : tensor<1xi64>,
- start_indices = dense<0> : tensor<1xi64>,
- strides = dense<2> : tensor<1xi64>
+ limit_indices = array<i64: 12>,
+ start_indices = array<i64: 0>,
+ strides = array<i64: 2>
} : (tensor<13xi32>) -> tensor<6xi32>
func.return %0 : tensor<6xi32>
}
@@ -1371,9 +1371,9 @@
// CHECK: tensor.extract_slice %{{.*}}[0] [3] [2] : tensor<6xi32> to tensor<3xi32>
func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> {
%0 = "stablehlo.slice"(%arg0) {
- limit_indices = dense<5> : tensor<1xi64>,
- start_indices = dense<0> : tensor<1xi64>,
- strides = dense<2> : tensor<1xi64>
+ limit_indices = array<i64: 5>,
+ start_indices = array<i64: 0>,
+ strides = array<i64: 2>
} : (tensor<6xi32>) -> tensor<3xi32>
func.return %0 : tensor<3xi32>
}
@@ -1384,9 +1384,9 @@
// CHECK: tensor.extract_slice %{{.*}}[0, 2, 0] [3, 0, 5] [1, 2, 1] : tensor<3x3x5xf64> to tensor<3x0x5xf64>
func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64> {
%0 = "stablehlo.slice"(%arg0) {
- limit_indices = dense<[3, 2, 5]> : tensor<3xi64>,
- start_indices = dense<[0, 2, 0]> : tensor<3xi64>,
- strides = dense<[1, 2, 1]> : tensor<3xi64>
+ limit_indices = array<i64: 3, 2, 5>,
+ start_indices = array<i64: 0, 2, 0>,
+ strides = array<i64: 1, 2, 1>
} : (tensor<3x3x5xf64>) -> tensor<3x0x5xf64>
func.return %0 : tensor<3x0x5xf64>
}
@@ -1399,7 +1399,7 @@
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xf32> {
%0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) {
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 4>
} : (tensor<3x4xf32>, tensor<i64>, tensor<i64>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}
@@ -1422,7 +1422,7 @@
%arg: tensor<3x4xui32>, %start1: tensor<ui64>, %start2: tensor<ui64>)
-> tensor<1x4xui32> {
%0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) {
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 4>
} : (tensor<3x4xui32>, tensor<ui64>, tensor<ui64>) -> tensor<1x4xui32>
func.return %0 : tensor<1x4xui32>
}
@@ -1438,7 +1438,7 @@
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
func.func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xui32> {
%0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) {
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 4>
} : (tensor<3x4xui32>, tensor<i64>, tensor<i64>) -> tensor<1x4xui32>
func.return %0 : tensor<1x4xui32>
}
@@ -1559,7 +1559,7 @@
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @transpose
func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
- %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
+ %0 = "stablehlo.transpose"(%arg0) {permutation = array<i64: 1, 0, 3, 2>}
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
func.return %0 : tensor<3x2x5x9xi32>
}
@@ -1574,7 +1574,7 @@
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @transpose_dynamic
func.func @transpose_dynamic(%arg0: tensor<?x?x9x?xi32>) -> tensor<?x?x?x9xi32> {
- %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>, someattr}
+ %0 = "stablehlo.transpose"(%arg0) {permutation = array<i64: 1, 0, 3, 2>, someattr}
: (tensor<?x?x9x?xi32>) -> tensor<?x?x?x9xi32>
func.return %0 : tensor<?x?x?x9xi32>
}
@@ -1607,7 +1607,7 @@
func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> {
%0 = "stablehlo.transpose"(%arg0) {
- permutation = dense<[1, 0]> : tensor<2xi64>,
+ permutation = array<i64: 1, 0>,
result_layout = dense<[0, 1]> : tensor<2xindex>
} : (tensor<2x2xui32>) -> tensor<2x2xui32>
return %0 : tensor<2x2xui32>
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
index 17270e2..5f4b338 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir
@@ -394,7 +394,7 @@
// CHECK: func.func @rfft_1d
func.func @rfft_1d(%input: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0 = "stablehlo.fft"(%input) {
- fft_length = dense<8> : tensor<1xi64>, fft_type = #stablehlo<fft_type RFFT>
+ fft_length = array<i64: 8>, fft_type = #stablehlo<fft_type RFFT>
} : (tensor<8xf32>) -> tensor<5xcomplex<f32>>
%1 = "stablehlo.real"(%0) : (tensor<5xcomplex<f32>>) -> tensor<5xf32>
%2 = "stablehlo.imag"(%0) : (tensor<5xcomplex<f32>>) -> tensor<5xf32>
@@ -442,7 +442,7 @@
// CHECK: func.func @rfft_2d
func.func @rfft_2d(%input: tensor<4x8xf32>) -> (tensor<4x5xf32>, tensor<4x5xf32>) {
%0 = "stablehlo.fft"(%input) {
- fft_length = dense<8> : tensor<1xi64>, fft_type = #stablehlo<fft_type RFFT>
+ fft_length = array<i64: 8>, fft_type = #stablehlo<fft_type RFFT>
} : (tensor<4x8xf32>) -> tensor<4x5xcomplex<f32>>
%1 = "stablehlo.real"(%0) : (tensor<4x5xcomplex<f32>>) -> tensor<4x5xf32>
%2 = "stablehlo.imag"(%0) : (tensor<4x5xcomplex<f32>>) -> tensor<4x5xf32>
@@ -490,7 +490,7 @@
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
%0 = "stablehlo.reverse"(%arg0) {
- dimensions = dense<1> : tensor<1xi64>
+ dimensions = array<i64: 1>
} : (tensor<3x5xi32>) -> tensor<3x5xi32>
return %0 : tensor<3x5xi32>
}
@@ -505,7 +505,7 @@
func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
%0 = "stablehlo.reverse"(%arg0) {
- dimensions = dense<1> : tensor<1xi64>
+ dimensions = array<i64: 1>
} : (tensor<3x5xui32>) -> tensor<3x5xui32>
return %0 : tensor<3x5xui32>
}
@@ -526,7 +526,7 @@
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
func.func @reverse_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
%0 = "stablehlo.reverse"(%arg0) {
- dimensions = dense<[0, 1]> : tensor<2xi64>
+ dimensions = array<i64: 0, 1>
} : (tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
diff --git a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
index 47334ac..2e5bcf9 100644
--- a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
+++ b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py
@@ -19,7 +19,7 @@
from iree.compiler.tools import InputType, compile_file
from iree.runtime import load_vm_flatbuffer_file
-MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.a49ba1535a45ac0f3e6be22a7ed5dddf4a53cd1f41126af938f0667b998f8e11.tar"
+MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.45208053dcd69ebd7428fe5b785249a7bdff2d62d55fb81b815889c4e1b993bb.tar"
Tensor = TypeVar("Tensor")
diff --git a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
index e16251c..a2a378c 100644
--- a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
+++ b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir
@@ -79,18 +79,18 @@
%62 = stablehlo.dot %61, %45, precision = [DEFAULT] : (tensor<1x74xf32>, tensor<74x40xf32>) -> tensor<1x40xf32>
%63 = stablehlo.reshape %43 : (tensor<40xf32>) -> tensor<1x40xf32>
%64 = stablehlo.add %62, %63 : tensor<1x40xf32>
- %65 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 30]> : tensor<2xi64>, start_indices = dense<[0, 20]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
+ %65 = "stablehlo.slice"(%64) {limit_indices = array<i64: 1, 30>, start_indices = array<i64: 0, 20>, strides = array<i64: 1, 1>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
%66 = stablehlo.multiply %65, %8 : tensor<1x10xf32>
%67 = stablehlo.tanh %66 : tensor<1x10xf32>
%68 = stablehlo.multiply %67, %8 : tensor<1x10xf32>
%69 = stablehlo.add %68, %8 : tensor<1x10xf32>
%70 = stablehlo.multiply %69, %47 : tensor<1x10xf32>
- %71 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 20]> : tensor<2xi64>, start_indices = dense<[0, 10]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
+ %71 = "stablehlo.slice"(%64) {limit_indices = array<i64: 1, 20>, start_indices = array<i64: 0, 10>, strides = array<i64: 1, 1>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
%72 = stablehlo.multiply %71, %8 : tensor<1x10xf32>
%73 = stablehlo.tanh %72 : tensor<1x10xf32>
%74 = stablehlo.multiply %73, %8 : tensor<1x10xf32>
%75 = stablehlo.add %74, %8 : tensor<1x10xf32>
- %76 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 10]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
+ %76 = "stablehlo.slice"(%64) {limit_indices = array<i64: 1, 10>, start_indices = array<i64: 0, 0>, strides = array<i64: 1, 1>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
%77 = stablehlo.tanh %76 : tensor<1x10xf32>
%78 = stablehlo.multiply %75, %77 : tensor<1x10xf32>
%79 = stablehlo.add %70, %78 : tensor<1x10xf32>
@@ -100,7 +100,7 @@
%83 = stablehlo.reshape %56 : (tensor<1x1xf32>) -> tensor<1xf32>
%84 = stablehlo.broadcast_in_dim %83, dims = [0] : (tensor<1xf32>) -> tensor<1x10xf32>
%85 = stablehlo.compare GT, %84, %7 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xi1>
- %86 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 40]> : tensor<2xi64>, start_indices = dense<[0, 30]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
+ %86 = "stablehlo.slice"(%64) {limit_indices = array<i64: 1, 40>, start_indices = array<i64: 0, 30>, strides = array<i64: 1, 1>} : (tensor<1x40xf32>) -> tensor<1x10xf32>
%87 = stablehlo.multiply %86, %8 : tensor<1x10xf32>
%88 = stablehlo.tanh %87 : tensor<1x10xf32>
%89 = stablehlo.multiply %88, %8 : tensor<1x10xf32>
diff --git a/tests/e2e/stablehlo_ops/broadcast.mlir b/tests/e2e/stablehlo_ops/broadcast.mlir
index b723466..ad4e72d 100644
--- a/tests/e2e/stablehlo_ops/broadcast.mlir
+++ b/tests/e2e/stablehlo_ops/broadcast.mlir
@@ -1,7 +1,7 @@
func.func @broadcast_2D_3D() {
%input = util.unfoldable_constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi32>
- %result = "stablehlo.broadcast"(%input) {broadcast_sizes = dense<3> : tensor<1xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ %result = "stablehlo.broadcast"(%input) {broadcast_sizes = array<i64: 3>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
check.expect_eq_const(%result, dense<[
[[1, 2, 3, 4], [5, 6, 7, 8]],
[[1, 2, 3, 4], [5, 6, 7, 8]],
@@ -11,7 +11,7 @@
func.func @broadcast_3D_scalar() {
%input = util.unfoldable_constant dense<42> : tensor<i32>
- %result = "stablehlo.broadcast"(%input) {broadcast_sizes = dense<[3, 2, 4]> : tensor<3xi64>} : (tensor<i32>) -> tensor<3x2x4xi32>
+ %result = "stablehlo.broadcast"(%input) {broadcast_sizes = array<i64: 3, 2, 4>} : (tensor<i32>) -> tensor<3x2x4xi32>
check.expect_eq_const(%result, dense<[
[[42, 42, 42, 42], [42, 42, 42, 42]],
[[42, 42, 42, 42], [42, 42, 42, 42]],
diff --git a/tests/e2e/stablehlo_ops/dynamic_slice.mlir b/tests/e2e/stablehlo_ops/dynamic_slice.mlir
index dc0a201..a423b23 100644
--- a/tests/e2e/stablehlo_ops/dynamic_slice.mlir
+++ b/tests/e2e/stablehlo_ops/dynamic_slice.mlir
@@ -6,7 +6,7 @@
%start1 = util.unfoldable_constant dense<1> : tensor<i64>
%start2 = util.unfoldable_constant dense<2> : tensor<i64>
%result = "stablehlo.dynamic_slice"(%input, %start1, %start2) {
- slice_sizes = dense<[2, 2]> : tensor<2xi64>
+ slice_sizes = array<i64: 2, 2>
} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
check.expect_eq_const(%result, dense<[
[7, 8],
@@ -22,7 +22,7 @@
%start1 = util.unfoldable_constant dense<1> : tensor<i64>
%start2 = util.unfoldable_constant dense<2> : tensor<i64>
%result = "stablehlo.dynamic_slice"(%input, %start1, %start2) {
- slice_sizes = dense<[1, 2]> : tensor<2xi64>
+ slice_sizes = array<i64: 1, 2>
} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x2xi32>
check.expect_eq_const(%result, dense<[
[7, 8]]> : tensor<1x2xi32>) : tensor<1x2xi32>
@@ -33,7 +33,7 @@
%input = util.unfoldable_constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%start1 = util.unfoldable_constant dense<1> : tensor<i64>
%result = "stablehlo.dynamic_slice"(%input, %start1) {
- slice_sizes = dense<[2]> : tensor<1xi64>
+ slice_sizes = array<i64: 2>
} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
check.expect_eq_const(%result, dense<[2, 3]> : tensor<2xi32>) : tensor<2xi32>
return
diff --git a/tests/e2e/stablehlo_ops/pad.mlir b/tests/e2e/stablehlo_ops/pad.mlir
index 9774bbc..18ccf78 100644
--- a/tests/e2e/stablehlo_ops/pad.mlir
+++ b/tests/e2e/stablehlo_ops/pad.mlir
@@ -2,9 +2,9 @@
%input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
%c0 = arith.constant dense<0> : tensor<i32>
%res = "stablehlo.pad"(%input, %c0) {
- edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
- edge_padding_high = dense<[1, 5]> : tensor<2xi64>,
- interior_padding = dense<0> : tensor<2xi64>
+ edge_padding_low = array<i64: 0, 1>,
+ edge_padding_high = array<i64: 1, 5>,
+ interior_padding = array<i64: 0, 0>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<3x9xi32>
check.expect_eq_const(%res, dense<[
[0, 1, 2, 3, 0, 0, 0, 0, 0],
@@ -16,7 +16,7 @@
func.func @pad_no_op() {
%input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
%c0 = arith.constant dense<0> : tensor<i32>
- %res = "stablehlo.pad"(%input, %c0) {edge_padding_high = dense<[0, 0]> : tensor<2xi64>, edge_padding_low = dense<[0, 0]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<2x3xi32>
+ %res = "stablehlo.pad"(%input, %c0) {edge_padding_high = array<i64: 0, 0>, edge_padding_low = array<i64: 0, 0>, interior_padding = array<i64: 0, 0>} : (tensor<2x3xi32>, tensor<i32>) -> tensor<2x3xi32>
check.expect_eq(%res, %input) : tensor<2x3xi32>
return
}
diff --git a/tests/e2e/stablehlo_ops/reverse.mlir b/tests/e2e/stablehlo_ops/reverse.mlir
index 11d53e6..336065c 100644
--- a/tests/e2e/stablehlo_ops/reverse.mlir
+++ b/tests/e2e/stablehlo_ops/reverse.mlir
@@ -1,19 +1,19 @@
func.func @xla_reverse() {
%t1 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
- %dim0 = "stablehlo.reverse"(%t1) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
+ %dim0 = "stablehlo.reverse"(%t1) {dimensions = array<i64: 0>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
check.expect_almost_eq_const(
%dim0,
dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>
) : tensor<2x3xf32>
- %dim1 = "stablehlo.reverse"(%t1) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
+ %dim1 = "stablehlo.reverse"(%t1) {dimensions = array<i64: 1>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
check.expect_almost_eq_const(
%dim1,
dense<[[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]> : tensor<2x3xf32>
) : tensor<2x3xf32>
- %both_dims = "stablehlo.reverse"(%t1) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
+ %both_dims = "stablehlo.reverse"(%t1) {dimensions = array<i64: 0, 1>} : (tensor<2x3xf32>) -> tensor<2x3xf32>
check.expect_almost_eq_const(
%both_dims,
dense<[[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]> : tensor<2x3xf32>
diff --git a/tests/e2e/stablehlo_ops/slice.mlir b/tests/e2e/stablehlo_ops/slice.mlir
index 2f0120d..f52c5a64 100644
--- a/tests/e2e/stablehlo_ops/slice.mlir
+++ b/tests/e2e/stablehlo_ops/slice.mlir
@@ -4,9 +4,9 @@
[05, 06, 07, 08],
[09, 10, 11, 12]]> : tensor<3x4xi32>
%result = "stablehlo.slice"(%input) {
- start_indices = dense<[0, 0]> : tensor<2xi64>,
- limit_indices = dense<[3, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 0, 0>,
+ limit_indices = array<i64: 3, 4>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<3x4xi32>
check.expect_eq_const(%result, dense<[
[1, 2, 3, 4],
@@ -21,9 +21,9 @@
[05, 06, 07, 08],
[09, 10, 11, 12]]> : tensor<3x4xi32>
%result = "stablehlo.slice"(%input) {
- start_indices = dense<[1, 0]> : tensor<2xi64>,
- limit_indices = dense<[2, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 1, 0>,
+ limit_indices = array<i64: 2, 4>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<1x4xi32>
check.expect_eq_const(%result, dense<[[5, 6, 7, 8]]> : tensor<1x4xi32>) : tensor<1x4xi32>
return
@@ -35,9 +35,9 @@
[05, 06, 07, 08],
[09, 10, 11, 12]]> : tensor<3x4xi32>
%result = "stablehlo.slice"(%input) {
- start_indices = dense<[1, 1]> : tensor<2xi64>,
- limit_indices = dense<[2, 3]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 1, 1>,
+ limit_indices = array<i64: 2, 3>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
check.expect_eq_const(%result, dense<[[6, 7]]> : tensor<1x2xi32>) : tensor<1x2xi32>
return
@@ -49,9 +49,9 @@
[05, 06, 07, 08],
[09, 10, 11, 12]]> : tensor<3x4xi32>
%result = "stablehlo.slice"(%input) {
- start_indices = dense<[1, 0]> : tensor<2xi64>,
- limit_indices = dense<[3, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
+ start_indices = array<i64: 1, 0>,
+ limit_indices = array<i64: 3, 4>,
+ strides = array<i64: 1, 1>
} : (tensor<3x4xi32>) -> tensor<2x4xi32>
check.expect_eq_const(%result, dense<[
[5, 6, 7, 8],
diff --git a/tests/e2e/stablehlo_ops/transpose.mlir b/tests/e2e/stablehlo_ops/transpose.mlir
index d709f2e..8b9b238 100644
--- a/tests/e2e/stablehlo_ops/transpose.mlir
+++ b/tests/e2e/stablehlo_ops/transpose.mlir
@@ -2,7 +2,7 @@
%input = util.unfoldable_constant dense<[[1, 2, 3],
[4, 5, 6]]> : tensor<2x3xi32>
%0 = "stablehlo.transpose"(%input) {
- permutation = dense<[1, 0]> : tensor<2xi64>
+ permutation = array<i64: 1, 0>
} : (tensor<2x3xi32>) -> tensor<3x2xi32>
check.expect_eq_const(%0, dense<[[1, 4],
[2, 5],
@@ -16,7 +16,7 @@
[[ 7, 8, 9],
[10, 11, 12]]]> : tensor<2x2x3xi32>
%0 = "stablehlo.transpose"(%input) {
- permutation = dense<[0, 2, 1]> : tensor<3xi64>
+ permutation = array<i64: 0, 2, 1>
} : (tensor<2x2x3xi32>) -> tensor<2x3x2xi32>
check.expect_eq_const(%0, dense<[
[[ 1, 4],
diff --git a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
index 334b996..64a78f0 100644
--- a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
+++ b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake
@@ -84,29 +84,29 @@
iree_fetch_artifact(
NAME "model-EfficientNetV2STF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-MiniLML12H384Uncased"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-BertForMaskedLMTF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-BertLargeTF"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc"
- OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc"
+ OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
UNPACK
)
@@ -140,63 +140,63 @@
iree_fetch_artifact(
NAME "model-BertLargeTFBatch1"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH1/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH1/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch1.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-BertLargeTFBatch32"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH32/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH32/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch32.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-BertLargeTFBatch64"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH64/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH64/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch64.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-Resnet50TFBatch1"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH1/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH1/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch1.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-Resnet50TFBatch64"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH64/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH64/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch64.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-Resnet50TFBatch128"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH128/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH128/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch128.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-T5LargeTFBatch1"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH1/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH1/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch1.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-T5LargeTFBatch16"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH16/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH16/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch16.mlirbc"
UNPACK
)
iree_fetch_artifact(
NAME "model-T5LargeTFBatch32"
- SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH32/stablehlo.mlirbc"
+ SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH32/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch32.mlirbc"
UNPACK
)
diff --git a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
index 1cbbde3..dafca14 100644
--- a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
+++ b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
@@ -246,7 +246,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -260,7 +260,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -274,7 +274,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -288,7 +288,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -538,7 +538,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -553,7 +553,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -568,7 +568,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -583,7 +583,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -838,7 +838,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -853,7 +853,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -868,7 +868,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -883,7 +883,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -1093,7 +1093,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1105,7 +1105,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1117,7 +1117,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1129,7 +1129,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -1275,7 +1275,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2120,7 +2120,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2138,7 +2138,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2156,7 +2156,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2174,7 +2174,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2492,7 +2492,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2511,7 +2511,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2530,7 +2530,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2549,7 +2549,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2872,7 +2872,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2891,7 +2891,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2910,7 +2910,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -2929,7 +2929,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
@@ -3195,7 +3195,7 @@
iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3211,7 +3211,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3227,7 +3227,7 @@
iree_bytecode_module(
NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3243,7 +3243,7 @@
iree_bytecode_module(
NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=cuda"
@@ -3433,7 +3433,7 @@
iree_bytecode_module(
NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_"
- SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc"
+ SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
diff --git a/tests/microbenchmarks/stablehlo_fft_abs.mlir b/tests/microbenchmarks/stablehlo_fft_abs.mlir
index 3ba9274..cdb5f7c 100644
--- a/tests/microbenchmarks/stablehlo_fft_abs.mlir
+++ b/tests/microbenchmarks/stablehlo_fft_abs.mlir
@@ -5,7 +5,7 @@
func.func @rfft_abs_6x1024() -> tensor<6x513xf32> {
%input = util.unfoldable_constant dense<1.0> : tensor<6x1024xf32>
%0 = "stablehlo.fft"(%input) {
- fft_length = dense<1024> : tensor<1xi64>,
+ fft_length = array<i64: 1024>,
fft_type = #stablehlo<fft_type RFFT>
} : (tensor<6x1024xf32>) -> tensor<6x513xcomplex<f32>>
%1 = "stablehlo.abs"(%0) : (tensor<6x513xcomplex<f32>>) -> tensor<6x513xf32>
diff --git a/third_party/stablehlo b/third_party/stablehlo
index 6b1ebdb..f8dcebf 160000
--- a/third_party/stablehlo
+++ b/third_party/stablehlo
@@ -1 +1 @@
-Subproject commit 6b1ebdbfa70ef9ce794f41e4fd1b0839191164c9
+Subproject commit f8dcebfa1ec166806974f6ae0dfb902d36b47238