Clean up encoding-related code. NFC. (#19717)
Fixing misc issues before modifying the surrounding code.
Signed-off-by: Jakub Kuderski <jakub@nod-labs.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
index 56e1f16..9521a8f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
@@ -5,30 +5,27 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
-#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#define DEBUG_TYPE "iree-codegen--materialize-encoding"
+#define DEBUG_TYPE "iree-codegen-materialize-encoding"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
@@ -150,9 +147,8 @@
return executableTargetAttrs;
}
-struct MaterializeHostEncodingPass
- : public impl::MaterializeHostEncodingPassBase<
- MaterializeHostEncodingPass> {
+struct MaterializeHostEncodingPass final
+ : impl::MaterializeHostEncodingPassBase<MaterializeHostEncodingPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect,
IREE::Codegen::IREECodegenDialect,
@@ -160,7 +156,7 @@
}
void runOnOperation() override {
- auto moduleOp = getOperation();
+ ModuleOp moduleOp = getOperation();
// Run required analysis passes.
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
@@ -211,11 +207,9 @@
// that. It should _not_ be running on both - target-specific codegen passes
// are not allowed on host programs and it's a big violation of layering that
// this exists.
-struct MaterializeDeviceEncodingPass
- : public impl::MaterializeDeviceEncodingPassBase<
- MaterializeDeviceEncodingPass> {
- using impl::MaterializeDeviceEncodingPassBase<
- MaterializeDeviceEncodingPass>::MaterializeDeviceEncodingPassBase;
+struct MaterializeDeviceEncodingPass final
+ : impl::MaterializeDeviceEncodingPassBase<MaterializeDeviceEncodingPass> {
+ using Base::Base;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect,
@@ -224,7 +218,7 @@
}
void runOnOperation() override {
- auto funcOp = getOperation();
+ FunctionOpInterface funcOp = getOperation();
auto executableTargetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr,
testCLGPUTarget))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
index 73ffd1d..8c757c5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp
@@ -9,8 +9,6 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/PassManager.h"
@@ -36,7 +34,7 @@
void runOnOperation() override {
MLIRContext *context = &getContext();
- auto operation = getOperation();
+ FunctionOpInterface operation = getOperation();
auto materializeEncodingValueFn =
[](RankedTensorType, OpBuilder &,
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index 20444ea..c4940c1 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -21,7 +21,6 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -87,14 +86,14 @@
if (!yieldOp) {
return false;
}
- auto addOp = yieldOp->getOperand(0).getDefiningOp();
+ Operation *addOp = yieldOp->getOperand(0).getDefiningOp();
if (!addOp || !isa<arith::AddIOp, arith::AddFOp>(addOp)) {
return false;
}
- auto addLhs = addOp->getOperand(0);
- auto addRhs = addOp->getOperand(1);
- auto addLhsOp = addLhs.getDefiningOp();
- auto addRhsOp = addRhs.getDefiningOp();
+ Value addLhs = addOp->getOperand(0);
+ Value addRhs = addOp->getOperand(1);
+ Operation *addLhsOp = addLhs.getDefiningOp();
+ Operation *addRhsOp = addRhs.getDefiningOp();
if (!(addLhsOp && addRhs == outBlockArg) &&
!(addRhsOp && addLhs == outBlockArg)) {
return false;
@@ -103,8 +102,8 @@
if (!isa<arith::MulFOp, arith::MulIOp>(mulOp)) {
return false;
}
- auto mulLhs = mulOp->getOperand(0);
- auto mulRhs = mulOp->getOperand(1);
+ Value mulLhs = mulOp->getOperand(0);
+ Value mulRhs = mulOp->getOperand(1);
auto mulLhsOp = mulLhs.getDefiningOp<CastOpInterface>();
auto mulRhsOp = mulRhs.getDefiningOp<CastOpInterface>();
if (!isa<BlockArgument>(mulLhs) && !mulLhsOp && !isa<BlockArgument>(mulRhs) &&
@@ -155,11 +154,11 @@
namespace {
-class setContractionOpEncoding
+class SetContractionOpEncoding final
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
public:
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- explicit setContractionOpEncoding(MLIRContext *ctx, int64_t factor)
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+ explicit SetContractionOpEncoding(MLIRContext *ctx, int64_t factor)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), padFactor(factor) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
@@ -244,8 +243,8 @@
/// Pattern to fold a `linalg.fill` -> `iree_encoding.set_encoding`
/// operation into a `linalg.fill` of the encoded type.
struct FoldFillWithSetEncoding final
- : public OpRewritePattern<IREE::Encoding::SetEncodingOp> {
- using OpRewritePattern<IREE::Encoding::SetEncodingOp>::OpRewritePattern;
+ : OpRewritePattern<IREE::Encoding::SetEncodingOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp,
PatternRewriter &rewriter) const override {
@@ -267,15 +266,14 @@
}
};
-struct SetEncodingPass final
- : public impl::SetEncodingPassBase<SetEncodingPass> {
+struct SetEncodingPass final : impl::SetEncodingPassBase<SetEncodingPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.insert<setContractionOpEncoding>(context, padFactor);
+ patterns.add<SetContractionOpEncoding>(context, padFactor);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
- patterns.insert<FoldFillWithSetEncoding>(context);
+ patterns.add<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
index f7aeb82..b2bc426 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
@@ -8,17 +8,9 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/GlobalOptimization/Passes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Utils/PassUtils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
namespace mlir::iree_compiler::GlobalOptimization {
@@ -27,7 +19,7 @@
// path. This is mainly for testing.
static llvm::cl::opt<bool> clEnableExperimentalRocmDataTiling(
"iree-global-opt-experimental-rocm-data-tiling",
- llvm::cl::desc("Enables data-tiling materializatino for rocm backends "
+ llvm::cl::desc("Enables data-tiling materialization for rocm backends "
"(experimental)."),
llvm::cl::init(false));
@@ -38,10 +30,9 @@
MultiOpNest<IREE::Util::InitializerOp, IREE::Util::FuncOp>;
namespace {
-class MaterializeHomogeneousEncodingsPass
- : public impl::MaterializeHomogeneousEncodingsPassBase<
+struct MaterializeHomogeneousEncodingsPass final
+ : impl::MaterializeHomogeneousEncodingsPassBase<
MaterializeHomogeneousEncodingsPass> {
-public:
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::HAL::HALDialect, tensor::TensorDialect,
IREE::Codegen::IREECodegenDialect>();
@@ -72,7 +63,7 @@
// TODO: vmvx has its own logic about supporting dynamic tile
// sizes. It is not fully integrated into the pipeline, so we remain the
// materialization to the end.
- auto executableTarget = executableTargets[0];
+ IREE::HAL::ExecutableTargetAttr executableTarget = executableTargets[0];
if (executableTarget.getBackend() == "vmvx") {
return;
}