Merge pull request #7265 from MaheshRavishankar:main-to-google
PiperOrigin-RevId: 401101859
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index ae39185..d49cf01 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -690,6 +690,30 @@
.def_property_readonly(
"stashed_flatbuffer_blob",
[](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
+ .def_property_readonly(
+ "function_names",
+ [](VmModule& self) {
+ py::list names;
+ iree_vm_module_signature_t sig =
+ iree_vm_module_signature(self.raw_ptr());
+ for (size_t ordinal = 0; ordinal < sig.export_function_count;
+ ++ordinal) {
+ iree_vm_function_t f;
+ iree_string_view_t linkage_name;
+ auto status = iree_vm_module_lookup_function_by_ordinal(
+ self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f,
+ &linkage_name);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ break;
+ }
+ CheckApiStatus(status, "Error enumerating module");
+ iree_string_view_t fname = iree_vm_function_name(&f);
+ py::str name(fname.data, fname.size);
+ names.append(name);
+ }
+ return names;
+ })
.def("__repr__", [](VmModule& self) {
std::string repr("<VmModule ");
iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
diff --git a/build_tools/benchmarks/run_benchmarks_on_android.py b/build_tools/benchmarks/run_benchmarks_on_android.py
index fa51dae..1c13a9c 100755
--- a/build_tools/benchmarks/run_benchmarks_on_android.py
+++ b/build_tools/benchmarks/run_benchmarks_on_android.py
@@ -216,6 +216,7 @@
def filter_benchmarks_for_category(benchmark_category_dir: str,
cpu_target_arch: str,
gpu_target_arch: str,
+ driver_filter: Optional[str],
verbose: bool = False) -> Sequence[str]:
"""Filters benchmarks in a specific category for the given device.
@@ -223,6 +224,7 @@
- benchmark_category_dir: the directory to a specific benchmark category.
- cpu_target_arch: CPU target architecture.
- gpu_target_arch: GPU target architecture.
+ - driver_filter: only run benchmarks for the given driver if not None.
Returns:
- A list containing all matched benchmark cases' directories.
@@ -242,10 +244,16 @@
continue
iree_driver, target_arch, bench_mode = segments
+ iree_driver = iree_driver[len("iree-"):].lower()
target_arch = target_arch.lower()
- # We can choose this benchmark if it matches the CPU/GPU architecture.
- should_choose = (target_arch == cpu_target_arch or
- target_arch == gpu_target_arch)
+
+ # We can choose this benchmark if it matches the driver and CPU/GPU
+ # architecture.
+ matched_driver = (
+ driver_filter is None or iree_driver == driver_filter.lower())
+ matched_arch = (
+ target_arch == cpu_target_arch or target_arch == gpu_target_arch)
+ should_choose = matched_driver and matched_arch
if should_choose:
matched_benchmarks.append(root)
@@ -373,6 +381,7 @@
def filter_and_run_benchmarks(
device_info: AndroidDeviceInfo,
root_build_dir: str,
+ driver_filter: Optional[str],
normal_benchmark_tool: str,
traced_benchmark_tool: Optional[str],
trace_capture_tool: Optional[str],
@@ -382,6 +391,7 @@
Args:
- device_info: an AndroidDeviceInfo object.
- root_build_dir: the root build directory.
+ - driver_filter: only run benchmarks for the given driver if not None.
- normal_benchmark_tool: the path to the normal benchmark tool.
- traced_benchmark_tool: the path to the tracing-enabled benchmark tool.
- trace_capture_tool: the path to the tool for collecting captured traces.
@@ -400,6 +410,7 @@
benchmark_category_dir=benchmark_category_dir,
cpu_target_arch=cpu_target_arch,
gpu_target_arch=gpu_target_arch,
+ driver_filter=driver_filter,
verbose=verbose)
run_results = run_benchmarks_for_category(
device_info=device_info,
@@ -454,6 +465,11 @@
type=check_exe_path,
default=None,
help="Path to the tool for collecting captured traces")
+ parser.add_argument(
+ "--driver",
+ type=str,
+ default=None,
+ help="Only run benchmarks for a specific driver, e.g., 'vulkan'")
parser.add_argument("-o",
dest="output",
default=None,
@@ -464,9 +480,8 @@
parser.add_argument(
"--no-clean",
action="store_true",
- help=
- "Do not clean up the temporary directory used for benchmarking on the Android device"
- )
+ help="Do not clean up the temporary directory used for "
+ "benchmarking on the Android device")
parser.add_argument("--verbose",
action="store_true",
help="Print internal information during execution")
@@ -499,10 +514,17 @@
(args.trace_capture_tool is not None):
execute_cmd_and_get_output(["adb", "forward", "tcp:8086", "tcp:8086"])
+ args.traced_benchmark_tool = os.path.realpath(args.traced_benchmark_tool)
+ args.trace_capture_tool = os.path.realpath(args.trace_capture_tool)
+
results, captures = filter_and_run_benchmarks(
- device_info, args.build_dir, os.path.realpath(args.normal_benchmark_tool),
- os.path.realpath(args.traced_benchmark_tool),
- os.path.realpath(args.trace_capture_tool), args.verbose)
+ device_info=device_info,
+ root_build_dir=args.build_dir,
+ driver_filter=args.driver,
+ normal_benchmark_tool=os.path.realpath(args.normal_benchmark_tool),
+ traced_benchmark_tool=args.traced_benchmark_tool,
+ trace_capture_tool=args.trace_capture_tool,
+ verbose=args.verbose)
if args.output is not None:
with open(args.output, "w") as f:
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 8297bdd..40f9e05 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -543,7 +543,11 @@
cast<IREE::HAL::InterfaceBindingSubspanOp>(op).queryBindingOp();
IREE::HAL::InterfaceBindingSubspanOpAdaptor newOperands(
operands, op->getAttrDictionary());
- MemRefType memRefType = op->getResult(0).getType().cast<MemRefType>();
+ MemRefType memRefType = op->getResult(0).getType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(
+ op,
+ "failed to convert interface.binding.subspan result to memref type");
auto memRefDesc = abi.loadBinding(
op->getLoc(), interfaceBindingOp.binding().getZExtValue(),
newOperands.byte_offset(), memRefType, newOperands.dynamic_dims(),
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
index a95d79c..aa1a3c6 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
@@ -158,13 +158,15 @@
}
}
- // TODO: This should be a folding of Add into Contract in core but while
- // they live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
+ {
+ // Fold consumer add ops into the contraction op itself.
+ RewritePatternSet canonicalizationPatterns(context);
+ vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
+ context);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns));
+ }
+
// Apply vector specific operation lowering.
{
vector::VectorTransformsOptions vectorTransformsOptions =
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
index b9b2a97..e0df89c 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
@@ -238,13 +238,14 @@
std::move(vectorizationPatterns));
}
- // TODO: This should be a folding of Add into Contract in core but while they
- // live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
+ {
+ // Fold consumer add ops into the contraction op itself.
+ RewritePatternSet canonicalizationPatterns(context);
+ vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
+ context);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns));
+ }
if (enableVectorContractToAarch64Asm) {
RewritePatternSet vectorToAArch64AsmPatterns(context);
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
index 23b8590..ab5b9e5 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
@@ -86,13 +86,13 @@
populateVectorizationPatterns(vectorizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
- // TODO: This should be a folding of Add into Contract in core but while
- // they live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
+
+ // Fold consumer add ops into the contraction op itself.
+ RewritePatternSet canonicalizationPatterns(context);
+ vector::ContractionOp::getCanonicalizationPatterns(
+ canonicalizationPatterns, context);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns));
RewritePatternSet vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns);
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 99243bf..514d916 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -107,11 +107,13 @@
pm.addPass(createTensorConstantBufferizePass());
pm.addPass(createFoldTensorExtractOpPass());
+ pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
+
// SCF -> STD
pm.addNestedPass<FuncOp>(createLowerToCFGPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
- pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
+
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
pm.addPass(createLowerAffinePass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 525b0cc..8557c11 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -226,6 +226,13 @@
vectorizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
+
+ // Fold consumer add ops into the contraction op itself.
+ RewritePatternSet canonicalizationPatterns(context);
+ vector::ContractionOp::getCanonicalizationPatterns(
+ canonicalizationPatterns, context);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizationPatterns));
}
LLVM_DEBUG({
@@ -234,13 +241,6 @@
llvm::dbgs() << "\n\n";
});
- // TODO: This should be a folding of Add into Contract in core but while
- // they live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
{
RewritePatternSet vectorUnrollPatterns(funcOp.getContext());
diff --git a/iree/compiler/Codegen/Transforms/Transforms.cpp b/iree/compiler/Codegen/Transforms/Transforms.cpp
index 5a4778f..b4b6056 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -69,50 +69,5 @@
return success();
}
-/// Return a fused vector::ContractionOp which represents a patterns such as:
-///
-/// ```mlir
-/// %c0 = vector.constant 0: ...
-/// %c = vector.contract %a, %b, %c0: ...
-/// %e = add %c, %d: ...
-/// ```
-///
-/// by:
-///
-/// ```mlir
-/// %e = vector.contract %a, %b, %d: ...
-/// ```
-///
-/// Return null if the canonicalization does not apply.
-// TODO: This should be a folding of Add into Contract in core but while they
-// live in different dialects, it is not possible without unnatural
-// dependencies.
-vector::ContractionOp canonicalizeContractionAdd(Operation *op) {
- if (!isa<AddIOp, AddFOp>(op)) return nullptr;
-
- OpBuilder builder(op);
- auto canonicalize = [](OpBuilder &b, Value maybeContraction,
- Value otherOperand) -> vector::ContractionOp {
- vector::ContractionOp contractionOp =
- dyn_cast_or_null<vector::ContractionOp>(
- maybeContraction.getDefiningOp());
- if (!contractionOp) return nullptr;
- if (auto maybeZero =
- dyn_cast_or_null<ConstantOp>(contractionOp.acc().getDefiningOp())) {
- if (maybeZero.value() == b.getZeroAttr(contractionOp.acc().getType())) {
- BlockAndValueMapping bvm;
- bvm.map(contractionOp.acc(), otherOperand);
- return cast<vector::ContractionOp>(b.clone(*contractionOp, bvm));
- }
- }
- return nullptr;
- };
-
- Value a = op->getOperand(0), b = op->getOperand(1);
- vector::ContractionOp contract = canonicalize(builder, a, b);
- contract = contract ? contract : canonicalize(builder, b, a);
- return contract;
-}
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Codegen/Transforms/Transforms.h b/iree/compiler/Codegen/Transforms/Transforms.h
index 1b30ff2..e8270de 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/iree/compiler/Codegen/Transforms/Transforms.h
@@ -32,26 +32,6 @@
OpBuilder &builder, FuncOp funcOp,
WorkgroupCountRegionBuilder regionBuilder);
-/// Return a fused vector::ContractionOp which represents a patterns such as:
-///
-/// ```mlir
-/// %c0 = vector.constant 0: ...
-/// %c = vector.contract %a, %b, %c0: ...
-/// %e = add %c, %d: ...
-/// ```
-///
-/// by:
-///
-/// ```mlir
-/// %e = vector.contract %a, %b, %d: ...
-/// ```
-///
-/// Return null if the canonicalization does not apply.
-// TODO: This should be a folding of Add into Contract in core but while they
-// live in different dialects, it is not possible without unnatural
-// dependencies.
-vector::ContractionOp canonicalizeContractionAdd(Operation *op);
-
/// Insert patterns to perform folding of AffineMinOp by matching the pattern
/// generated by tile and distribute. Try to fold a affine.min op by matching
/// the following form:
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 30c7932..b375029 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -239,11 +239,31 @@
static std::pair<IREE::Flow::DispatchWorkgroupsOp, Operation *>
buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc,
ArrayRef<Value> count, Operation *op) {
+ SmallVector<Value> operands, operandDims;
+ SmallVector<int64_t> tiedOperands;
+ if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
+ // Handle tensor.insert_slice in a special manner. This op is actually two
+ // steps:
+ // 1) Copy over the dest tensor to the result,
+ // 2) Update the overwritten part of the result with the destination.
+ // To actually make this work, the dispatch region needs the `dest` and
+ // result to be tied operands. This is somehow special. It might fall out
+ // naturally, but not sure how. For now, just do it by construction.
+ operands.push_back(insertSliceOp.dest());
+ ReifiedRankedShapedTypeDims resultShapes;
+ (void)insertSliceOp.reifyResultShapes(rewriter, resultShapes);
+ auto destType = insertSliceOp.dest().getType().cast<ShapedType>();
+ for (auto shape : enumerate(destType.getShape())) {
+ if (shape.value() != ShapedType::kDynamicSize) {
+ continue;
+ }
+ operandDims.push_back(resultShapes[0][shape.index()]);
+ }
+ tiedOperands.push_back(0);
+ }
auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
- loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{},
- /*operands=*/ValueRange{},
- /*operand_dims=*/ValueRange{},
- /*tied_operands=*/ArrayRef<int64_t>{});
+ loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{}, operands,
+ operandDims, tiedOperands);
Region ®ion = dispatchOp.body();
Block *block = ®ion.front();
Operation *clonedOp;
@@ -423,15 +443,13 @@
return orderedOps;
}
-/// Computes the values that will be eventually be used within the dispatch
+/// Computes the values that will eventually be used within the dispatch
/// workgroup op but defined outside the op after all clonable operations are
-/// cloned into the region. Returns (by reference) the clonable operations too,
-/// in order in which they can be cloned within the region to satisfy use-def
-/// relationships between them.
+/// cloned into the region.
static void getUsedValuesDefinedAboveAfterCloningOps(
- IREE::Flow::DispatchWorkgroupsOp dispatchOp,
- llvm::SetVector<Value> &valuesDefinedAbove,
- llvm::SmallVector<Operation *> &clonedOps) {
+ OpBuilder &builder, IREE::Flow::DispatchWorkgroupsOp dispatchOp,
+ llvm::SetVector<Value> &valuesDefinedAbove) {
+ llvm::SmallVector<Operation *> clonedOps;
llvm::SetVector<Value> visited;
SmallVector<Value, 4> worklist;
worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
@@ -452,6 +470,21 @@
// The cloned operations form a DAG. Return the cloned operations so the
// leaves come first, and can be cloned in-order into the dispatch region.
clonedOps = orderOperations(clonedOps);
+
+ for (auto clonedOp : reverse(clonedOps)) {
+ Operation *clone = builder.clone(*clonedOp);
+ for (auto result : llvm::enumerate(clonedOp->getResults())) {
+ result.value().replaceUsesWithIf(
+ clone->getResult(result.index()), [&](OpOperand &use) {
+ return use.getOwner()
+ ->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ==
+ dispatchOp;
+ });
+ valuesDefinedAbove.remove(result.value());
+ }
+ builder.setInsertionPoint(clone);
+ }
+
// Reverse the values. This is not for correctness, but more for readability
// of the IR.
llvm::SetVector<Value> reversedValues;
@@ -459,86 +492,115 @@
std::swap(reversedValues, valuesDefinedAbove);
}
+/// Returns the tied operand for the given `resultArg`. Returns nullptr if error
+/// or not found.
+static BlockArgument getTiedOperandBlockArgument(BlockArgument resultArg) {
+ auto resultArgType =
+ resultArg.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+ if (!resultArgType ||
+ resultArgType.getAccess() != IREE::Flow::TensorAccess::WriteOnly) {
+ return nullptr;
+ }
+ // Each output block argument should just have one use.
+ if (!resultArg.hasOneUse()) return nullptr;
+
+ // And that's a flow.dispatch.output.store op.
+ auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
+ (*resultArg.getUses().begin()).getOwner());
+ if (!storeOp) return nullptr;
+
+ Operation *tieOp = storeOp.value().getDefiningOp();
+ if (!tieOp) return nullptr;
+
+ // TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
+ // when it's available in MLIR core in some form.
+ BlockArgument tiedArg =
+ TypeSwitch<Operation *, BlockArgument>(tieOp)
+ .Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp insertOp)
+ -> BlockArgument {
+ auto loadOp =
+ insertOp.dest()
+ .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!loadOp) return nullptr;
+ return loadOp.source().dyn_cast<BlockArgument>();
+ })
+ .Case<IREE::Flow::DispatchTensorLoadOp>(
+ [&](auto loadOp) -> BlockArgument {
+ // Check that there is a single use and that the source is
+ // block argument. Single use can potentially be relaxed.
+ auto loadArg =
+ loadOp.source().template dyn_cast<BlockArgument>();
+ if (!loadArg || !loadArg.hasOneUse()) {
+ return nullptr;
+ }
+ return loadArg;
+ })
+ .Case<linalg::LinalgOp,
+ linalg_ext::LinalgExtOp>([&](auto linalgLikeOp)
+ -> BlockArgument {
+ unsigned resultIndex =
+ storeOp.value().cast<OpResult>().getResultNumber();
+ auto loadOp =
+ linalgLikeOp.getOutputTensorOperands()[resultIndex]
+ ->get()
+ .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!loadOp) return nullptr;
+ return loadOp.source().template dyn_cast<BlockArgument>();
+ })
+ .Default([&](Operation *) -> BlockArgument { return nullptr; });
+
+ if (!tiedArg) {
+ return nullptr;
+ }
+
+ // CHeck that the type of the tied argument candidate and type of the output
+ // match and that the tied argument is readonly.
+ auto type = tiedArg.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+ if (!type || type.getAccess() != IREE::Flow::TensorAccess::ReadOnly ||
+ type.getElementType() != resultArgType.getElementType() ||
+ llvm::any_of(llvm::zip(type.getShape(), resultArgType.getShape()),
+ [](std::tuple<int64_t, int64_t> sizes) {
+ return std::get<0>(sizes) !=
+ IREE::Flow::DispatchTensorType::kDynamicSize &&
+ std::get<1>(sizes) !=
+ IREE::Flow::DispatchTensorType::kDynamicSize &&
+ std::get<0>(sizes) != std::get<1>(sizes);
+ })) {
+ return nullptr;
+ }
+ return tiedArg;
+}
+
/// Modifies `dispatchOp` to attach operand-result tie information when
/// possible.
static void tryToTieOperandsAndResults(
IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
Block *block = dispatchOp.getBody(0);
- unsigned numResults = dispatchOp.getNumResults();
- auto inputs = block->getArguments().drop_back(numResults);
- auto outputs = block->getArguments().take_back(numResults);
+ unsigned numOperands = dispatchOp.getODSOperandIndexAndLength(1).second;
- // Returns the tied operand for the given `resultArg`. Returns nullptr
- // if error or not found.
- auto getTiedOperandBlockArgument =
- [](BlockArgument resultArg) -> BlockArgument {
- // Each output block argument should just have one use.
- if (!llvm::hasSingleElement(resultArg.getUses())) return nullptr;
-
- // And that's a flow.dispatch.output.store op.
- auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
- (*resultArg.getUses().begin()).getOwner());
- if (!storeOp) return nullptr;
-
- Operation *tieOp = storeOp.value().getDefiningOp();
- if (!tieOp) return nullptr;
-
- // TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
- // when it's available in MLIR core in some form.
- BlockArgument tiedArg =
- TypeSwitch<Operation *, BlockArgument>(tieOp)
- .Case<tensor::InsertSliceOp>(
- [&](tensor::InsertSliceOp insertOp) -> BlockArgument {
- auto loadOp = insertOp.dest()
- .template getDefiningOp<
- IREE::Flow::DispatchTensorLoadOp>();
- if (!loadOp) return nullptr;
- return loadOp.source().dyn_cast<BlockArgument>();
- })
- .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>(
- [&](auto linalgLikeOp) -> BlockArgument {
- unsigned resultIndex =
- storeOp.value().cast<OpResult>().getResultNumber();
- auto loadOp =
- linalgLikeOp.getOutputTensorOperands()[resultIndex]
- ->get()
- .template getDefiningOp<
- IREE::Flow::DispatchTensorLoadOp>();
- if (!loadOp) return nullptr;
- return loadOp.source().template dyn_cast<BlockArgument>();
- })
- .Default([&](Operation *) -> BlockArgument { return nullptr; });
-
- return tiedArg;
- };
-
- SmallVector<BlockArgument, 4> tiedOperands;
- tiedOperands.reserve(numResults);
-
- // Collect all result argument's tied operand arguments.
- for (BlockArgument &arg : outputs) {
- tiedOperands.push_back(getTiedOperandBlockArgument(arg));
- }
-
+ SmallVector<unsigned> eraseArguments;
// Go over each result to tie operand when possible, by:
// 1. Update the tied operand argument to take readwrite tensors.
// 2. Erase the result argument.
// 3. Attach the tie information to the DispatchWorkgroupsOp.
- for (int i = outputs.size() - 1; i >= 0; --i) {
- BlockArgument inputArg = tiedOperands[i];
- if (!inputArg) continue;
-
- auto oldType = inputArg.getType().cast<IREE::Flow::DispatchTensorType>();
- inputArg.setType(IREE::Flow::DispatchTensorType::get(
+ for (auto result : llvm::enumerate(dispatchOp.getResults())) {
+ if (dispatchOp.getTiedResultOperand(result.value())) continue;
+ BlockArgument outputArgument =
+ block->getArgument(numOperands + result.index());
+ BlockArgument tiedOperandArgument =
+ getTiedOperandBlockArgument(outputArgument);
+ if (!tiedOperandArgument) continue;
+ auto oldType =
+ tiedOperandArgument.getType().cast<IREE::Flow::DispatchTensorType>();
+ tiedOperandArgument.setType(IREE::Flow::DispatchTensorType::get(
IREE::Flow::TensorAccess::ReadWrite, oldType.getShape(),
oldType.getElementType()));
-
- BlockArgument outputArg = block->getArgument(inputs.size() + i);
- outputArg.replaceAllUsesWith(inputArg);
- block->eraseArgument(inputs.size() + i);
-
- dispatchOp.setTiedResultOperandIndex(i, inputArg.getArgNumber());
+ outputArgument.replaceAllUsesWith(tiedOperandArgument);
+ eraseArguments.push_back(outputArgument.getArgNumber());
+ dispatchOp.setTiedResultOperandIndex(result.index(),
+ tiedOperandArgument.getArgNumber());
}
+ block->eraseArguments(eraseArguments);
}
static void replaceAllUsesWithinDispatchOp(
@@ -563,78 +625,92 @@
Location loc = dispatchOp.getLoc();
Region ®ion = dispatchOp.body();
Block &block = region.front();
- unsigned numOldBBArgs = block.getNumArguments();
OpBuilder b = OpBuilder::atBlockBegin(&block);
llvm::SetVector<Value> valuesDefinedAbove;
- llvm::SmallVector<Operation *> clonedOps;
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
if (valuesDefinedAbove.empty()) return success();
- getUsedValuesDefinedAboveAfterCloningOps(dispatchOp, valuesDefinedAbove,
- clonedOps);
+ getUsedValuesDefinedAboveAfterCloningOps(b, dispatchOp, valuesDefinedAbove);
+ b.setInsertionPointToStart(&block);
- BlockAndValueMapping map;
- SmallVector<Value> toReplaceWithinRegion;
- // Replace valuesDefinedAbove by new BB args (including the op's operands).
- for (Value operand : valuesDefinedAbove) {
- if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
- block.addArgument(IREE::Flow::DispatchTensorType::get(
- TensorAccess::ReadOnly, rt.getShape(), rt.getElementType()));
- } else {
- block.addArgument(operand.getType());
+ // Build a map from current operands to arguments.
+ std::pair<unsigned, unsigned> operandsIndexAndLength =
+ dispatchOp.getODSOperandIndexAndLength(1);
+ std::pair<unsigned, unsigned> operandDimsIndexAndLength =
+ dispatchOp.getODSOperandIndexAndLength(2);
+ llvm::DenseMap<Value, BlockArgument> operandToBBArg;
+ for (auto operand : llvm::enumerate(dispatchOp.operands())) {
+ operandToBBArg[operand.value()] = block.getArgument(operand.index());
+ }
+
+ // Of the values defined above and used in the region, add values that are not
+ // operands to the region already.
+ unsigned numOperands = operandsIndexAndLength.second;
+ unsigned numOperandDims = operandDimsIndexAndLength.second;
+ for (auto value : valuesDefinedAbove) {
+ BlockArgument bbArg = operandToBBArg.lookup(value);
+ auto tensorType = value.getType().dyn_cast<RankedTensorType>();
+ if (!bbArg) {
+ // Create a new basic block argument for this value.
+ Type bbArgType = value.getType();
+ if (tensorType) {
+ bbArgType = IREE::Flow::DispatchTensorType::get(
+ TensorAccess::ReadOnly, tensorType.getShape(),
+ tensorType.getElementType());
+ }
+ bbArg = block.insertArgument(numOperands, bbArgType, value.getLoc());
}
- Value bbArg = block.getArguments().back();
Value repl = bbArg;
if (bbArg.getType().isa<IREE::Flow::DispatchTensorType>()) {
+ // For arguments of type flow.dispatch.tensor, create a
+ // flow.dispatch.tensor.load to get the replacement values.
repl = b.create<IREE::Flow::DispatchTensorLoadOp>(
- loc, operand.getType().cast<RankedTensorType>(), bbArg);
+ loc, value.getType().cast<RankedTensorType>(), bbArg);
}
- map.map(operand, repl);
- toReplaceWithinRegion.push_back(operand);
- }
- // The only existing arguments are for the outputs. Just need to add a new
- // argument for the outputs and remap the value to use the new argument.
- for (unsigned argNum : llvm::seq<unsigned>(0, numOldBBArgs)) {
- BlockArgument arg = block.getArgument(argNum);
- assert(arg.getType().isa<IREE::Flow::DispatchTensorType>());
- arg.replaceAllUsesWith(block.addArgument(arg.getType()));
- }
- // Drop old BB args.
- block.eraseArguments(
- llvm::to_vector<4>(llvm::seq<unsigned>(0, numOldBBArgs)));
+ value.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ return use.getOwner()
+ ->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ==
+ dispatchOp;
+ });
- // Clone the marked operations.
- for (Operation *op : clonedOps) {
- b.clone(*op, map);
- toReplaceWithinRegion.append(op->result_begin(), op->result_end());
- }
-
- // Make the region isolated from above.
- for (auto value : toReplaceWithinRegion) {
- replaceAllUsesWithinDispatchOp(dispatchOp, value, map.lookup(value));
- }
-
- // Gather the dynamic dimensions for all operands.
- SmallVector<Value, 4> operandDynamicDims;
- OpBuilder builder(dispatchOp);
- for (Value operand : valuesDefinedAbove) {
- if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
- for (unsigned i = 0; i < rt.getRank(); ++i) {
- if (!rt.isDynamicDim(i)) continue;
- auto dim = builder.createOrFold<tensor::DimOp>(dispatchOp.getLoc(),
- operand, i);
- operandDynamicDims.push_back(dim);
+ // Insert the operand if this is not already one. Also need to account for
+ // dynamic dim values for the operands.
+ if (!operandToBBArg.count(value)) {
+ dispatchOp->insertOperands(operandsIndexAndLength.first + numOperands,
+ {value});
+ numOperands++;
+ if (tensorType) {
+ // This dims for this operand does not exist. Add those.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(dispatchOp);
+ SmallVector<Value> dynamicDims;
+ for (auto dim : llvm::enumerate(tensorType.getShape())) {
+ if (dim.value() != ShapedType::kDynamicSize) continue;
+ dynamicDims.push_back(b.createOrFold<tensor::DimOp>(
+ dispatchOp.getLoc(), value, dim.index()));
+ }
+ dispatchOp->insertOperands(
+ operandsIndexAndLength.first + numOperands + numOperandDims,
+ dynamicDims);
+ numOperandDims += dynamicDims.size();
}
}
}
- // Set the values captured from above as the new operands.
- dispatchOp.operandsMutable().assign(llvm::to_vector<4>(valuesDefinedAbove));
- dispatchOp.operand_dimsMutable().assign(operandDynamicDims);
-
+ // Update the `operand_segment_sizes`.
+ auto operandSegmentSizes = dispatchOp->getAttrOfType<DenseIntElementsAttr>(
+ dispatchOp.operand_segment_sizesAttrName());
+ auto newValues = llvm::to_vector<4>(llvm::map_range(
+ operandSegmentSizes.getValues<APInt>(),
+ [&](APInt val) -> int32_t { return val.getSExtValue(); }));
+ newValues[1] = numOperands;
+ newValues[2] = numOperandDims;
+ auto newAttr =
+ DenseIntElementsAttr::get(operandSegmentSizes.getType(), newValues);
+ dispatchOp->setAttr(dispatchOp.operand_segment_sizesAttrName(), newAttr);
return success();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 11a1a55..732502e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -457,30 +457,59 @@
// -----
-// A subsequent pass is expected to convert linalg.fill into DMA ops.
-func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32> {
- %cst = constant 0.000000e+00 : f32
- %0 = linalg.init_tensor [1, 225, 225, 3] : tensor<1x225x225x3xf32>
- %1 = linalg.fill(%cst, %0) : f32, tensor<1x225x225x3xf32> -> tensor<1x225x225x3xf32>
- %2 = tensor.insert_slice %arg0 into %1[0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1] : tensor<1x224x224x3xf32> into tensor<1x225x225x3xf32>
- return %2 : tensor<1x225x225x3xf32>
+func @subtensor_insert(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) -> tensor<?x?xf32> {
+ %0 = tensor.insert_slice %arg0 into
+ %arg1[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func @subtensor_insert
-// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x224x224x3xf32>)
-//
-// CHECK-NOT: flow.dispatch.workgroups
-// CHECK: %[[FILL:.+]] = linalg.fill
-//
-// CHECK: %[[PAD:.+]] = flow.dispatch.workgroups[{{.+}}](%[[INPUT]], %[[FILL]]) : (tensor<1x224x224x3xf32>, tensor<1x225x225x3xf32>) -> %[[FILL]] =
-// CHECK-NEXT: (%[[SRC:.+]]: !flow.dispatch.tensor<readonly:1x224x224x3xf32>, %[[DST:.+]]: !flow.dispatch.tensor<readwrite:1x225x225x3xf32>) {
-// CHECK-NEXT: %[[SRC_TENSOR:.+]] = flow.dispatch.tensor.load %[[SRC]], {{.*}} : !flow.dispatch.tensor<readonly:1x224x224x3xf32> -> tensor<1x224x224x3xf32>
-// CHECK-NEXT: %[[DST_TENSOR:.+]] = flow.dispatch.tensor.load %[[DST]], {{.*}} : !flow.dispatch.tensor<readwrite:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
-// CHECK-NEXT: %[[INSERT:.+]] = tensor.insert_slice %[[SRC_TENSOR]] into %[[DST_TENSOR]][0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1]
-// CHECK-NEXT: flow.dispatch.tensor.store %[[INSERT]], %[[DST]], {{.*}} : tensor<1x225x225x3xf32> -> !flow.dispatch.tensor<readwrite:1x225x225x3xf32>
-// CHECK-NEXT: flow.return
-//
-// CHECK: return %[[PAD]] : tensor<1x225x225x3xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]]
+// CHECK-SAME: (%[[ARG1]], %[[ARG0]], %[[ARG2]], %[[ARG3]])
+// CHECK-SAME: tensor<?x?xf32>{%[[D2]], %[[D3]]}
+// CHECK-SAME: tensor<?x?xf32>{%[[D0]], %[[D1]]}
+// CHECK-SAME: -> %[[ARG1]]{%[[D2]], %[[D3]]}
+// CHECK-NEXT: %[[ARG6:.+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
+// CHECK-SAME: %[[ARG7:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0]
+// CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1]
+// CHECK-DAG: %[[SHAPE:.+]] = flow.dispatch.shape %[[ARG7]]
+// CHECK: %[[UB_Y:.+]] = shapex.ranked_dim %[[SHAPE]][0]
+// CHECK: %[[UB_X:.+]] = shapex.ranked_dim %[[SHAPE]][1]
+// CHECK-DAG: %[[WGID_X:.+]] = flow.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[WGCOUNT_X:.+]] = flow.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[WGID_Y:.+]] = flow.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[WGCOUNT_Y:.+]] = flow.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]](%[[WGID_Y]])[%[[WGSIZE_Y]]]
+// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MAP0]](%[[WGCOUNT_Y]])[%[[WGSIZE_Y]]]
+// CHECK: scf.for %[[ARG10:.+]] = %[[OFFSET_Y]] to %[[UB_Y]] step %[[STEP_Y]]
+// CHECK-DAG: %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[ARG10]])[%[[WGSIZE_Y]], %[[UB_Y]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP0]](%[[WGID_X]])[%[[WGSIZE_X]]]
+// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MAP0]](%[[WGCOUNT_X]])[%[[WGSIZE_X]]]
+// CHECK: scf.for %[[ARG11:.+]] = %[[OFFSET_X]] to %[[UB_X]] step %[[STEP_X]]
+// CHECK: %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[ARG11]])[%[[WGSIZE_X]], %[[UB_X]]]
+// CHECK: %[[LOAD_TILE:.+]] = flow.dispatch.tensor.load %[[ARG7]]
+// CHECK-SAME: offsets = [%[[ARG10]], %[[ARG11]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+// CHECK-DAG: %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[ARG10]])[%[[ARG8]]]
+// CHECK-DAG: %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[ARG11]])[%[[ARG9]]]
+// CHECK: flow.dispatch.tensor.store %[[LOAD_TILE]], %[[ARG6]]
+// CHECK: return %[[RESULT]]
// -----
@@ -1096,3 +1125,40 @@
// CHECK: %[[FILL:.+]] = linalg.fill_rng_2d
// CHECK: linalg.matmul
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?xf32>)
+
+// -----
+
+func @dynamic_slice(%arg0 : i32, %arg1 : i32, %arg2 : tensor<?xi32>,
+ %arg3 : tensor<?x?xi32>) -> tensor<?x?xi32>{
+ %c0 = constant 0 : index
+ %c0_i32 = constant 0 : i32
+ %c2_i32 = constant 2 : i32
+ %5 = cmpi slt, %arg0, %c2_i32 : i32
+ %6 = select %5, %arg0, %c2_i32 : i32
+ %7 = cmpi sgt, %6, %c0_i32 : i32
+ %8 = select %7, %6, %c0_i32 : i32
+ %9 = index_cast %8 : i32 to index
+ %11 = cmpi slt, %arg1, %c0_i32 : i32
+ %12 = select %11, %arg1, %c0_i32 : i32
+ %13 = cmpi sgt, %12, %c0_i32 : i32
+ %14 = select %13, %12, %c0_i32 : i32
+ %15 = index_cast %14 : i32 to index
+ %d0 = tensor.dim %arg2, %c0 : tensor<?xi32>
+ %17 = tensor.insert_slice %arg2 into
+ %arg3[%9, %15] [1, %d0] [1, 1] : tensor<?xi32> into tensor<?x?xi32>
+ return %17 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @dynamic_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?xi32>
+// CHECK-SAME: %[[ARG3:.+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG3]], %[[C0]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG3]], %[[C1]]
+// CHECK: flow.dispatch.workgroups[%[[D0]], %[[C1]], %[[C1]]]
+// CHECK-SAME: tensor<?x?xi32>{%[[D1]], %[[D2]]}, tensor<?xi32>{%[[D0]]}
+// CHECK-NEXT: %[[ARG4:.+]]: !flow.dispatch.tensor<readwrite:?x?xi32>
+// CHECK-SAME: %[[ARG5:.+]]: !flow.dispatch.tensor<readonly:?xi32>
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
index fd8c3a5..2cd8a14 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -31,16 +31,117 @@
return valueOrAttr.get<Value>();
}
+//===----------------------------------------------------------------------===//
+// Interface implementations for external operations.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct InsertSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
+ tensor::InsertSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op) const {
+ SmallVector<Value> dest;
+ dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
+ return dest;
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ Value source = insertSliceOp.source();
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ Location loc = op->getLoc();
+ Value zero = b.create<ConstantIndexOp>(loc, 0);
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopBounds(sourceType.getRank(),
+ Range{zero, nullptr, one});
+ for (auto dim :
+ llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
+ loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
+ }
+ return loopBounds;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto insertOp = cast<tensor::InsertSliceOp>(op);
+ // Compute a subtensor of the source based on the offsets.
+ auto opStrides = insertOp.getMixedStrides();
+ if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 1;
+ })) {
+ op->emitOpError("unable to tile operation with non-unit stride");
+ return nullptr;
+ }
+ Location loc = insertOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+ auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
+ loc, insertOp.source(), offsets, sizes, strides);
+
+ // The offsets for the insert is based on the op offsets plus the offsets of
+ // the loops passed in.
+ auto opOffsets = insertOp.getMixedOffsets();
+ auto opSizes = insertOp.getMixedSizes();
+ unsigned offsetIndex = 0;
+ ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
+ int64_t destRank = insertOp.getType().getRank();
+ SmallVector<OpFoldResult> resultOffsets(destRank);
+ SmallVector<OpFoldResult> resultSizes(destRank);
+ for (auto opOffset : llvm::enumerate(opOffsets)) {
+ // Check for rank-reducing by checking that
+ // 1) The corresponding opSize value is 1
+ // 2) The current rank of the source is not 1.
+ // Then the opOffset is for the rank-reduced dimension. Skip.
+ unsigned opOffsetIndex = opOffset.index();
+ OpFoldResult opOffsetVal = opOffset.value();
+ Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
+ if (offsetIndex >= sourceShape.size() ||
+ (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1)) {
+ resultOffsets[opOffsetIndex] = opOffsetVal;
+ resultSizes[opOffsetIndex] = oneAttr;
+ continue;
+ }
+ OpFoldResult offset = offsets[offsetIndex];
+ if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
+ resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
+ *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
+ } else {
+ AffineMap map = AffineMap::get(
+ 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
+ resultOffsets[opOffsetIndex] =
+ b.create<AffineApplyOp>(loc, map,
+ ValueRange{getValue(b, loc, offset),
+ getValue(b, loc, opOffsetVal)})
+ .getResult();
+ }
+ resultSizes[opOffsetIndex] = sizes[offsetIndex];
+ offsetIndex++;
+ }
+ SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return extractSliceOp;
+ }
+};
+} // namespace
+
void registerTiledOpInterfaceExternalModels(DialectRegistry ®istry) {
LLVM_DEBUG({
llvm::dbgs() << "Adding tiled op interface for tensor.insert_slice\n";
});
- // TODO(ravishankarm): For now this is commented out since there are a lot of
- // upstream bugs exposed by this. Leaving the restructuring in place, but
- // avoiding the interface hook till those are addressed.
- //
- // registry.addOpInterface<tensor::InsertSliceOp,
- // InsertSliceTiledOpInterface>();
+ registry.addOpInterface<tensor::InsertSliceOp, InsertSliceTiledOpInterface>();
}
} // namespace linalg_ext
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
index 16aca74..ae62abe 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -235,112 +235,6 @@
loopBounds, 0, offsets, distributionInfo);
}
-//===----------------------------------------------------------------------===//
-// Interface implementations for external operations.
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct InsertSliceTiledOpInterface
- : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
- tensor::InsertSliceOp> {
- SmallVector<Value> getDestinationOperands(Operation *op) const {
- SmallVector<Value> dest;
- dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
- return dest;
- }
-
- SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
- getParallelIteratorTypeName());
- }
-
- SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- Value source = insertSliceOp.source();
- RankedTensorType sourceType = insertSliceOp.getSourceType();
- Location loc = op->getLoc();
- Value zero = b.create<ConstantIndexOp>(loc, 0);
- Value one = b.create<ConstantIndexOp>(loc, 1);
- SmallVector<Range> loopBounds(sourceType.getRank(),
- Range{zero, nullptr, one});
- for (auto dim :
- llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
- loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
- }
- return loopBounds;
- }
-
- Operation *getTiledImplementation(Operation *op, OpBuilder &b,
- ValueRange outputs,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<Value> &results) const {
- auto insertOp = cast<tensor::InsertSliceOp>(op);
- // Compute a subtensor of the source based on the offsets.
- auto opStrides = insertOp.getMixedStrides();
- if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
- Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
- return intVal && *intVal == 1;
- })) {
- op->emitOpError("unable to tile operation with non-unit stride");
- return nullptr;
- }
- Location loc = insertOp.getLoc();
- auto oneAttr = b.getI64IntegerAttr(1);
- SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
- auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
- loc, insertOp.source(), offsets, sizes, strides);
-
- // The offsets for the insert is based on the op offsets plus the offsets of
- // the loops passed in.
- auto opOffsets = insertOp.getMixedOffsets();
- auto opSizes = insertOp.getMixedSizes();
- unsigned offsetIndex = 0;
- ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
- int64_t destRank = insertOp.getType().getRank();
- SmallVector<OpFoldResult> resultOffsets(destRank);
- SmallVector<OpFoldResult> resultSizes(destRank);
- auto zeroAttr = b.getI64IntegerAttr(0);
- for (auto opOffset : llvm::enumerate(opOffsets)) {
- // Check for rank-reducing by checking that
- // 1) The corresponding opSize value is 1
- // 2) The current rank of the source is not 1.
- // Then the opOffset is for the rank-reduced dimension. Skip.
- unsigned opOffsetIndex = opOffset.index();
- Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
- if (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1) {
- resultOffsets[opOffsetIndex] = zeroAttr;
- resultSizes[opOffsetIndex] = oneAttr;
- continue;
- }
- OpFoldResult opOffsetVal = opOffset.value();
- OpFoldResult offset = offsets[offsetIndex];
- if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
- resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
- *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
- } else {
- AffineMap map = AffineMap::get(
- 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
- resultOffsets[opOffsetIndex] =
- b.create<AffineApplyOp>(loc, map,
- ValueRange{getValue(b, loc, offset),
- getValue(b, loc, opOffsetVal)})
- .getResult();
- }
- resultSizes[opOffsetIndex] = sizes[offsetIndex];
- offsetIndex++;
- }
- SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
- auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
- loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
- resultStrides);
- results.push_back(tiledInsertOp.result());
- return extractSliceOp;
- }
-};
-} // namespace
-
LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
TiledOpInterface tilableOp, PatternRewriter &rewriter,
TiledOp &result) const {
@@ -373,7 +267,6 @@
linalg_ext::LinalgExtDialect, memref::MemRefDialect,
StandardOpsDialect, tensor::TensorDialect, scf::SCFDialect>();
}
- LogicalResult initialize(MLIRContext *context) override;
void runOnOperation() override;
};
} // namespace
@@ -383,13 +276,6 @@
return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
}
-LogicalResult TiledOpInterfaceTilingPass::initialize(MLIRContext *context) {
- // TODO(ravishankarm): When the interface is added during registration, remove
- // this initialization.
- tensor::InsertSliceOp::attachInterface<InsertSliceTiledOpInterface>(*context);
- return success();
-}
-
void TiledOpInterfaceTilingPass::runOnOperation() {
FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 21c3596..f632cf6 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -683,3 +683,65 @@
// CHECK: scf.yield %[[RES3]]
// CHECK: scf.yield %[[RES2]]
// CHECK: return %[[RES]]
+
+// -----
+
+func @dynamic_insert_slice(%arg0 : tensor<?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3] [1, %d0] [1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @dynamic_insert_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = constant 10 : index
+// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]]
+// CHECK-SAME: step %[[C10]] iter_args(%[[ARG5:.+]] = %[[ARG1]])
+// CHECK: %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG4]])[%[[C10]], %[[D0]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]]] [%[[TILESIZE]]]
+// CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG3]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT]] into %[[ARG5]]
+// CHECK-SAME: [%[[ARG2]], %[[OFFSET]]] [1, %[[TILESIZE]]]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: return %[[RESULT]]
+
+
+// -----
+
+func @insert_slice_rank_reduced_inner(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?x?x?xf32>, %arg2: index, %arg3 : index, %arg4 : index) -> tensor<?x?x?xf32> {
+ %c0 = constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [1, %d0, 1] [1, 1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @insert_slice_rank_reduced_inner(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[LB:.+]] = constant 0 : index
+// CHECK-DAG: %[[STEP:.+]] = constant 10 : index
+// CHECK: %[[UB:.+]] = tensor.dim %[[ARG0]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = %[[LB]]
+// CHECK-SAME: to %[[D0]] step %[[STEP]] iter_args(%[[ARG6:.+]] = %[[ARG1]])
+// CHECK: %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG5]])[%[[STEP]], %[[UB]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]]] [%[[TILESIZE]]]
+// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[IV0]])[%[[ARG3]]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG6]]
+// CHECK-SAME: [%[[ARG2]], %[[APPLY]], %[[ARG4]]] [1, %[[TILESIZE]], 1]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c
index ed751db..7873498 100644
--- a/iree/samples/static_library/static_library_demo.c
+++ b/iree/samples/static_library/static_library_demo.c
@@ -82,11 +82,6 @@
if (iree_status_is_ok(status)) {
status = create_device_with_static_loader(&device);
}
- iree_vm_module_t* hal_module = NULL;
- if (iree_status_is_ok(status)) {
- status =
- iree_hal_module_create(device, iree_allocator_system(), &hal_module);
- }
// Session configuration (one per loaded module to hold module state).
iree_runtime_session_options_t session_options;
diff --git a/iree/test/e2e/xla_ops/gather.mlir b/iree/test/e2e/xla_ops/gather.mlir
index c0417c6..465c58c 100644
--- a/iree/test/e2e/xla_ops/gather.mlir
+++ b/iree/test/e2e/xla_ops/gather.mlir
@@ -18,3 +18,152 @@
check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
return
}
+
+func @via_torch_index_select() {
+ %input = util.unfoldable_constant dense<[
+ [[01, 02, 03, 04, 05]],
+ [[06, 07, 08, 09, 10]],
+ [[11, 12, 13, 14, 15]],
+ [[16, 17, 18, 19, 20]],
+ [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
+ %start_indices = util.unfoldable_constant dense<2> : tensor<i64>
+ %res = "mhlo.gather"(%input, %start_indices) {
+ dimension_numbers = #mhlo.gather<
+ collapsed_slice_dims = [0],
+ index_vector_dim = 0,
+ offset_dims = [0, 1],
+ start_index_map = [0],
+ >,
+ slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+ } : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
+ check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
+ return
+}
+
+
+func @general_but_just_index_select() {
+ %operand = util.unfoldable_constant dense<[[
+ [ 0, 1, 2, 3, 4, 5, 6, 7],
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+ %start_indices = util.unfoldable_constant dense<[[
+ [0, 1],
+ [0, 2],
+ [0, 3],
+ [0, 0],
+ [0, 0],
+ [0, 1],
+ [0, 2],
+ [0, 3]]]> : tensor<1x8x2xi32>
+ %result = "mhlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #mhlo.gather<
+ collapsed_slice_dims = [0, 1],
+ index_vector_dim = 2,
+ offset_dims = [2],
+ start_index_map = [0, 1]
+ >,
+ indices_are_sorted = false,
+ slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>
+ } : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32>
+ check.expect_eq_const(%result, dense<[[
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30, 31],
+ [ 0, 1, 2, 3, 4, 5, 6, 7],
+ [ 0, 1, 2, 3, 4, 5, 6, 7],
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x8x8xi32>) : tensor<1x8x8xi32>
+ return
+}
+
+func @small_slices() {
+ %operand = util.unfoldable_constant dense<[[
+ [ 0, 1, 2, 3, 4, 5, 6, 7],
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+ %start_indices = util.unfoldable_constant dense<[[
+ [0, 1],
+ [0, 2],
+ [0, 3],
+ [0, 0]]]> : tensor<1x4x2xi32>
+ %result = "mhlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #mhlo.gather<
+ collapsed_slice_dims = [0, 1],
+ index_vector_dim = 2,
+ offset_dims = [2],
+ start_index_map = [0, 1]
+ >,
+ indices_are_sorted = false,
+ slice_sizes = dense<[1, 1, 3]> : tensor<3xi64>
+ } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x4x3xi32>
+ check.expect_eq_const(%result, dense<[[
+ [ 8, 9, 10],
+ [16, 17, 18],
+ [24, 25, 26],
+ [ 0, 1, 2]]]> : tensor<1x4x3xi32>) : tensor<1x4x3xi32>
+ return
+}
+
+func @nonstandard_offset_dims() {
+ %operand = util.unfoldable_constant dense<[[
+ [ 0, 1, 2, 3, 4, 5, 6, 7],
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
+ [16, 17, 18, 19, 20, 21, 22, 23],
+ [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+ %start_indices = util.unfoldable_constant dense<[[
+ [0, 1],
+ [0, 2],
+ [0, 2],
+ [0, 0]]]> : tensor<1x4x2xi32>
+ %result = "mhlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #mhlo.gather<
+ collapsed_slice_dims = [0],
+ index_vector_dim = 2,
+ offset_dims = [1, 2],
+ start_index_map = [0, 1]
+ >,
+ indices_are_sorted = false,
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x2x3x4xi32>
+ check.expect_eq_const(%result, dense<[[
+ [[ 8, 16, 16, 0],
+ [ 9, 17, 17, 1],
+ [10, 18, 18, 2]],
+ [[16, 24, 24, 8],
+ [17, 25, 25, 9],
+ [18, 26, 26, 10]]]]> : tensor<1x2x3x4xi32>) : tensor<1x2x3x4xi32>
+ return
+}
+
+func @reordered_start_index() {
+ %operand = util.unfoldable_constant dense<[[
+ [[ 0, 1, 2, 3],
+ [ 4, 5, 6, 7]],
+ [[ 8, 9, 10, 11],
+ [12, 13, 14, 15]],
+ [[16, 17, 18, 19],
+ [20, 21, 22, 23]]]]> : tensor<1x3x2x4xi32>
+ %start_indices = util.unfoldable_constant dense<[
+ [0, 1, 0, 0],
+ [1, 0, 0, 0]]> : tensor<2x4xi32>
+ %result = "mhlo.gather"(%operand, %start_indices) {
+ dimension_numbers = #mhlo.gather<
+ collapsed_slice_dims = [0, 2],
+ index_vector_dim = 1,
+ offset_dims = [1, 2],
+ start_index_map = [3, 2, 0, 1]
+ >,
+ indices_are_sorted = false,
+ slice_sizes = dense<[1, 2, 1, 3]> : tensor<4xi64>
+ } : (tensor<1x3x2x4xi32>, tensor<2x4xi32>) -> tensor<2x2x3xi32>
+
+ check.expect_eq_const(%result, dense<[
+ [[ 4, 5, 6],
+ [12, 13, 14]],
+ [[ 1, 2, 3],
+ [ 9, 10, 11]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32>
+ return
+}
diff --git a/iree/tools/utils/trace_replay.c b/iree/tools/utils/trace_replay.c
index f7e95c8..cba6e5f 100644
--- a/iree/tools/utils/trace_replay.c
+++ b/iree/tools/utils/trace_replay.c
@@ -427,6 +427,9 @@
yaml_node_t* encoding_type_node,
iree_hal_encoding_type_t* out_encoding_type) {
*out_encoding_type = IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+ if (!encoding_type_node) {
+ return iree_ok_status();
+ }
if (!encoding_type_node) return iree_ok_status();
iree_string_view_t encoding_type_str =
diff --git a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
index 6e946e8..2be688c 100644
--- a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
+++ b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
@@ -21,6 +21,7 @@
typedef struct name name
DEFINE_C_API_STRUCT(IreeCompilerOptions, void);
+#undef DEFINE_C_API_STRUCT
//===----------------------------------------------------------------------===//
// Registration.
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
index 3b6ae8d..d13028d 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
@@ -44,7 +44,10 @@
void ireeCompilerRegisterTargetBackends() { registerHALTargetBackends(); }
IreeCompilerOptions ireeCompilerOptionsCreate() {
- return wrap(new CompilerOptions);
+ auto options = new CompilerOptions;
+ // TODO: Make configurable.
+ options->vmTargetOptions.f32Extension = true;
+ return wrap(options);
}
void ireeCompilerOptionsDestroy(IreeCompilerOptions options) {
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index df10347..90340be 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -275,6 +275,7 @@
cc_library(
name = "IREEPyDMTransforms",
srcs = glob([
+ "lib/Dialect/IREEPyDM/Transforms/*.cpp",
"lib/Dialect/IREEPyDM/Transforms/RTL/*.cpp",
"lib/Dialect/IREEPyDM/Transforms/ToIREE/*.cpp",
]),
@@ -291,6 +292,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
@@ -312,5 +314,6 @@
":IREEPyDMTransforms",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index 613df95..da8701a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -27,6 +27,28 @@
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm);
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(IREEPyDMSourceBundle, void);
+DEFINE_C_API_STRUCT(IREEPyDMLoweringOptions, void);
+#undef DEFINE_C_API_STRUCT
+
+/// Creates a PyDM source bundle from an ASM string.
+MLIR_CAPI_EXPORTED IREEPyDMSourceBundle
+ireePyDMSourceBundleCreateAsm(MlirStringRef asmString);
+
+/// Creates a PyDM source bundle from a file path.
+MLIR_CAPI_EXPORTED IREEPyDMSourceBundle
+ireePyDMSourceBundleCreateFile(MlirStringRef filePath);
+
+/// Destroys a created source bundle.
+MLIR_CAPI_EXPORTED void ireePyDMSourceBundleDestroy(
+ IREEPyDMSourceBundle bundle);
+
MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type);
#define IREEPYDM_DECLARE_NULLARY_TYPE(Name) \
@@ -51,9 +73,20 @@
MLIR_CAPI_EXPORTED MlirType mlirIREEPyDMObjectTypeGet(MlirContext context,
MlirType primitive);
+/// Creates a lowering options struct.
+MLIR_CAPI_EXPORTED IREEPyDMLoweringOptions ireePyDMLoweringOptionsCreate();
+
+/// Sets the RTL link source bundle to the lowering options.
+MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsLinkRtl(
+ IREEPyDMLoweringOptions options, IREEPyDMSourceBundle source);
+
+/// Destroys a created lowering options struct.
+MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsDestroy(
+ IREEPyDMLoweringOptions options);
+
/// Builds a pass pipeline which lowers the iree_pydm dialect to IREE.
MLIR_CAPI_EXPORTED void mlirIREEPyDMBuildLowerToIREEPassPipeline(
- MlirOpPassManager passManager);
+ MlirOpPassManager passManager, IREEPyDMLoweringOptions options);
#ifdef __cplusplus
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
index c3670e0..a169051 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
@@ -7,8 +7,11 @@
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
+#include <memory>
+
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
namespace mlir {
@@ -18,9 +21,25 @@
namespace iree_pydm {
+/// References sources, either passed literally or by reference to a file.
+/// One of `asmBlob` or `asmFilePath` should be populated.
+struct SourceBundle {
+ std::shared_ptr<std::string> asmBlob;
+ Optional<std::string> asmFilePath;
+};
+
+/// Options for lowering to IREE.
+struct LowerToIREEOptions {
+ Optional<SourceBundle> linkRtlSource;
+};
+
std::unique_ptr<OperationPass<ModuleOp>> createConvertIREEPyDMToIREEPass();
std::unique_ptr<OperationPass<ModuleOp>> createLowerIREEPyDMToRTLPass();
-std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass();
+std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass(
+ Optional<SourceBundle> linkRtlSourceBundle = None);
+
+void buildLowerToIREEPassPipeline(OpPassManager& passManager,
+ const LowerToIREEOptions& options);
#define GEN_PASS_REGISTRATION
#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
index 69dd0e1..daef405 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
@@ -26,6 +26,10 @@
Type getWeakIntegerType(Builder b) const;
Type getWeakFloatType(Builder b) const;
+ // Whether the given type is a valid lowered type.
+ bool isTypeLegal(Type t) const;
+ bool areTypesLegal(TypeRange types) const;
+
private:
bool boolBits = 32;
int weakIntegerBits = 32;
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 66682e5..75be3ec 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -5,7 +5,7 @@
MLIRIR
IREEDialectsIREEDialect
IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMToIREEPasses
+ IREEDialectsIREEPyDMPasses
)
iree_dialects_target_includes(IREEDialectsCAPI)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index ad5e546..d316fbe 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -15,6 +15,9 @@
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
//===----------------------------------------------------------------------===//
// IREEDialect
@@ -29,6 +32,10 @@
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm,
mlir::iree_pydm::IREEPyDMDialect)
+DEFINE_C_API_PTR_METHODS(IREEPyDMSourceBundle, mlir::iree_pydm::SourceBundle)
+DEFINE_C_API_PTR_METHODS(IREEPyDMLoweringOptions,
+ mlir::iree_pydm::LowerToIREEOptions)
+
bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type) {
return unwrap(type).isa<mlir::iree_pydm::PrimitiveType>();
}
@@ -66,8 +73,40 @@
return wrap(mlir::iree_pydm::ObjectType::get(unwrap(ctx), cppType));
}
-void mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager) {
+void mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager,
+ IREEPyDMLoweringOptions options) {
auto *passManagerCpp = unwrap(passManager);
- // TODO: Should be a pass pipeline, not loose passes in the C impl.
- passManagerCpp->addPass(mlir::iree_pydm::createConvertIREEPyDMToIREEPass());
+ mlir::iree_pydm::buildLowerToIREEPassPipeline(*passManagerCpp,
+ *unwrap(options));
+}
+
+// SourceBundle
+IREEPyDMSourceBundle ireePyDMSourceBundleCreateAsm(MlirStringRef asmString) {
+ auto bundle = std::make_unique<mlir::iree_pydm::SourceBundle>();
+ bundle->asmBlob = std::make_shared<std::string>(unwrap(asmString));
+ return wrap(bundle.release());
+}
+
+IREEPyDMSourceBundle ireePyDMSourceBundleCreateFile(MlirStringRef filePath) {
+ auto bundle = std::make_unique<mlir::iree_pydm::SourceBundle>();
+ bundle->asmFilePath = std::string(unwrap(filePath));
+ return wrap(bundle.release());
+}
+
+void ireePyDMSourceBundleDestroy(IREEPyDMSourceBundle bundle) {
+ delete unwrap(bundle);
+}
+
+// LoweringOptions
+IREEPyDMLoweringOptions ireePyDMLoweringOptionsCreate() {
+ return wrap(new mlir::iree_pydm::LowerToIREEOptions);
+}
+
+void ireePyDMLoweringOptionsLinkRtl(IREEPyDMLoweringOptions options,
+ IREEPyDMSourceBundle source) {
+ unwrap(options)->linkRtlSource = *unwrap(source);
+}
+
+void ireePyDMLoweringOptionsDestroy(IREEPyDMLoweringOptions options) {
+ delete unwrap(options);
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
index 37db9a5..44fe3b2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
@@ -1,2 +1,15 @@
add_subdirectory(RTL)
add_subdirectory(ToIREE)
+
+add_mlir_library(IREEDialectsIREEPyDMPasses
+ Passes.cpp
+
+ DEPENDS
+ MLIRIREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEDialectsIREEPyDMRTLPasses
+ IREEDialectsIREEPyDMToIREEPasses
+)
+
+iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
new file mode 100644
index 0000000..caaaf14
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
@@ -0,0 +1,22 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+void mlir::iree_pydm::buildLowerToIREEPassPipeline(
+ OpPassManager& passManager, const LowerToIREEOptions& options) {
+ // TODO: Needs to be iterative, support optimization passes, etc.
+ passManager.addPass(createLowerIREEPyDMToRTLPass());
+ if (options.linkRtlSource) {
+ passManager.addPass(createLinkIREEPyDMRTLPass(options.linkRtlSource));
+ }
+ passManager.addPass(createConvertIREEPyDMToIREEPass());
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
index 9608edf..7803d65 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -35,22 +35,38 @@
namespace {
class LinkIREEPyDMRTLPass : public LinkIREEPyDMRTLBase<LinkIREEPyDMRTLPass> {
+ public:
+ LinkIREEPyDMRTLPass() = default;
+ LinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle)
+ : linkRtlSourceBundle(std::move(linkRtlSourceBundle)) {}
+
+ private:
LogicalResult initialize(MLIRContext *context) override {
- // Already initialized in some way.
- if (!rtlModule && rtlFile.empty()) {
+ SourceBundle localSource;
+ if (linkRtlSourceBundle) {
+ localSource = *linkRtlSourceBundle;
+ } else {
+ // Get it from the cli options.
+ localSource.asmFilePath = rtlFile;
+ }
+
+ if (localSource.asmBlob) {
+ // Parse from inline asm.
+ auto owningOp = parseSourceString(*localSource.asmBlob, context);
+ if (!owningOp) return failure();
+ rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+ } else if (localSource.asmFilePath) {
+ // Parse from a file.
+ auto owningOp = parseSourceFile(*localSource.asmFilePath, context);
+ if (!owningOp) return failure();
+ rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+ } else {
return emitError(UnknownLoc::get(context))
<< "pass " << getArgument()
<< "must be initialized with an RTL module (did you mean to "
"add an rtl-file option?)";
}
- if (!rtlFile.empty()) {
- // Parse from a file.
- auto owningOp = parseSourceFile(rtlFile, context);
- if (!owningOp) return failure();
- rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
- }
-
ModuleOp parentModule = rtlModule->get();
// Walk the module and build a SymbolTable for each sub-module.
parentModule->walk([&](ModuleOp importModule) {
@@ -187,11 +203,15 @@
// A SymbolTable for each sub module.
SmallVector<SymbolTable> importModules;
+
+ // ASM source of RTL modules to link (otherwise will use pass options).
+ Optional<SourceBundle> linkRtlSourceBundle;
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
-mlir::iree_pydm::createLinkIREEPyDMRTLPass() {
- return std::make_unique<LinkIREEPyDMRTLPass>();
+mlir::iree_pydm::createLinkIREEPyDMRTLPass(
+ Optional<SourceBundle> linkRtlSourceBundle) {
+ return std::make_unique<LinkIREEPyDMRTLPass>(std::move(linkRtlSourceBundle));
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
index 661a237..2d4d88d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -40,6 +40,16 @@
target.addLegalDialect<mlir::iree::IREEDialect>();
target.addLegalDialect<mlir::math::MathDialect>();
target.addLegalDialect<StandardOpsDialect>();
+
+ // Some CFG ops can be present in the original pydm program. Need to
+ // verify legality based on types.
+ target.addDynamicallyLegalOp<BranchOp>([&](BranchOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
+ target.addDynamicallyLegalOp<CondBranchOp>([&](CondBranchOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
return signalPassFailure();
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 6214709..9b3e1fd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -17,6 +17,7 @@
using namespace mlir;
using namespace mlir::iree_pydm;
+namespace arith_d = mlir;
namespace iree_d = mlir::iree;
namespace builtin_d = mlir;
namespace std_d = mlir;
@@ -101,6 +102,60 @@
return list;
}
+static Value castIntegerValue(Location loc, Value input,
+ builtin_d::IntegerType resultType,
+ OpBuilder &builder) {
+ builtin_d::IntegerType inputType =
+ input.getType().cast<builtin_d::IntegerType>();
+ if (inputType.getWidth() == resultType.getWidth()) {
+ return input;
+ } else if (inputType.getWidth() < resultType.getWidth()) {
+ return builder.create<arith_d::SignExtendIOp>(loc, resultType, input);
+ } else {
+ return builder.create<arith_d::TruncateIOp>(loc, resultType, input);
+ }
+}
+
+static Optional<arith_d::CmpIPredicate> convertIntegerComparePredicate(
+ StringAttr dunderName, bool isSigned, Builder &builder) {
+ StringRef v = dunderName.getValue();
+ if (v == "lt") {
+ return isSigned ? arith_d::CmpIPredicate::slt : arith_d::CmpIPredicate::ult;
+ } else if (v == "le") {
+ return isSigned ? arith_d::CmpIPredicate::sle : arith_d::CmpIPredicate::ule;
+ } else if (v == "eq" || v == "is") {
+ return arith_d::CmpIPredicate::eq;
+ } else if (v == "ne" || v == "isnot") {
+ return arith_d::CmpIPredicate::ne;
+ } else if (v == "gt") {
+ return isSigned ? arith_d::CmpIPredicate::sgt : arith_d::CmpIPredicate::ugt;
+ } else if (v == "ge") {
+ return isSigned ? arith_d::CmpIPredicate::sge : arith_d::CmpIPredicate::uge;
+ }
+
+ return {};
+}
+
+static Optional<arith_d::CmpFPredicate> convertFpComparePredicate(
+ StringAttr dunderName, Builder &builder) {
+ StringRef v = dunderName.getValue();
+ if (v == "lt") {
+ return arith_d::CmpFPredicate::OLT;
+ } else if (v == "le") {
+ return arith_d::CmpFPredicate::OLE;
+ } else if (v == "eq" || v == "is") {
+ return arith_d::CmpFPredicate::OEQ;
+ } else if (v == "ne" || v == "isnot") {
+ return arith_d::CmpFPredicate::ONE;
+ } else if (v == "gt") {
+ return arith_d::CmpFPredicate::OGT;
+ } else if (v == "ge") {
+ return arith_d::CmpFPredicate::OGE;
+ }
+
+ return {};
+}
+
namespace {
class AllocFreeVarOpConversion
@@ -120,6 +175,53 @@
}
};
+class ApplyCompareNumericConversion
+ : public OpConversionPattern<pydm_d::ApplyCompareOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::ApplyCompareOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type leftType = adaptor.left().getType();
+ Type rightType = adaptor.right().getType();
+ if (leftType != rightType) {
+ return rewriter.notifyMatchFailure(srcOp, "not same type operands");
+ }
+ if (leftType.isa<builtin_d::IntegerType>()) {
+ bool isSigned = true; // TODO: Unsigned.
+ auto predicate = convertIntegerComparePredicate(adaptor.dunder_name(),
+ isSigned, rewriter);
+ if (!predicate)
+ return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+ rewriter.replaceOpWithNewOp<arith_d::CmpIOp>(
+ srcOp, *predicate, adaptor.left(), adaptor.right());
+ return success();
+ } else if (leftType.isa<builtin_d::FloatType>()) {
+ auto predicate =
+ convertFpComparePredicate(adaptor.dunder_name(), rewriter);
+ if (!predicate)
+ return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+ rewriter.replaceOpWithNewOp<arith_d::CmpFOp>(
+ srcOp, *predicate, adaptor.left(), adaptor.right());
+ return success();
+ }
+
+ return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+ }
+};
+
+class BoolToPredConversion : public OpConversionPattern<pydm_d::BoolToPredOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::BoolToPredOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ pydm_d::BoolToPredOp::Adaptor adaptor(operands);
+ rewriter.replaceOp(srcOp, adaptor.value());
+ return success();
+ }
+};
+
class BoxOpConversion : public OpConversionPattern<pydm_d::BoxOp> {
using OpConversionPattern::OpConversionPattern;
@@ -140,6 +242,95 @@
}
};
+class BranchConversion : public OpConversionPattern<std_d::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::BranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
+ adaptor.destOperands());
+ return success();
+ }
+};
+
+class CallOpConversion : public OpConversionPattern<pydm_d::CallOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::CallOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ pydm_d::CallOp::Adaptor adaptor(operands);
+ SmallVector<Type> resultTypes;
+ if (failed(getTypeConverter()->convertTypes(srcOp.getResultTypes(),
+ resultTypes))) {
+ return rewriter.notifyMatchFailure(srcOp,
+ "result types could not be converted");
+ }
+ rewriter.replaceOpWithNewOp<std_d::CallOp>(srcOp, srcOp.callee(),
+ resultTypes, adaptor.operands());
+ return success();
+ }
+};
+
+class CondBranchConversion : public OpConversionPattern<std_d::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::CondBranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
+ srcOp, adaptor.condition(), srcOp.trueDest(),
+ adaptor.trueDestOperands(), srcOp.falseDest(),
+ adaptor.falseDestOperands());
+ return success();
+ }
+};
+
+class ConstantOpConversion : public OpConversionPattern<pydm_d::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::ConstantOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = srcOp.getLoc();
+ Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(
+ srcOp, "constant type could not be converted");
+ Attribute newValue = adaptor.value();
+ // Fixup widths of integer types that may be wider/narrower than the
+ // stored attribute (which tends to be stored in high precision in pydm
+ // constants).
+ TypeSwitch<Type>(resultType)
+ .Case([&](builtin_d::IntegerType t) {
+ APInt intValue =
+ newValue.cast<IntegerAttr>().getValue().sextOrTrunc(t.getWidth());
+ newValue = rewriter.getIntegerAttr(t, intValue);
+ })
+ .Case([&](builtin_d::FloatType t) {
+ APFloat fpValue = newValue.cast<FloatAttr>().getValue();
+ if (APFloat::SemanticsToEnum(fpValue.getSemantics()) !=
+ APFloat::SemanticsToEnum(t.getFloatSemantics())) {
+ // Convert.
+ APFloat newFpValue = fpValue;
+ bool losesInfo;
+ newFpValue.convert(t.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ if (losesInfo) {
+ emitWarning(loc) << "conversion of " << newValue << " to " << t
+ << " loses information";
+ }
+ newValue = rewriter.getFloatAttr(t, newFpValue);
+ }
+ });
+
+ if (!newValue)
+ return rewriter.notifyMatchFailure(
+ srcOp, "constant cannot be represented as a standard constant");
+ rewriter.replaceOpWithNewOp<std_d::ConstantOp>(srcOp, resultType, newValue);
+ return success();
+ }
+};
+
class FuncOpConversion : public OpConversionPattern<pydm_d::FuncOp> {
using OpConversionPattern::OpConversionPattern;
@@ -172,6 +363,7 @@
convertedResultTypes);
auto newFuncOp = rewriter.create<mlir::FuncOp>(
srcOp.getLoc(), srcOp.getName(), newFuncType);
+ newFuncOp.setVisibility(srcOp.getVisibility());
rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
@@ -187,6 +379,33 @@
}
};
+class GetTypeCodeConversion
+ : public OpConversionPattern<pydm_d::GetTypeCodeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::GetTypeCodeOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = srcOp.getLoc();
+ // Gets the 0'th element of the object list, optionally casting it to the
+ // converted integer type.
+ Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(srcOp,
+ "result type could not be converted");
+ Type i32Type = rewriter.getIntegerType(32);
+ Value index0 =
+ rewriter.create<std_d::ConstantOp>(loc, rewriter.getIndexAttr(0));
+ Value typeCode = rewriter.create<iree_d::ListGetOp>(
+ loc, i32Type, adaptor.value(), index0);
+ rewriter.replaceOp(
+ srcOp,
+ castIntegerValue(loc, typeCode,
+ resultType.cast<builtin_d::IntegerType>(), rewriter));
+ return success();
+ }
+};
+
class LoadVarOpConversion : public OpConversionPattern<pydm_d::LoadVarOp> {
using OpConversionPattern::OpConversionPattern;
@@ -237,8 +456,8 @@
auto loc = srcOp.getLoc();
Value status = operands[0];
- // Get the containing function return type so that we can create a suitable
- // null return value.
+ // Get the containing function return type so that we can create a
+ // suitable null return value.
auto parentFunc = srcOp->getParentOfType<builtin_d::FuncOp>();
if (!parentFunc)
return rewriter.notifyMatchFailure(srcOp, "not contained by a func");
@@ -390,10 +609,13 @@
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Structural.
- patterns.insert<AllocFreeVarOpConversion, BoxOpConversion, FuncOpConversion,
- LoadVarOpConversion, RaiseOnFailureOpConversion,
- ReturnOpConversion, StoreVarOpConversion, UnboxOpConversion>(
- typeConverter, context);
+ patterns.insert<AllocFreeVarOpConversion, ApplyCompareNumericConversion,
+ BoolToPredConversion, BoxOpConversion, BranchConversion,
+ CallOpConversion, CondBranchConversion, ConstantOpConversion,
+ FuncOpConversion, GetTypeCodeConversion, LoadVarOpConversion,
+ RaiseOnFailureOpConversion, ReturnOpConversion,
+ StoreVarOpConversion, UnboxOpConversion>(typeConverter,
+ context);
// Constants and constructors.
patterns.insert<NoneOpConversion>(typeConverter, context);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
index 9416b2a..0ee4fb7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -37,12 +37,23 @@
return getVariantListType(b);
});
+ // Bool.
+ addConversion([&](pydm_d::BoolType t) -> Optional<Type> {
+ return builtin_d::IntegerType::get(t.getContext(), 1);
+ });
+
// Integer type hierarchy.
addConversion([&](pydm_d::IntegerType t) -> Optional<Type> {
Builder b(t.getContext());
return getWeakIntegerType(b);
});
+ // Real type hierarchy.
+ addConversion([&](pydm_d::RealType t) -> Optional<Type> {
+ Builder b(t.getContext());
+ return getWeakFloatType(b);
+ });
+
// Variable references.
addConversion([](pydm_d::FreeVarRefType t) -> Optional<Type> {
// Just an object record.
@@ -73,3 +84,15 @@
return b.getF64Type();
}
}
+
+bool LoweringTypeConverter::isTypeLegal(Type t) const {
+ return t.isa<builtin_d::IntegerType, builtin_d::FloatType,
+ builtin_d::IndexType, iree_d::ListType>();
+}
+
+bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
+ for (Type t : types) {
+ if (!isTypeLegal(t)) return false;
+ }
+ return true;
+}
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 0b235ee..067c565 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -15,6 +15,36 @@
namespace py = pybind11;
using namespace mlir::python::adaptors;
+namespace {
+
+struct PyIREEPyDMSourceBundle {
+ PyIREEPyDMSourceBundle(IREEPyDMSourceBundle wrapped) : wrapped(wrapped) {}
+ PyIREEPyDMSourceBundle(PyIREEPyDMSourceBundle &&other)
+ : wrapped(other.wrapped) {
+ other.wrapped.ptr = nullptr;
+ }
+ PyIREEPyDMSourceBundle(const PyIREEPyDMSourceBundle &) = delete;
+ ~PyIREEPyDMSourceBundle() {
+ if (wrapped.ptr) ireePyDMSourceBundleDestroy(wrapped);
+ }
+ IREEPyDMSourceBundle wrapped;
+};
+
+struct PyIREEPyDMLoweringOptions {
+ PyIREEPyDMLoweringOptions() : wrapped(ireePyDMLoweringOptionsCreate()) {}
+ PyIREEPyDMLoweringOptions(PyIREEPyDMLoweringOptions &&other)
+ : wrapped(other.wrapped) {
+ other.wrapped.ptr = nullptr;
+ }
+ PyIREEPyDMLoweringOptions(const PyIREEPyDMLoweringOptions &) = delete;
+ ~PyIREEPyDMLoweringOptions() {
+ if (wrapped.ptr) ireePyDMLoweringOptionsDestroy(wrapped);
+ }
+ IREEPyDMLoweringOptions wrapped;
+};
+
+} // namespace
+
PYBIND11_MODULE(_ireeDialects, m) {
m.doc() = "iree-dialects main python extension";
@@ -63,6 +93,37 @@
//===--------------------------------------------------------------------===//
auto iree_pydm_m = m.def_submodule("iree_pydm");
+ py::class_<PyIREEPyDMSourceBundle>(
+ iree_pydm_m, "SourceBundle", py::module_local(),
+ "Contains raw assembly source or a reference to a file")
+ .def_static(
+ "from_asm",
+ [](std::string asmBlob) {
+ return PyIREEPyDMSourceBundle(ireePyDMSourceBundleCreateAsm(
+ {asmBlob.data(), asmBlob.size()}));
+ },
+ py::arg("asm_blob"),
+ "Creates a SourceBundle from an ASM blob (string or bytes)")
+ .def_static(
+ "from_file",
+ [](std::string asmFile) {
+ return PyIREEPyDMSourceBundle(ireePyDMSourceBundleCreateFile(
+ {asmFile.data(), asmFile.size()}));
+ },
+ py::arg("asm_file"),
+ "Creates a SourceBundle from a file containing ASM");
+ py::class_<PyIREEPyDMLoweringOptions>(iree_pydm_m, "LoweringOptions",
+ py::module_local(),
+ "Lowering options to compile to IREE")
+ .def(py::init<>())
+ .def(
+ "link_rtl",
+ [](PyIREEPyDMLoweringOptions &self,
+ PyIREEPyDMSourceBundle &sourceBundle) {
+ ireePyDMLoweringOptionsLinkRtl(self.wrapped, sourceBundle.wrapped);
+ },
+ "Enables linking against a runtime-library module");
+
iree_pydm_m.def(
"register_dialect",
[](MlirContext context, bool load) {
@@ -76,12 +137,13 @@
iree_pydm_m.def(
"build_lower_to_iree_pass_pipeline",
- [](MlirPassManager passManager) {
+ [](MlirPassManager passManager, PyIREEPyDMLoweringOptions &options) {
MlirOpPassManager opPassManager =
mlirPassManagerGetAsOpPassManager(passManager);
- mlirIREEPyDMBuildLowerToIREEPassPipeline(opPassManager);
+ mlirIREEPyDMBuildLowerToIREEPassPipeline(opPassManager,
+ options.wrapped);
},
- py::arg("pass_manager"));
+ py::arg("pass_manager"), py::arg("link_rtl_asm") = py::none());
#define DEFINE_IREEPYDM_NULLARY_TYPE(Name) \
mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
index 8eb41b1..d8bfd1c 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
@@ -361,8 +361,10 @@
f"implemented for this compiler")
-def create_context() -> ir.Context:
+def create_context(*, debug: bool = False) -> ir.Context:
context = ir.Context()
+ if debug:
+ context.enable_multithreading(False)
d.register_dialect(context)
return context
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
index 7bea5c5..1255fc0 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
@@ -3,3 +3,22 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import functools as _functools
+
+from typing import Sequence
+from .base import RtlBuilder, RtlModule
+
+
+def _get_std_rtl_modules() -> Sequence[RtlModule]:
+ from .modules import (
+ booleans,
+ numerics,
+ )
+ return [m.RTL_MODULE for m in (booleans, numerics)]
+
+
+STD_RTL_MODULES = _get_std_rtl_modules()
+
+# Source bundle for the standard RTL.
+get_std_rtl_source_bundle = RtlBuilder.lazy_build_source_bundle(STD_RTL_MODULES)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
index 240317a..ae23206 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
@@ -5,9 +5,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Base helpers for the RTL DSL."""
-from typing import List, Optional
+from typing import Callable, List, Optional, Sequence
import functools
+import threading
from ..importer import (
create_context,
@@ -18,8 +19,15 @@
ImportStage,
)
-from ... import builtin as builtin_d
-from .... import (ir, passmanager, transforms as unused_transforms)
+from ... import (
+ builtin as builtin_d,
+ iree_pydm as pydm_d,
+)
+from .... import (
+ ir,
+ passmanager,
+ transforms as unused_transforms,
+)
class RtlModule:
@@ -88,6 +96,36 @@
ir.Location.unknown(context=self.context))
self.module_op = self.root_module.operation
+ @staticmethod
+ def build_modules(rtl_modules: Sequence[RtlModule]) -> bytes:
+ """One shot build modules and return assembly."""
+ b = RtlBuilder()
+ b.emit_modules(rtl_modules)
+ b.optimize()
+ return b.root_module.operation.get_asm(binary=True, enable_debug_info=True)
+
+ @staticmethod
+ def lazy_build_source_bundle(
+ rtl_modules: Sequence[RtlModule]) -> Callable[[], pydm_d.SourceBundle]:
+ """Returns a function to lazily build RTL modules.
+
+ Modules will only be built once and cached for the life of the function.
+ Since RTL asm is typically passed unsafely to compiler passes, caching
+ forever is important.
+ """
+ rtl_modules = tuple(rtl_modules)
+ cache = []
+ lock = threading.Lock()
+
+ def get() -> pydm_d.SourceBundle:
+ with lock:
+ if not cache:
+ asm_blob = RtlBuilder.build_modules(rtl_modules)
+ cache.append(pydm_d.SourceBundle.from_asm(asm_blob))
+ return cache[0]
+
+ return get
+
def emit_module(self, rtl_module: RtlModule):
root_body = self.module_op.regions[0].blocks[0]
with ir.InsertionPoint(root_body), ir.Location.unknown():
@@ -101,6 +139,10 @@
# Getting the symbol implies exporting it into the module.
f.get_or_create_provided_func_symbol(stage)
+ def emit_modules(self, rtl_modules: Sequence[RtlModule]):
+ for rtl_module in rtl_modules:
+ self.emit_module(rtl_module)
+
def optimize(self):
"""Optimizes the RTL modules by running through stage 1 compilation."""
with self.context:
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir
new file mode 100644
index 0000000..d8fcb91
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-dialects-opt -split-input-file -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @none_constant
+iree_pydm.func @none_constant() -> (!iree_pydm.exception_result, !iree_pydm.none) {
+ // CHECK: %[[CST0:.*]] = constant 0 : i32
+ // CHECK: %[[CST1:.*]] = constant 0 : i32
+ // CHECK: return %[[CST1]], %[[CST0]]
+ %0 = none
+ return %0 : !iree_pydm.none
+}
+
+// CHECK-LABEL: @constant_integer_trunc
+iree_pydm.func @constant_integer_trunc() -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+ // CHECK: constant -10 : i32
+ %0 = constant -10 : i64 -> !iree_pydm.integer
+ return %0 : !iree_pydm.integer
+}
+
+// CHECK-LABEL: @constant_real_trunc
+iree_pydm.func @constant_real_trunc() -> (!iree_pydm.exception_result, !iree_pydm.real) {
+ // CHECK: constant -2.000000e+00 : f32
+ %0 = constant -2.0 : f64 -> !iree_pydm.real
+ return %0 : !iree_pydm.real
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir
new file mode 100644
index 0000000..dfd387c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-dialects-opt -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @lt
+iree_pydm.func @lt(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi slt, %arg0, %arg1 : i32
+ %0 = apply_compare "lt", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @le
+iree_pydm.func @le(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi sle, %arg0, %arg1 : i32
+ %0 = apply_compare "le", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @eq
+iree_pydm.func @eq(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi eq, %arg0, %arg1 : i32
+ %0 = apply_compare "eq", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @is
+iree_pydm.func @is(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi eq, %arg0, %arg1 : i32
+ %0 = apply_compare "is", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ne
+iree_pydm.func @ne(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi ne, %arg0, %arg1 : i32
+ %0 = apply_compare "ne", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @isnot
+iree_pydm.func @isnot(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi ne, %arg0, %arg1 : i32
+ %0 = apply_compare "isnot", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @gt
+iree_pydm.func @gt(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi sgt, %arg0, %arg1 : i32
+ %0 = apply_compare "gt", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ge
+iree_pydm.func @ge(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpi sge, %arg0, %arg1 : i32
+ %0 = apply_compare "ge", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir
new file mode 100644
index 0000000..83601c1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-dialects-opt -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @lt
+iree_pydm.func @lt(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf olt, %arg0, %arg1 : f32
+ %0 = apply_compare "lt", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @le
+iree_pydm.func @le(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf ole, %arg0, %arg1 : f32
+ %0 = apply_compare "le", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @eq
+iree_pydm.func @eq(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf oeq, %arg0, %arg1 : f32
+ %0 = apply_compare "eq", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @is
+iree_pydm.func @is(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf oeq, %arg0, %arg1 : f32
+ %0 = apply_compare "is", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ne
+iree_pydm.func @ne(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf one, %arg0, %arg1 : f32
+ %0 = apply_compare "ne", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @isnot
+iree_pydm.func @isnot(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf one, %arg0, %arg1 : f32
+ %0 = apply_compare "isnot", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @gt
+iree_pydm.func @gt(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf ogt, %arg0, %arg1 : f32
+ %0 = apply_compare "gt", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ge
+iree_pydm.func @ge(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+ // CHECK: %[[R:.*]] = cmpf oge, %arg0, %arg1 : f32
+ %0 = apply_compare "ge", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+ // CHECK: return {{.*}}, %[[R]]
+ return %0 : !iree_pydm.bool
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
index 35566e8..46cafc7 100644
--- a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
@@ -1,12 +1,27 @@
// RUN: iree-dialects-opt -split-input-file -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
-// CHECK-LABEL: @none_constant
-iree_pydm.func @none_constant() -> (!iree_pydm.exception_result, !iree_pydm.none) {
- // CHECK: %[[CST0:.*]] = constant 0 : i32
- // CHECK: %[[CST1:.*]] = constant 0 : i32
- // CHECK: return %[[CST1]], %[[CST0]]
+// CHECK-LABEL: @bool_to_pred
+// NOTE: Also tests cond_br conversion.
+iree_pydm.func @bool_to_pred(%arg0 : !iree_pydm.bool) -> (!iree_pydm.exception_result, !iree_pydm.none) {
+ %0 = bool_to_pred %arg0
+ %1 = none
+ // CHECK: cond_br %arg0
+ cond_br %0, ^bb1, ^bb2
+^bb1:
+ return %1 : !iree_pydm.none
+^bb2:
+ return %1 : !iree_pydm.none
+}
+
+// -----
+// CHECK-LABEL: @br
+iree_pydm.func @br() -> (!iree_pydm.exception_result, !iree_pydm.none) {
%0 = none
- return %0 : !iree_pydm.none
+ // CHECK: br ^bb1({{.*}} : i32)
+ br ^bb1(%0 : !iree_pydm.none)
+ // CHECK: ^bb1(%0: i32):
+^bb1(%1 : !iree_pydm.none):
+ return %1 : !iree_pydm.none
}
// -----
@@ -109,3 +124,25 @@
raise_on_failure %arg0 : !iree_pydm.exception_result
return %arg1 : !iree_pydm.integer
}
+
+// -----
+// CHECK-LABEL: @call_and_visibility
+iree_pydm.func @call_and_visibility(%arg0 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+ // CHECK: %[[R:.*]]:2 = call @callee(%arg0) : (i32) -> (i32, i32)
+ %0:2 = call @callee(%arg0) : (!iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer)
+ return %0#1 : !iree_pydm.integer
+}
+
+// CHECK: func private @callee
+iree_pydm.func private @callee(%arg0 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+ return %arg0 : !iree_pydm.integer
+}
+
+// -----
+// CHECK-LABEL: @get_type_code
+iree_pydm.func @get_type_code(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+ // CHECK: %[[c0:.*]] = constant 0 : index
+ // CHECK: %[[R:.*]] = iree.list.get %arg0[%[[c0]]] : !iree.list<!iree.variant> -> i32
+ %0 = get_type_code %arg0 : !iree_pydm.object
+ return %0 : !iree_pydm.integer
+}
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py
new file mode 100644
index 0000000..9337da9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py
@@ -0,0 +1,6 @@
+# RUN: %PYTHON %s
+
+from iree.compiler.dialects.iree_pydm import rtl
+
+# Ensures that we can compile the standard library to a SourceModule.
+print(rtl.get_std_rtl_source_bundle())