土法炼钢兴趣小组的算法知识备份

【编译器与 MLIR】模式重写与规范化框架

文章导航

分类入口
compilerarchitecture
标签入口
#mlir#llvm#compiler#pattern-rewrite#canonicalization#fold#optimization#greedy-rewrite

目录

模式重写与规范化框架

上一章讲了 Pass 的外层——如何注册、调度和调试。这一章进入 Pass 的”心脏”——模式重写(Pattern Rewrite)。这是 MLIR 中最常用、最强大的 IR 变换机制。

一、模式重写的核心思想

模式重写的模型很简单:匹配(match)一个 IR 子图,如果匹配成功,将其替换(rewrite)为更优的等价 IR。这个”如果匹配则替换”的过程在 MLIR 中被抽象为 RewritePattern

输入 IR
  │
  ▼
[遍历所有 Op]
  │
  ├──→ 尝试 Pattern 1 匹配 → 匹配成功 → 替换
  ├──→ 尝试 Pattern 2 匹配 → 匹配失败 → 跳过
  ├──→ 尝试 Pattern 3 匹配 → 匹配成功 → 替换
  ...
  │
  ▼
变换后的 IR

模式重写的威力在于可组合性:一个 Pass 可以注册多个 Pattern,Pattern 之间互相独立。一次 traversal 可能触发多次 rewrite——每次 rewrite 后的 IR 是新的,可能会触发新的匹配。这种”迭代直到不动点”的行为由 GreedyPatternRewriteDriver 管理。

二、写一个 RewritePattern

以将 arith.subi %x, %x 替换为 arith.constant 0 为例:

// SimplifySubSameOperands.h
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

struct SimplifySubSameOperands : public OpRewritePattern<arith::SubIOp> {
  using OpRewritePattern<arith::SubIOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(arith::SubIOp op,
                                 PatternRewriter &rewriter) const override {
    // 获取操作数
    Value lhs = op.getLhs();
    Value rhs = op.getRhs();

    // 检查两个操作数是否相同
    if (lhs != rhs)
      return failure();   // 不匹配,跳过

    // 匹配成功——创建替换
    auto constZero = rewriter.create<arith::ConstantOp>(
        op.getLoc(),
        op.getType(),
        rewriter.getIntegerAttr(op.getType(), 0));

    // 替换所有使用且删除原 Op
    rewriter.replaceOp(op, constZero);

    return success();     // 通知框架:IR 已修改
  }
};

matchAndRewrite 的协议:

  1. 检查匹配条件:如果 op 不满足模式,返回 failure()——框架继续尝试下一个 Op 和 Pattern。
  2. 创建新 IR:用 PatternRewriter(不是 OpBuilder)创建新 Op——PatternRewriter 在创建时会通知框架 IR 正在被修改。
  3. 替换和清理rewriter.replaceOp(oldOp, newValues) 替换旧 Op 的结果值的使用,并计划删除旧 Op。删除是延迟的——在当前遍历轮次结束后批量执行。

三、PatternRewriter 的关键 API

方法 作用
create<OpType>(loc, args...) 创建一个新 Op
replaceOp(Operation *op, ValueRange newValues) 替换 Op 所有结果的使用并删除 Op
replaceOpWithNewOp<OpType>(Operation *op, args...) 创建新 Op 并替换旧 Op(一步操作)
eraseOp(Operation *op) 直接删除 Op(不替换值,适用无结果的 Op)
replaceAllUsesWith(Value from, Value to) 替换单个值的所有使用
inlineRegionBefore(Region &src, Region &dest, ...) 将一个 Region 内联到另一个 Region
mergeBlocks(Block *src, Block *dest, ...) 合并两个 Block
cloneRegionBefore(Region &src, ...) 克隆一个 Region

四、Pattern Benefit:控制模式应用顺序

每个 Pattern 都有一个 benefit 值(默认 1)。当多个 Pattern 同时匹配同一个 Op 时,benefit 更高的 Pattern 先应用:

struct SimplifyAddWithZero : public OpRewritePattern<arith::AddIOp> {
  using OpRewritePattern<arith::AddIOp>::OpRewritePattern;
  SimplifyAddWithZero(MLIRContext *ctx)
    : OpRewritePattern<arith::AddIOp>(ctx, /*benefit=*/2) {}  // 更高优先级

  LogicalResult matchAndRewrite(arith::AddIOp op,
                                 PatternRewriter &rewriter) const override {
    // 实现...
  }
};

benefit 的典型用途:规范化 Pattern 比通用优化 Pattern 的 benefit 更高。例如常数折叠的 benefit 设为 10,因为它总是正确且几乎零成本;但更激进的优化(如循环展开)benefit 设为 1。

五、规范化(Canonicalization)

规范化是模式重写的标准应用——将 IR 转换为”规范形式”(canonical form),使得后续 Pass 看到的是 IR 的最简、最标准表示。

每个 Op 可以通过实现 getCanonicalizationPatterns() 注册它的规范化 Pattern:

// MyOps.cpp
void MyAddOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                           MLIRContext *context) {
  // 注册此 Op 的规范化规则
  patterns.add<SimplifyAddSameOperands>(context);
  patterns.add<FoldAddWithZero>(context);
  patterns.add<ReorderConstantsToRight>(context);
}

然后在 Pass 中调用规范化:

void runOnOperation() override {
  RewritePatternSet patterns(&getContext());
  // 从所有已加载的方言中收集规范化 Pattern
  for (auto *dialect : getContext().getLoadedDialects())
    dialect->getCanonicalizationPatterns(patterns);
  for (RegisteredOperationName op : getContext().getRegisteredOperations())
    op.getCanonicalizationPatterns(patterns, &getContext());

  // 应用规范化直到不动点
  if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                          std::move(patterns))))
    signalPassFailure();
}

MLIR 内建的 -canonicalize Pass 就是上述逻辑的封装——它会加载所有方言的规范化 Pattern 并贪婪迭代到不动点。

六、折叠(Fold)与常量传播

fold() 方法是规范化的一个特例——它将 Op 的所有操作数都是常量的情况简化为单个常量。

OpFoldResult MyAddOp::fold(FoldAdaptor adaptor) {
  // adaptor 提供操作数——如果所有操作数都是常量,则 adaptor 包含常量值
  // 否则 adaptor.getLhs() 返回 null

  // 情况 1:所有操作数都是常量
  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
  auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
  if (lhs && rhs) {
    // 编译期计算——返回折叠后的常量值
    return IntegerAttr::get(getType(), lhs.getInt() + rhs.getInt());
  }

  // 情况 2:x + 0 = x
  if (matchPattern(getRhs(), m_Zero()))
    return getLhs();  // 返回已有的 Value

  // 情况 3:x + x = 2 * x
  if (getLhs() == getRhs())
    return nullptr;   // 这里不处理,留给 SimplifyAddSameOperands Pattern

  return nullptr;  // 不能折叠
}

matchPattern + Matchers 组合用于匹配操作数的模式:

// 匹配常量值
if (matchPattern(op.getRhs(), m_Zero())) { ... }

// 匹配特定整数
if (matchPattern(op.getRhs(), m_SpecificInt(42))) { ... }

// 组合:匹配 x + 0 或 0 + x
if (matchPattern(op.getRhs(), m_Zero()) || matchPattern(op.getLhs(), m_Zero())) { ... }

七、GreedyPatternRewriteDriver 的工作机制

applyPatternsAndFoldGreedily 是 MLIR 中最常用的模式重写驱动器。它的工作流程:

1. worklist ← IR 中的所有 Op
2. while worklist 非空:
3.   从 worklist 中取出一个 Op
4.   对该 Op 尝试所有已注册的 Pattern(按 benefit 降序)
5.   if 某个 Pattern 匹配成功:
6.     执行 rewrite(替换/创建/删除 Op)
7.     将被修改的 Op 及其附近的 Op 重新加入 worklist
8.     break → 继续外层 while 循环
9.   else:
10.     直接处理 worklist 中的下一个 Op

关键特性:

八、Pattern 编写最佳实践

8.1 确保 Pattern 是收敛的

坏 Pattern(可能死循环):

// 将 A→B,又将 B→A——无限循环
struct BadPattern1 : OpRewritePattern<AOp> {
  LogicalResult matchAndRewrite(AOp op, PatternRewriter &r) const override {
    r.replaceOpWithNewOp<BOp>(op, op.getOperand());  // A → B
    return success();
  }
};
struct BadPattern2 : OpRewritePattern<BOp> {
  LogicalResult matchAndRewrite(BOp op, PatternRewriter &r) const override {
    r.replaceOpWithNewOp<AOp>(op, op.getOperand());  // B → A
    return success();
  }
};

确保每个 rewrite 导向一个”更规范”的形式——一个严格偏序。例如:常量折叠(减少 Op 数量)、降低操作复杂度。

8.2 使用 hasCanonicalizerhasFolder 而非在 Pass 中手写

hasCanonicalizer = 1hasFolder = 1 在 ODS 定义中使得规范化规则可以通过方言注册被发现,而不是隐式地嵌入在某个特定 Pass 中——这保证了规范化在所有使用该方言的 Pass 中都被统一应用。

8.3 将复杂 Pattern 拆分成多个简单 Pattern

一个 matchAndRewrite 只做一件事——匹配一个具体的 IR 模式并替换。不要在一个 Pattern 中处理五种不同情况。每个 Pattern 的 benefit 可以不同,让框架根据优先级选择应用顺序。

8.4 检查操作数是否来自不同的作用域

当 Pattern 替换跨越多个 Region 边界的 Op 时要小心——确保替换不违反 SSA 支配关系:

// 如果 op 的操作数来自外层 Region,替换必须保持这个关系
if (op.getOperand(0).getParentRegion() != op.getParentRegion()) {
  // 处理跨 Region 依赖
}

九、方言转换:模式重写的下一个层级

模式重写处理的是同一方言内部或等价方言间的局部替换。下一章讲方言转换(Dialect Conversion)——将一种方言系统地、整体地降阶为另一种方言的框架。这是 MLIR “渐进降阶”的工程核心。

参考资料

官方文档(A 级)

源码(A 级)

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-09 · compiler / architecture

【编译器与 MLIR】Pass 管理与分析

详解 MLIR 的 Pass 基础设施:OperationPass 与 ModulePass 的分类与适用场景、Pass 依赖管理与流水线构建、Pass 选项系统、多线程执行模型,以及 mlir-opt 的调试命令。

2026-06-09 · compiler / architecture

【编译器与 MLIR】MLIR 全景图与设计哲学

从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。


By .