listener based approach
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 012a8f1..91ccda3 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -761,6 +761,26 @@
RewriterBase(const RewriterBase &) = delete;
};
+class DiscardableAttributeConverter : public RewriterBase::Listener {
+public:
+ using DiscardableAttributeConverterFn =
+ std::function<LogicalResult(Operation *, Operation *)>;
+
+ DiscardableAttributeConverter(
+ RewriterBase &rewriter,
+ ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters)
+ : rewriter(rewriter),
+ dicardableAttributeConverters(dicardableAttributeConverters) {}
+
+protected:
+ void notifyOperationErased(Operation *op) override;
+
+ void notifyOperationReplaced(Operation *op, Operation *replacement) override;
+
+ RewriterBase &rewriter;
+ ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters;
+};
+
//===----------------------------------------------------------------------===//
// IRRewriter
//===----------------------------------------------------------------------===//
@@ -790,15 +810,7 @@
/// place.
class PatternRewriter : public RewriterBase {
public:
- using DiscardableAttributeConverterFn =
- std::function<LogicalResult(Operation *, Operation *)>;
-
explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
- PatternRewriter(
- MLIRContext *ctx,
- ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters)
- : RewriterBase(ctx),
- dicardableAttributeConverters(dicardableAttributeConverters) {}
using RewriterBase::RewriterBase;
/// A hook used to indicate if the pattern rewriter can recover from failure
@@ -806,27 +818,6 @@
/// rewriter supports rollback, it may progress smoothly even if IR was
/// changed during the rewrite.
virtual bool canRecoverFromRewriteFailure() const { return false; }
-
- /// Erase an operation that is known to have no uses. If this pattern
- /// rewriter has attribute converters, asserts the op (and its nested ops)
- /// has no discardable attributes.
- void eraseOp(Operation *op) override;
-
- /// Replace the results of the given (original) operation with the specified
- /// new op (replacement). The result types of the two ops must match. The
- /// original op is erased.
- ///
- /// If the original op has discardable attributes, try to run an attribute
- /// converter.
- void replaceOp(Operation *op, Operation *newOp) override;
- using RewriterBase::replaceOp;
-
-protected:
- ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters;
-
- bool hasAttributeConverter() const {
- return !dicardableAttributeConverters.empty();
- }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 2e3aed9..110b4f6 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -14,7 +14,6 @@
#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
namespace mlir {
@@ -99,9 +98,6 @@
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool cseConstants = true;
-
- SmallVector<PatternRewriter::DiscardableAttributeConverterFn>
- dicardableAttributeConverters;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 6826c3a..9507e7a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -425,37 +425,34 @@
}
//===----------------------------------------------------------------------===//
-// PatternRewriter
+// DiscardableAttributeConverter
//===----------------------------------------------------------------------===//
-void PatternRewriter::eraseOp(Operation *op) {
- if (hasAttributeConverter()) {
- op->walk([](Operation *op) {
- assert(op->getDiscardableAttrs().empty() &&
- "attempting to drop discardable attribute");
- });
- }
- RewriterBase::eraseOp(op);
+void DiscardableAttributeConverter::notifyOperationErased(Operation *op) {
+ op->walk([](Operation *op) {
+ assert(op->getDiscardableAttrs().empty() &&
+ "attempting to drop discardable attribute");
+ });
}
-void PatternRewriter::replaceOp(Operation *oldOp, Operation *newOp) {
- if (hasAttributeConverter() && !oldOp->getDiscardableAttrs().empty()) {
- startOpModification(oldOp);
- startOpModification(newOp);
+void DiscardableAttributeConverter::notifyOperationReplaced(Operation *oldOp,
+ Operation *newOp) {
+ if (!oldOp->getDiscardableAttrs().empty()) {
+ rewriter.startOpModification(oldOp);
+ rewriter.startOpModification(newOp);
bool success = false;
for (DiscardableAttributeConverterFn fn :
llvm::reverse(dicardableAttributeConverters)) {
if (succeeded(fn(oldOp, newOp))) {
success = true;
- finalizeOpModification(oldOp);
- finalizeOpModification(newOp);
+ rewriter.finalizeOpModification(oldOp);
+ rewriter.finalizeOpModification(newOp);
break;
}
}
if (!success) {
- cancelOpModification(oldOp);
- cancelOpModification(newOp);
+ rewriter.cancelOpModification(oldOp);
+ rewriter.cancelOpModification(newOp);
}
}
- RewriterBase::replaceOp(oldOp, newOp);
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 56311ed..fe84c61 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -411,8 +411,7 @@
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : rewriter(ctx, config.dicardableAttributeConverters), config(config),
- matcher(patterns)
+ : rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(