土法炼钢兴趣小组的算法知识备份

【编译器与 MLIR】方言转换与渐进降阶策略

文章导航

分类入口
compilerarchitecture
标签入口
#mlir#llvm#compiler#dialect-conversion#type-converter#lowering#progressive-lowering#bufferization

目录

方言转换与渐进降阶策略

前面两章分别讲了 Pass 管理(外层框架)和模式重写(内层引擎)。这一章讲两者叠加的最终形态——方言转换(Dialect Conversion),它把一个 Module 从一个方言集合系统性地降阶到另一个方言集合。

这是 MLIR “渐进降阶”的工程核心。没有方言转换框架,渐进降阶就只是概念——有了它,每一步降阶才成为可定义、可验证、可复用的编译 Pass。

一、方言转换与模式重写的区别

模式重写(前一章)和方言转换都基于 RewritePattern,但动机和机制不同:

维度 模式重写 方言转换
目标 局部优化(规范化、折叠、简化) 全局降阶(将一种方言变成另一种方言)
是否修改类型 一般不修改类型 通常修改类型(tensor → memref, index → i64)
是否保留方言 保留(在同方言内优化) 消除(源方言 Op 全部转换)
合法性判断 不需判断(所有 IR 始终合法) 需要检查:转换后 IR 是否只包含目标方言的 Op
执行方式 Greedy 迭代到不动点 一次性转换或分阶段的部分转换

方言转换解决的问题是:给定一个 IR,其中包含方言 A 和 B 的 Op,如何系统性地将方言 A 的 Op 全部替换为方言 B 的 Op,并确保转换后的 IR 类型一致、语义等价。

二、方言转换的三个组件

2.1 TypeConverter:类型映射

类型转换指定源方言中的类型变成目标方言中的什么类型:

TypeConverter converter;

// 直接映射:tensor → memref
converter.addConversion([](TensorType tensorType) {
  return MemRefType::get(tensorType.getShape(),
                         tensorType.getElementType());
});

// 1 对 1 映射:index → i64
converter.addConversion([](IndexType indexType) {
  return IntegerType::get(indexType.getContext(), 64);
});

// 添加"合法"的目标类型(不做转换,但会被接受)
converter.addConversion([](IntegerType intType) { return intType; });
converter.addConversion([](FloatType floatType) { return floatType; });

2.2 ConversionTarget:合法性规则

ConversionTarget 指定转换后的 IR 中哪些方言/Op 是合法的:

ConversionTarget target(getContext());

// 目标方言是"合法"的——转换后应该只有这些方言的 Op
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<scf::SCFDialect>();

// 源方言是"非法"的——所有 tensor 方言的 Op 必须被转换
target.addIllegalDialect<tensor::TensorDialect>();

// 但允许某些特定的 Op 在特定条件下保持为源方言
target.addDynamicallyLegalOp<tensor::EmptyOp>(
    [](tensor::EmptyOp op) {
      // 如果 tensor.empty 的 result 已经被 bufferization 使用,则允许
      return op.getResult().use_empty();
    });

addDynamicallyLegalOp 是最灵活的机制——它允许某些源方言 Op 在满足条件时保留。这在部分转换中非常有用。

2.3 Conversion Patterns:转换模式

转换 Pattern 是实际的 RewritePattern,但用的是 ConversionPatternOpConversionPattern(提供了 TypeConverter 调用):

struct BufferizeTensorEmpty : public OpConversionPattern<tensor::EmptyOp> {
  using OpConversionPattern<tensor::EmptyOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(
      tensor::EmptyOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {
    // 通过 TypeConverter 获取目标类型
    auto memrefType = getTypeConverter()->convertType(op.getType())
                          .cast<MemRefType>();

    // 创建目标方言的等价 Op
    rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType);

    return success();
  }
};

OpAdaptor 是方言转换特有的——它与普通的操作数访问器不同:OpAdaptor 返回的是已经转换过的操作数(它们的类型可能已经被底层转换 Pattern 改了),而 op.getLhs() 返回的是原始操作数。

三、完整转换 vs. 部分转换

MLIR 方言转换支持两种模式:

完整转换(Full Conversion)

所有源方言的 Op 都必须被转换——转换后 IR 中不允许出现源方言:

if (failed(applyFullConversion(moduleOp, target,
                                std::move(patterns)))) {
  // 转换失败——存在无法转换的源方言 Op
  signalPassFailure();
}

部分转换(Partial Conversion)

只转换”能转换的”,其余源方言 Op 保留:

if (failed(applyPartialConversion(moduleOp, target,
                                   std::move(patterns)))) {
  // 转换部分失败(通常是内部错误)
  signalPassFailure();
}
// 转换成功,但 IR 中可能仍存在源方言 Op

何时用部分转换: - 降阶管线中的中间步骤——先用部分转换消除某种特定的 Op 模式,再处理剩余的。 - 另一个 Pass 会在后续处理未被转换的 Op。 - 存在递归或间接的 Op 关系,需要多轮转换。

四、类型转换的歧义处理

当源类型到目标类型不是一一映射时,需要额外的逻辑:

// 场景:tensor<?xf32> → memref<?xf32>(动态形状)
// 需要运行时 shape 信息

// 解决方式:在 convertType 中接受额外的动态参数

// 方式 1:为每个 TensorType 生成签名字段
converter.addConversion(
    [](TensorType t, SmallVectorImpl<Value> &dynamicDims) {
      // dynamicDims 是由框架填充的动态维度值
      auto memrefType = MemRefType::get(
          t.getShape(), t.getElementType());
      return memrefType;
    });

// 方式 2:在 ConversionPattern 中处理
struct MyPattern : public OpConversionPattern<SourceOp> {
  LogicalResult matchAndRewrite(
      SourceOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {

    // 获取转换后的类型
    auto resultType = getTypeConverter()
        ->convertType(op.getResult().getType())
        .cast<MemRefType>();

    // 提取动态维度值
    SmallVector<Value> dynamicDims;
    // 从 IR 中获取 shape 值的逻辑
    // ...

    rewriter.replaceOpWithNewOp<TargetOp>(
        op, resultType, adaptor.getOperands(), dynamicDims);

    return success();
  }
};

五、一个完整的方言转换示例

tensor 方言的 IR 转换为 memref 方言(bufferization)。以下输入/输出 IR 与 Pass 框架为教学示意——真实 one-shot-bufferizepopulateTensorBufferizationPatterns 的输出在函数签名、alloc 位置和 alias 决策上会更复杂,请以你环境中的 mlir-opt dump 为准。

输入的 IR(tensor 方言):

func.func @add_tensors(%a: tensor<256xf32>, %b: tensor<256xf32>) -> tensor<256xf32> {
  %0 = tensor.empty() : tensor<256xf32>
  %1 = linalg.add ins(%a, %b : tensor<256xf32>, tensor<256xf32>)
                   outs(%0 : tensor<256xf32>) -> tensor<256xf32>
  return %1 : tensor<256xf32>
}

Pass 代码框架

class BufferizeAddTensorsPass : public PassWrapper<BufferizeAddTensorsPass,
                                                    OperationPass<func::FuncOp>> {
public:
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BufferizeAddTensorsPass)

  void runOnOperation() override {
    auto *context = &getContext();
    TypeConverter converter;
    ConversionTarget target(*context);

    // 1. 设置类型转换
    converter.addConversion([](TensorType t) -> std::optional<Type> {
      return MemRefType::get(t.getShape(), t.getElementType());
    });

    // 2. 设置合法性
    target.addLegalDialect<memref::MemRefDialect>();
    target.addLegalDialect<arith::ArithDialect>();
    target.addLegalDialect<linalg::LinalgDialect>();
    target.addIllegalDialect<tensor::TensorDialect>();
    // 允许 func.func 的参数/返回值类型符合转换后也为 memref
    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
      return converter.isSignatureLegal(op.getFunctionType());
    });

    // 3. 设置转换 Pattern
    RewritePatternSet patterns(context);
    populateTensorBufferizationPatterns(converter, patterns);
    // 处理 func 签名中的 tensor → memref 转换
    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
        patterns, converter);

    // 4. 运行完整转换
    if (failed(applyFullConversion(getOperation(), target,
                                    std::move(patterns)))) {
      signalPassFailure();
    }
  }
};

输出 IR(memref 方言):

func.func @add_tensors(%arg0: memref<256xf32>, %arg1: memref<256xf32>,
                        %arg2: memref<256xf32>) {
  linalg.add ins(%arg0, %arg1 : memref<256xf32>, memref<256xf32>)
             outs(%arg2 : memref<256xf32>)
  return
}

注意几点变化:

  1. 类型全变了tensor<256xf32>memref<256xf32>
  2. 函数签名变了:输出 tensor 变成了输入 memrefouts 参数变成函数参数)。
  3. tensor.empty 消失了:memref 不需要显式分配——分配由调用者在函数外部管理。或者 bufferization 生成了 memref.alloc 在函数内部分配。

六、Progressive Lowering 的组合模式

方言转换的威力不在于单个转换,而在于组合:

tensor → linalg → affine → scf → cf → llvm
   ↑         ↑        ↑      ↑     ↑
   ├─── TensorToLinalg
   │
   ├────────────── LinalgToAffineLoops
   │
   ├─────────────────────── LowerAffine
   │
   ├────────────────────────────── ConvertSCFToCF
   │
   └─────────────────────────────────── ConvertToLLVM

每一步都是独立的方言转换 Pass:

PassManager pm(&ctx);

// Step 1: tensor → linalg
pm.addPass(createTensorToLinalgPass());

// Step 2: linalg → affine loops
pm.addPass(createConvertLinalgToAffineLoopsPass());

// Step 3: affine → scf
pm.addPass(createLowerAffinePass());

// Step 4: scf → cf
pm.addPass(createConvertSCFToCFPass());

// Step 5: cf/arith/func → llvm
pm.addPass(createConvertToLLVMPass());

pm.run(moduleOp);

每步都可以单独运行和调试——这是渐进降阶的工程价值:编译链不是黑箱,是显式的、可检查的步骤序列

七、部分转换的实际用途

部分转换在 MLIR 的标准管线中很常见。例如,linalg 方言的 tiling 和 fusion 在 tensor 语义上做(因为纯函数式的 tensor 做数据流分析更容易),只有 tiling 完成后才做 bufferization(转换为 memref):

tensor + linalg (高层 IR)
  │
  ├── linalg tiling (仍在 tensor 域)      ← 部分转换:只 tile 不 bufferize
  ├── linalg fusion (仍在 tensor 域)      ← 部分转换
  │
  ├── bufferization (tensor → memref)     ← 完整转换:消除 tensor 方言
  │
  └── lowering to llvm

部分转换使这种”先优化再降阶”的策略成为可能——tiling 和 fusion 在更高层的抽象上做分析和决策,bufferization 在决策完毕后执行。

八、方言转换的常见陷阱

8.1 类型转换不匹配

如果 TypeConverter 返回 std::nullopt(即”不会转换这个类型”),ConversionPattern 中的对应 Op 会失败。确保 TypeConverter 覆盖了所有可能出现在源方言 Op 操作数和结果中的类型。

8.2 OpAdaptor 与 Operation 的混用

// 错误:在 ConversionPattern 中使用 op.getOperand() 而非 adaptor.getOperand()
LogicalResult matchAndRewrite(MyOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const override {
  auto type = op.getOperand(0).getType();  // 这是原始(已转换)的类型!
  // 正确做法是用 adaptor
  auto type = adaptor.getOperand(0).getType();  // 这是转换后的类型
}

8.3 1 对 N 的类型映射

某些方言转换中一个 source Op 可能产生 N 个 target Op(例如一个 tensor 级别的 matmul 降阶后产生多个 memref + scf 循环 + arith 操作)。在 ConversionPatternRewriter 中,通常先 create 多个 Op,再用 rewriter.replaceOp(op, newValues) 或分步 eraseOp 完成一对多替换——具体 API 因 MLIR 版本而异,以 mlir/include/mlir/Transforms/DialectConversion.h 为准。

8.4 合法性检查过于严格或过于宽松

addIllegalDialect 加上 addDynamicallyLegalOp 是精度最高但最易出错的配置。一个 Op 如果既非法又未被动态合法化,转换就会报告失败。建议从 addIllegalOp<SpecificOp>() 开始逐步收紧。

九、本篇后续

Part 3(Pass 管理、模式重写、方言转换)覆盖了 MLIR 编译流水线的完整基础设施。Part 4 进入 AI 编译的核心——Tensor、Linalg、Affine、SCF、GPU 方言的语义和优化策略。

参考资料

官方文档(A 级)

源码(A 级)

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-09 · compiler / architecture

【编译器与 MLIR】MLIR 全景图与设计哲学

从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。


By .