与 AI 框架的接口设计
本系列前十三篇覆盖了 MLIR 的编译器基础设施和核心方言。现在回到起点:一个 PyTorch/TensorFlow/JAX 的用户模型,怎么变成 MLIR IR?
这是 AI 编译器最”接地气”的问题——框架桥接层的设计决定了最终生成的代码质量和编译器的适用范围。
一、四种框架,五种桥接方案
PyTorch ───→ Torch-MLIR ───→ torch dialect ──→ linalg ──→ ... ──→ LLVM/SPIR-V
TensorFlow → TF / MHLO ────→ mhlo dialect ───→ linalg ──→ ... ──→ LLVM/SPIR-V
JAX ──────→ StableHLO ─────→ stablehlo dialect → linalg ──→ ... ──→ LLVM/SPIR-V
ONNX ─────→ ONNX-MLIR ─────→ onnx dialect ────→ krnl ────→ ... ──→ LLVM
(或通过 onnx-mlir → stablehlo → linalg)
虽然最终都汇入 MLIR 方言栈,但每个框架的”入口方言”和降阶路径有本质差异。
二、Torch-MLIR:PyTorch 的图捕获与降低
Torch-MLIR 的任务是:将 PyTorch 的 Eager 模式计算图(动态 Python → traced graph)翻译为 MLIR。它的核心流程分两步:
PyTorch model (nn.Module)
│
▼
torch.fx (符号化图捕获)
│ └── torch.export / torch.fx.symbolic_trace
│
▼
torch-mlir (翻译为 torch dialect)
│ └── torch_mlir.compile()
│
▼
torch dialect ──→ (torch-to-linalg) ──→ linalg ──→ ... (复用标准 MLIR 管线)
2.1 torch 方言的设计
torch 方言试图直接映射 PyTorch
的算子语义(Aten IR)。它的 Op 集合覆盖 PyTorch
的数百个算子:
// torch dialect — Aten 语义的 MLIR 表示
func.func @forward(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,1000],f32> {
%0 = torch.aten.conv2d %arg0, %weight, %bias, ...
: !torch.vtensor<...>, !torch.vtensor<...>, ... -> !torch.vtensor<...>
%1 = torch.aten.relu %0 : ...
%2 = torch.aten.max_pool2d %1, ...
%3 = torch.aten.linear %2, %fc_weight, %fc_bias : ...
return %3 : !torch.vtensor<[1,1000],f32>
}
torch 方言的降阶方式是”分类处理”:
- 核心数值算子(conv、linear、norm、activation)→
映射到
linalg结构化操作。 - 非数值算子(shape
manipulation、indexing、copy)→ 映射到
tensor+memref方言。 - 不能映射的算子(如控制流、特定的 Aten 框架操作)→ 标记为”未分解”并报告。
2.2 分解机制(Decomposition)
PyTorch 有大量复合算子——例如
torch.ops.aten.batch_norm 可以被分解为
mean、var、sub、div
等基本操作的组合。Torch-MLIR 在 torch
方言层上做算子分解,将复杂算子拆成更细粒度的操作后再降阶为
Linalg:
// 分解前
%bn = torch.aten.batch_norm %x, %weight, %bias, %running_mean, ...
// 分解后(概念层面)
%mean = torch.aten.mean %x, dim=[0,2,3]
%var = torch.aten.var %x, dim=[0,2,3]
%x_norm = torch.aten.sub %x, %mean
%x_scaled = torch.aten.div %x_norm, %sqrt_var
%result = torch.aten.add %x_scaled * %weight, %bias
分解策略减少了需要直接从 torch 降阶到
linalg 的 Op
种类——只有分解后的”原子算子”需要降阶规则。
三、StableHLO / MHLO:Google 生态的统一 IR
StableHLO(Stable High-Level Operations)是 Google 提出的跨框架 ML 操作的统一表示——它是 MHLO(Meta HLO)的稳定化版本,同时也是 XLA HLO 的 MLIR 方言表示。JAX、TensorFlow、PyTorch(通过 PyTorch/XLA)都通过 StableHLO 桥接到 MLIR。
3.1 StableHLO 的定位
JAX ────↘
TensorFlow → StableHLO dialect → (legalize-to-linalg) → linalg → ... → 代码生成
PyTorch/XLA ↗
StableHLO 中的核心操作与 XLA HLO 一一对应:
// StableHLO IR — 与 XLA HLO 语义等价
func.func @main(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x1000xf32> {
%0 = stablehlo.convolution(%arg0, %weight)
dim_numbers = [b,0,1,f]x[0,1,i,o]->[b,0,1,f],
window = {stride = [1,1], pad = [[1,1],[1,1]], ...}
: (tensor<...>, tensor<...>) -> tensor<...>
%1 = stablehlo.broadcast_in_dim %bias, dims = [3]
%2 = stablehlo.add %0, %1
%3 = stablehlo.maximum %2, %c0
return %3
}
3.2 StableHLO → Linalg 的降阶
stablehlo-to-linalg Pass 映射 StableHLO
操作到 Linalg 结构化操作——卷积用
linalg.conv_2d,逐元素操作用
linalg.generic,归约用带归约维度的
linalg.generic。这个过程比
torch → linalg 更直接,因为 StableHLO
的操作语义和 Linalg 的命名操作有很高的重合度。
3.3 StableHLO 与 MHLO 的区别
| 维度 | MHLO | StableHLO |
|---|---|---|
| 稳定性 | 随 XLA 版本频繁变化 | 独立版本,向后兼容 |
| 接口 | 未正式承诺 | 承诺向后兼容(类似 ONNX 的承诺) |
| 适用场景 | Google 内部 TF/JAX | 跨框架互操作 |
StableHLO 是 MHLO 的标准化版本——它的目标是成为 AI 编译器领域的”LLVM IR”:一个足够稳定的中间表示,所有框架都可以安全地降阶到它,下游工具链可以放心地消费这个表示。
四、ONNX-MLIR:ONNX 的 MLIR 编译路径
ONNX-MLIR 将 ONNX 计算图翻译为 onnx
方言,然后通过多阶段降阶生成 LLVM IR:
ONNX model (.onnx)
│
▼
onnx dialect ──→ (shape 推断) ──→ krnl dialect ──→ affine ──→ llvm
krnl 方言是 ONNX-MLIR 独有的——它在 ONNX
语义和通用的 affine/scf
之间提供了中间抽象层。这与其他通过 linalg
的路径不同——ONNX-MLIR 选择了自己的降阶路线。
ONNX-MLIR 也可以通过 onnx-to-stablehlo
路径进入 StableHLO 生态:
ONNX → onnx dialect → stablehlo → linalg → ... (复用 StableHLO 管线)
五、四种方案对比
| 维度 | Torch-MLIR | StableHLO | ONNX-MLIR |
|---|---|---|---|
| 入口方言 | torch |
stablehlo |
onnx |
| DL 算子覆盖 | 数百个 Aten Op(Torch-MLIR 持续扩展) | 百余个 StableHLO Op(见 StableHLO spec) | 百余个 ONNX Op(见 ONNX Operators) |
| 核心降阶路径 | torch → linalg | stablehlo → linalg | onnx → krnl → affine/llvm |
| 降阶的复杂度 | 高(算子多样,需要分解) | 中(算子更规则) | 中 |
| 工程成熟度 | 积极开发中 | 成熟(Google 全面支持) | 成熟(IBM/Linux Foundation) |
| 优化能力 | 复用 Linalg 管线(tiling/fusion/vec) | 同左 | 自主优化管线 |
| 复用 Linalg 管线 | 是 | 是 | 部分(通过 stablehlo 路径可复用) |
| 适用场景 | PyTorch 生态 | TF/JAX/任何支持 XLA 的框架 | 框架互操作(.onnx 文件) |
工程选择
- 如果你的 AI 模型是 PyTorch
写的:Torch-MLIR 是第一选择。它直接消费
torch.fx计算图,不需要中间格式转换。 - 如果框架已经集成 XLA(TF、JAX):StableHLO 是自然选择——TF 和 JAX 原生支持 XLA 编译器,StableHLO 是其 MLIR 版本。
- 如果需要跨框架互操作:ONNX→StableHLO 提供了从 ONNX 生态进入 MLIR 的路径。
- 如果你在做通用 MLIR 编译器开发:关注 StableHLO→Linalg 这条路径——这是当前收敛点最多的降阶管线。
六、StableHLO 的 “IR 收敛点” 角色
MLIR 论文的愿景是多个框架共享一套编译管线。在 AI 编译领域,目前最接近这一愿景的是 StableHLO:
PyTorch ──(torch-to-stablehlo)──↘
TensorFlow ───(原生)──────────────→ StableHLO → linalg → affine/scf → llvm/spirv
JAX ──────────(原生)─────────────↗
ONNX ───(onnx-to-stablehlo)─────↗
所有路径汇入 StableHLO 后,共享完全相同的 Linalg→Affine→SCF→LLVM/SPIRV 降阶管线——tiling、fusion、bufferization、代码生成全部复用。
这条路径上的核心 Pass:
mlir-opt input.mlir \
--stablehlo-legalize-to-linalg \
--linalg-bufferize \
--convert-linalg-to-loops \
--lower-affine \
--convert-scf-to-cf \
--convert-to-llvm七、框架桥接的工程挑战
桥接层有几个共通的工程难题:
算子覆盖不完整:PyTorch 有数百个 Aten 算子,ONNX 标准也在不断扩张。Torch-MLIR 和 ONNX-MLIR 都不会追求 100% 覆盖——未覆盖的算子通过”分解为基本操作”或”回退到框架的原生实现”来处理。
动态形状处理:PyTorch 和 JAX
都支持动态形状(运行时决定的 tensor size)。MLIR 的
tensor<?x?xf32>
表示可以承载动态形状,但某些优化(如 tiling
大小决策)依赖静态形状信息。
数值精度差异:不同框架的同一操作(如
conv2d)在浮点舍入、padding 模式、stride
语义上存在细微差异。桥接层需要精确保持每个框架的数值语义。
控制流翻译:PyTorch
的条件语句(if x.sum() > 0: ... else: ...)在
torch.fx 中是 Python 级的——如何翻译为 MLIR 的
scf.if 需要仔细处理。
八、本篇后续
Part 4 的四篇文章覆盖了 AI 编译器的完整链路——Tensor/Linalg(计算抽象)、Affine/SCF(调度)、GPU(硬件映射)、框架桥接(入口)。下一篇进入全系列的实战环节——从零构建一个微型 Tensor DSL,在一篇文章中走完方言设计、降阶 Pass、LLVM IR 生成的全流程。
参考资料
社区项目(B 级)
- Torch-MLIR — https://github.com/llvm/torch-mlir
- StableHLO — https://github.com/openxla/stablehlo
- ONNX-MLIR — https://github.com/onnx/onnx-mlir
- IREE — https://github.com/iree-org/iree
官方文档(A 级)
- MLIR Torch Dialect — https://github.com/llvm/torch-mlir/blob/main/docs/development.md
- StableHLO Specification — https://github.com/openxla/stablehlo/blob/main/docs/spec.md
- ONNX Operators — https://onnx.ai/onnx/operators/
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【编译器与 MLIR】MLIR 全景图与设计哲学
从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。
【编译器与 MLIR】编译器的挑战与 IR 的裂变
从三阶段编译器局限出发,串联 Halide、XLA、TVM 的 IR 裂变,说明 DSA 与 AI 编译器为何需要 MLIR 这类可组合的多层 IR 框架。
【编译器与 MLIR】环境搭建与第一个 MLIR 程序
从零构建 LLVM/MLIR 工程,用 mlir-opt 理解 .mlir 文本表示,运行规范化 Pass 并逐行解读转换结果,建立从命令行到 IR 变换的直觉。
【编译器与 MLIR】操作、方言与 IR 的 C++ 表示
深入 Operation、Op、Value、Block、Region 的 C++ 内存布局与继承体系:CRTP 模板包装、SSA 值的两种来源、Use 链表的遍历方法。这是后续所有 Pass 写作的基础。