模式重写与规范化框架
上一章讲了 Pass 的外层——如何注册、调度和调试。这一章进入 Pass 的”心脏”——模式重写(Pattern Rewrite)。这是 MLIR 中最常用、最强大的 IR 变换机制。
一、模式重写的核心思想
模式重写的模型很简单:匹配(match)一个 IR
子图,如果匹配成功,将其替换(rewrite)为更优的等价
IR。这个”如果匹配则替换”的过程在 MLIR 中被抽象为
RewritePattern。
输入 IR
│
▼
[遍历所有 Op]
│
├──→ 尝试 Pattern 1 匹配 → 匹配成功 → 替换
├──→ 尝试 Pattern 2 匹配 → 匹配失败 → 跳过
├──→ 尝试 Pattern 3 匹配 → 匹配成功 → 替换
...
│
▼
变换后的 IR
模式重写的威力在于可组合性:一个 Pass
可以注册多个 Pattern,Pattern 之间互相独立。一次 traversal
可能触发多次 rewrite——每次 rewrite 后的 IR
是新的,可能会触发新的匹配。这种”迭代直到不动点”的行为由
GreedyPatternRewriteDriver 管理。
二、写一个 RewritePattern
以将 arith.subi %x, %x 替换为
arith.constant 0 为例:
// SimplifySubSameOperands.h
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
struct SimplifySubSameOperands : public OpRewritePattern<arith::SubIOp> {
using OpRewritePattern<arith::SubIOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arith::SubIOp op,
PatternRewriter &rewriter) const override {
// 获取操作数
Value lhs = op.getLhs();
Value rhs = op.getRhs();
// 检查两个操作数是否相同
if (lhs != rhs)
return failure(); // 不匹配,跳过
// 匹配成功——创建替换
auto constZero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
op.getType(),
rewriter.getIntegerAttr(op.getType(), 0));
// 替换所有使用且删除原 Op
rewriter.replaceOp(op, constZero);
return success(); // 通知框架:IR 已修改
}
};matchAndRewrite 的协议:
- 检查匹配条件:如果 op 不满足模式,返回
failure()——框架继续尝试下一个 Op 和 Pattern。 - 创建新 IR:用
PatternRewriter(不是OpBuilder)创建新 Op——PatternRewriter在创建时会通知框架 IR 正在被修改。 - 替换和清理:
rewriter.replaceOp(oldOp, newValues)替换旧 Op 的结果值的使用,并计划删除旧 Op。删除是延迟的——在当前遍历轮次结束后批量执行。
三、PatternRewriter 的关键 API
| 方法 | 作用 |
|---|---|
create<OpType>(loc, args...) |
创建一个新 Op |
replaceOp(Operation *op, ValueRange newValues) |
替换 Op 所有结果的使用并删除 Op |
replaceOpWithNewOp<OpType>(Operation *op, args...) |
创建新 Op 并替换旧 Op(一步操作) |
eraseOp(Operation *op) |
直接删除 Op(不替换值,适用无结果的 Op) |
replaceAllUsesWith(Value from, Value to) |
替换单个值的所有使用 |
inlineRegionBefore(Region &src, Region &dest, ...) |
将一个 Region 内联到另一个 Region |
mergeBlocks(Block *src, Block *dest, ...) |
合并两个 Block |
cloneRegionBefore(Region &src, ...) |
克隆一个 Region |
四、Pattern Benefit:控制模式应用顺序
每个 Pattern 都有一个 benefit 值(默认
1)。当多个 Pattern 同时匹配同一个 Op 时,benefit 更高的
Pattern 先应用:
struct SimplifyAddWithZero : public OpRewritePattern<arith::AddIOp> {
using OpRewritePattern<arith::AddIOp>::OpRewritePattern;
SimplifyAddWithZero(MLIRContext *ctx)
: OpRewritePattern<arith::AddIOp>(ctx, /*benefit=*/2) {} // 更高优先级
LogicalResult matchAndRewrite(arith::AddIOp op,
PatternRewriter &rewriter) const override {
// 实现...
}
};benefit 的典型用途:规范化 Pattern 比通用优化 Pattern 的 benefit 更高。例如常数折叠的 benefit 设为 10,因为它总是正确且几乎零成本;但更激进的优化(如循环展开)benefit 设为 1。
五、规范化(Canonicalization)
规范化是模式重写的标准应用——将 IR 转换为”规范形式”(canonical form),使得后续 Pass 看到的是 IR 的最简、最标准表示。
每个 Op 可以通过实现
getCanonicalizationPatterns() 注册它的规范化
Pattern:
// MyOps.cpp
void MyAddOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// 注册此 Op 的规范化规则
patterns.add<SimplifyAddSameOperands>(context);
patterns.add<FoldAddWithZero>(context);
patterns.add<ReorderConstantsToRight>(context);
}然后在 Pass 中调用规范化:
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
// 从所有已加载的方言中收集规范化 Pattern
for (auto *dialect : getContext().getLoadedDialects())
dialect->getCanonicalizationPatterns(patterns);
for (RegisteredOperationName op : getContext().getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, &getContext());
// 应用规范化直到不动点
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns))))
signalPassFailure();
}MLIR 内建的 -canonicalize Pass
就是上述逻辑的封装——它会加载所有方言的规范化 Pattern
并贪婪迭代到不动点。
六、折叠(Fold)与常量传播
fold() 方法是规范化的一个特例——它将 Op
的所有操作数都是常量的情况简化为单个常量。
OpFoldResult MyAddOp::fold(FoldAdaptor adaptor) {
// adaptor 提供操作数——如果所有操作数都是常量,则 adaptor 包含常量值
// 否则 adaptor.getLhs() 返回 null
// 情况 1:所有操作数都是常量
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
if (lhs && rhs) {
// 编译期计算——返回折叠后的常量值
return IntegerAttr::get(getType(), lhs.getInt() + rhs.getInt());
}
// 情况 2:x + 0 = x
if (matchPattern(getRhs(), m_Zero()))
return getLhs(); // 返回已有的 Value
// 情况 3:x + x = 2 * x
if (getLhs() == getRhs())
return nullptr; // 这里不处理,留给 SimplifyAddSameOperands Pattern
return nullptr; // 不能折叠
}matchPattern + Matchers
组合用于匹配操作数的模式:
// 匹配常量值
if (matchPattern(op.getRhs(), m_Zero())) { ... }
// 匹配特定整数
if (matchPattern(op.getRhs(), m_SpecificInt(42))) { ... }
// 组合:匹配 x + 0 或 0 + x
if (matchPattern(op.getRhs(), m_Zero()) || matchPattern(op.getLhs(), m_Zero())) { ... }七、GreedyPatternRewriteDriver 的工作机制
applyPatternsAndFoldGreedily 是 MLIR
中最常用的模式重写驱动器。它的工作流程:
1. worklist ← IR 中的所有 Op
2. while worklist 非空:
3. 从 worklist 中取出一个 Op
4. 对该 Op 尝试所有已注册的 Pattern(按 benefit 降序)
5. if 某个 Pattern 匹配成功:
6. 执行 rewrite(替换/创建/删除 Op)
7. 将被修改的 Op 及其附近的 Op 重新加入 worklist
8. break → 继续外层 while 循环
9. else:
10. 直接处理 worklist 中的下一个 Op
关键特性:
- 贪婪(Greedy):每次匹配成功立即应用,不进行回溯。
- 迭代到不动点:rewrite 产生的新 IR 可能触发新的匹配——worklist 机制保证循环持续。
- 区域敏感:修改只影响被修改 Op 的局部区域和父区域,不影响 sibling 区域。
- 收敛保证:每个有效 Pattern 应该缩小 IR(减少 Op 数量或简化操作),否则可能无限循环。MLIR 提供一个内部计数器防止真正的无限循环(超过阈值后报错)。
八、Pattern 编写最佳实践
8.1 确保 Pattern 是收敛的
坏 Pattern(可能死循环):
// 将 A→B,又将 B→A——无限循环
struct BadPattern1 : OpRewritePattern<AOp> {
LogicalResult matchAndRewrite(AOp op, PatternRewriter &r) const override {
r.replaceOpWithNewOp<BOp>(op, op.getOperand()); // A → B
return success();
}
};
struct BadPattern2 : OpRewritePattern<BOp> {
LogicalResult matchAndRewrite(BOp op, PatternRewriter &r) const override {
r.replaceOpWithNewOp<AOp>(op, op.getOperand()); // B → A
return success();
}
};确保每个 rewrite 导向一个”更规范”的形式——一个严格偏序。例如:常量折叠(减少 Op 数量)、降低操作复杂度。
8.2
使用 hasCanonicalizer 和 hasFolder
而非在 Pass 中手写
hasCanonicalizer = 1 和
hasFolder = 1 在 ODS
定义中使得规范化规则可以通过方言注册被发现,而不是隐式地嵌入在某个特定
Pass 中——这保证了规范化在所有使用该方言的 Pass
中都被统一应用。
8.3 将复杂 Pattern 拆分成多个简单 Pattern
一个 matchAndRewrite
只做一件事——匹配一个具体的 IR 模式并替换。不要在一个 Pattern
中处理五种不同情况。每个 Pattern 的 benefit
可以不同,让框架根据优先级选择应用顺序。
8.4 检查操作数是否来自不同的作用域
当 Pattern 替换跨越多个 Region 边界的 Op 时要小心——确保替换不违反 SSA 支配关系:
// 如果 op 的操作数来自外层 Region,替换必须保持这个关系
if (op.getOperand(0).getParentRegion() != op.getParentRegion()) {
// 处理跨 Region 依赖
}九、方言转换:模式重写的下一个层级
模式重写处理的是同一方言内部或等价方言间的局部替换。下一章讲方言转换(Dialect Conversion)——将一种方言系统地、整体地降阶为另一种方言的框架。这是 MLIR “渐进降阶”的工程核心。
参考资料
官方文档(A 级)
- MLIR Pattern Rewrite — https://mlir.llvm.org/docs/PatternRewriter/
- MLIR Canonicalization — https://mlir.llvm.org/docs/Canonicalization/
源码(A 级)
mlir/include/mlir/IR/PatternMatch.hmlir/include/mlir/Transforms/GreedyPatternRewriteDriver.hmlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【编译器与 MLIR】Pass 管理与分析
详解 MLIR 的 Pass 基础设施:OperationPass 与 ModulePass 的分类与适用场景、Pass 依赖管理与流水线构建、Pass 选项系统、多线程执行模型,以及 mlir-opt 的调试命令。
【编译器与 MLIR】编译器的挑战与 IR 的裂变
从三阶段编译器局限出发,串联 Halide、XLA、TVM 的 IR 裂变,说明 DSA 与 AI 编译器为何需要 MLIR 这类可组合的多层 IR 框架。
【编译器与 MLIR】MLIR 全景图与设计哲学
从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。
【编译器与 MLIR】环境搭建与第一个 MLIR 程序
从零构建 LLVM/MLIR 工程,用 mlir-opt 理解 .mlir 文本表示,运行规范化 Pass 并逐行解读转换结果,建立从命令行到 IR 变换的直觉。