土法炼钢兴趣小组的算法知识备份

【编译器与 MLIR】与 AI 框架的接口设计

文章导航

分类入口
compilerarchitecture
标签入口
#mlir#llvm#compiler#pytorch#tensorflow#jax#onnx#torch-mlir#mhlo#stablehlo#framework-bridge

目录

与 AI 框架的接口设计

本系列前十三篇覆盖了 MLIR 的编译器基础设施和核心方言。现在回到起点:一个 PyTorch/TensorFlow/JAX 的用户模型,怎么变成 MLIR IR?

这是 AI 编译器最”接地气”的问题——框架桥接层的设计决定了最终生成的代码质量和编译器的适用范围。

一、四种框架,五种桥接方案

PyTorch ───→ Torch-MLIR ───→ torch dialect ──→ linalg ──→ ... ──→ LLVM/SPIR-V
TensorFlow → TF / MHLO ────→ mhlo dialect ───→ linalg ──→ ... ──→ LLVM/SPIR-V
JAX ──────→ StableHLO ─────→ stablehlo dialect → linalg ──→ ... ──→ LLVM/SPIR-V
ONNX ─────→ ONNX-MLIR ─────→ onnx dialect ────→ krnl ────→ ... ──→ LLVM
           (或通过 onnx-mlir → stablehlo → linalg)

虽然最终都汇入 MLIR 方言栈,但每个框架的”入口方言”和降阶路径有本质差异。

二、Torch-MLIR:PyTorch 的图捕获与降低

Torch-MLIR 的任务是:将 PyTorch 的 Eager 模式计算图(动态 Python → traced graph)翻译为 MLIR。它的核心流程分两步:

PyTorch model (nn.Module)
   │
   ▼
torch.fx (符号化图捕获)
   │   └── torch.export / torch.fx.symbolic_trace
   │
   ▼
torch-mlir (翻译为 torch dialect)
   │   └── torch_mlir.compile()
   │
   ▼
torch dialect ──→ (torch-to-linalg) ──→ linalg ──→ ... (复用标准 MLIR 管线)

2.1 torch 方言的设计

torch 方言试图直接映射 PyTorch 的算子语义(Aten IR)。它的 Op 集合覆盖 PyTorch 的数百个算子:

// torch dialect — Aten 语义的 MLIR 表示
func.func @forward(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,1000],f32> {
  %0 = torch.aten.conv2d %arg0, %weight, %bias, ...
    : !torch.vtensor<...>, !torch.vtensor<...>, ... -> !torch.vtensor<...>
  %1 = torch.aten.relu %0 : ...
  %2 = torch.aten.max_pool2d %1, ...
  %3 = torch.aten.linear %2, %fc_weight, %fc_bias : ...
  return %3 : !torch.vtensor<[1,1000],f32>
}

torch 方言的降阶方式是”分类处理”:

2.2 分解机制(Decomposition)

PyTorch 有大量复合算子——例如 torch.ops.aten.batch_norm 可以被分解为 meanvarsubdiv 等基本操作的组合。Torch-MLIR 在 torch 方言层上做算子分解,将复杂算子拆成更细粒度的操作后再降阶为 Linalg:

// 分解前
%bn = torch.aten.batch_norm %x, %weight, %bias, %running_mean, ...

// 分解后(概念层面)
%mean = torch.aten.mean %x, dim=[0,2,3]
%var = torch.aten.var %x, dim=[0,2,3]
%x_norm = torch.aten.sub %x, %mean
%x_scaled = torch.aten.div %x_norm, %sqrt_var
%result = torch.aten.add %x_scaled * %weight, %bias

分解策略减少了需要直接从 torch 降阶到 linalg 的 Op 种类——只有分解后的”原子算子”需要降阶规则。

三、StableHLO / MHLO:Google 生态的统一 IR

StableHLO(Stable High-Level Operations)是 Google 提出的跨框架 ML 操作的统一表示——它是 MHLO(Meta HLO)的稳定化版本,同时也是 XLA HLO 的 MLIR 方言表示。JAX、TensorFlow、PyTorch(通过 PyTorch/XLA)都通过 StableHLO 桥接到 MLIR。

3.1 StableHLO 的定位

JAX ────↘
TensorFlow → StableHLO dialect → (legalize-to-linalg) → linalg → ... → 代码生成
PyTorch/XLA ↗

StableHLO 中的核心操作与 XLA HLO 一一对应:

// StableHLO IR — 与 XLA HLO 语义等价
func.func @main(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x1000xf32> {
  %0 = stablehlo.convolution(%arg0, %weight)
    dim_numbers = [b,0,1,f]x[0,1,i,o]->[b,0,1,f],
    window = {stride = [1,1], pad = [[1,1],[1,1]], ...}
    : (tensor<...>, tensor<...>) -> tensor<...>
  %1 = stablehlo.broadcast_in_dim %bias, dims = [3]
  %2 = stablehlo.add %0, %1
  %3 = stablehlo.maximum %2, %c0
  return %3
}

3.2 StableHLO → Linalg 的降阶

stablehlo-to-linalg Pass 映射 StableHLO 操作到 Linalg 结构化操作——卷积用 linalg.conv_2d,逐元素操作用 linalg.generic,归约用带归约维度的 linalg.generic。这个过程比 torch → linalg 更直接,因为 StableHLO 的操作语义和 Linalg 的命名操作有很高的重合度。

3.3 StableHLO 与 MHLO 的区别

维度 MHLO StableHLO
稳定性 随 XLA 版本频繁变化 独立版本,向后兼容
接口 未正式承诺 承诺向后兼容(类似 ONNX 的承诺)
适用场景 Google 内部 TF/JAX 跨框架互操作

StableHLO 是 MHLO 的标准化版本——它的目标是成为 AI 编译器领域的”LLVM IR”:一个足够稳定的中间表示,所有框架都可以安全地降阶到它,下游工具链可以放心地消费这个表示。

四、ONNX-MLIR:ONNX 的 MLIR 编译路径

ONNX-MLIR 将 ONNX 计算图翻译为 onnx 方言,然后通过多阶段降阶生成 LLVM IR:

ONNX model (.onnx)
   │
   ▼
onnx dialect ──→ (shape 推断) ──→ krnl dialect ──→ affine ──→ llvm

krnl 方言是 ONNX-MLIR 独有的——它在 ONNX 语义和通用的 affine/scf 之间提供了中间抽象层。这与其他通过 linalg 的路径不同——ONNX-MLIR 选择了自己的降阶路线。

ONNX-MLIR 也可以通过 onnx-to-stablehlo 路径进入 StableHLO 生态:

ONNX → onnx dialect → stablehlo → linalg → ... (复用 StableHLO 管线)

五、四种方案对比

维度 Torch-MLIR StableHLO ONNX-MLIR
入口方言 torch stablehlo onnx
DL 算子覆盖 数百个 Aten Op(Torch-MLIR 持续扩展) 百余个 StableHLO Op(见 StableHLO spec 百余个 ONNX Op(见 ONNX Operators
核心降阶路径 torch → linalg stablehlo → linalg onnx → krnl → affine/llvm
降阶的复杂度 高(算子多样,需要分解) 中(算子更规则)
工程成熟度 积极开发中 成熟(Google 全面支持) 成熟(IBM/Linux Foundation)
优化能力 复用 Linalg 管线(tiling/fusion/vec) 同左 自主优化管线
复用 Linalg 管线 部分(通过 stablehlo 路径可复用)
适用场景 PyTorch 生态 TF/JAX/任何支持 XLA 的框架 框架互操作(.onnx 文件)

工程选择

六、StableHLO 的 “IR 收敛点” 角色

MLIR 论文的愿景是多个框架共享一套编译管线。在 AI 编译领域,目前最接近这一愿景的是 StableHLO:

PyTorch ──(torch-to-stablehlo)──↘
TensorFlow ───(原生)──────────────→ StableHLO → linalg → affine/scf → llvm/spirv
JAX ──────────(原生)─────────────↗
ONNX ───(onnx-to-stablehlo)─────↗

所有路径汇入 StableHLO 后,共享完全相同的 Linalg→Affine→SCF→LLVM/SPIRV 降阶管线——tiling、fusion、bufferization、代码生成全部复用。

这条路径上的核心 Pass:

mlir-opt input.mlir \
  --stablehlo-legalize-to-linalg \
  --linalg-bufferize \
  --convert-linalg-to-loops \
  --lower-affine \
  --convert-scf-to-cf \
  --convert-to-llvm

七、框架桥接的工程挑战

桥接层有几个共通的工程难题:

算子覆盖不完整:PyTorch 有数百个 Aten 算子,ONNX 标准也在不断扩张。Torch-MLIR 和 ONNX-MLIR 都不会追求 100% 覆盖——未覆盖的算子通过”分解为基本操作”或”回退到框架的原生实现”来处理。

动态形状处理:PyTorch 和 JAX 都支持动态形状(运行时决定的 tensor size)。MLIR 的 tensor<?x?xf32> 表示可以承载动态形状,但某些优化(如 tiling 大小决策)依赖静态形状信息。

数值精度差异:不同框架的同一操作(如 conv2d)在浮点舍入、padding 模式、stride 语义上存在细微差异。桥接层需要精确保持每个框架的数值语义。

控制流翻译:PyTorch 的条件语句(if x.sum() > 0: ... else: ...)在 torch.fx 中是 Python 级的——如何翻译为 MLIR 的 scf.if 需要仔细处理。

八、本篇后续

Part 4 的四篇文章覆盖了 AI 编译器的完整链路——Tensor/Linalg(计算抽象)、Affine/SCF(调度)、GPU(硬件映射)、框架桥接(入口)。下一篇进入全系列的实战环节——从零构建一个微型 Tensor DSL,在一篇文章中走完方言设计、降阶 Pass、LLVM IR 生成的全流程。

参考资料

社区项目(B 级)

官方文档(A 级)

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-09 · compiler / architecture

【编译器与 MLIR】MLIR 全景图与设计哲学

从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。

2026-06-09 · compiler / architecture

【编译器与 MLIR】操作、方言与 IR 的 C++ 表示

深入 Operation、Op、Value、Block、Region 的 C++ 内存布局与继承体系:CRTP 模板包装、SSA 值的两种来源、Use 链表的遍历方法。这是后续所有 Pass 写作的基础。


By .