土法炼钢兴趣小组的算法知识备份

【编译器与 MLIR】从零构建一个微型 Tensor DSL

文章导航

分类入口
compilerarchitecture
标签入口
#mlir#llvm#compiler#dsl#tensor#dialect#lowering#jit#hands-on

目录

从零构建一个微型 Tensor DSL

这是全系列的”总复习”——前面十四章讲的概念、API、框架将在这一篇集中落地:设计一个自定义方言、用 ODS 声明 Op、写降阶 Pass,经标准 mlir-opt 管线生成 LLVM 方言并用 mlir-translate 导出 LLVM IR 验证。

目标不追求性能——追求的是一条代码量在可读范围内、降阶逻辑完整的端到端编译链。完整可构建工程请参考 MLIR Toy 教程 的 CMake 布局;本篇聚焦 Tiny 方言相对 Toy 的增量(张量 Op 与 tiny-to-linalg Pass)。

一、我们要做什么

定义一个小型的 Tensor 方言——只包含四个 Op:

Op 语义
tiny.constant 编译期常量(从 DenseElementsAttr 创建)
tiny.add 逐元素张量加法
tiny.mul 逐元素张量乘法
tiny.matmul 矩阵乘法

tiny 方言降阶到 linalg 方言,然后复用 MLIR 的标准 linalg → llvm 管线生成 LLVM 方言,最后用 mlir-translate --mlir-to-llvmir 导出 LLVM IR 检查降阶结果。

tiny dialect ──→ (tiny-to-linalg) ──→ linalg ──→ (标准管线) ──→ llvm 方言 ──→ mlir-translate ──→ LLVM IR

我们只写 tiny 方言定义和 tiny-to-linalg 降阶——后面的 linalg → llvm 是 MLIR 提供的标准 Pass。

二、项目结构

tiny-dsl/
├── CMakeLists.txt
├── include/
│   └── Tiny/
│       ├── CMakeLists.txt
│       ├── TinyDialect.h
│       ├── TinyDialect.td
│       ├── TinyOps.td
│       └── TinyPasses.h
├── lib/
│   └── Tiny/
│       ├── CMakeLists.txt
│       ├── TinyDialect.cpp
│       └── TinyToLinalg.cpp
└── test/
    └── tiny_matmul.mlir

三、方言定义:TinyDialect

3.1 TinyDialect.td

// include/Tiny/TinyDialect.td
#ifndef TINY_DIALECT_TD
#define TINY_DIALECT_TD

include "mlir/IR/OpBase.td"

def Tiny_Dialect : Dialect {
  let name = "tiny";
  let summary = "A minimal tensor DSL for learning MLIR compiler construction";
  let description = [{
    Tiny is a minimal dialect that provides a small set of tensor
    operations (constant, add, mul, matmul) and demonstrates how to
    build a complete lowering pipeline from a custom dialect to LLVM IR.
  }];
  let cppNamespace = "::mlir::tiny";
}

#endif

3.2 TinyOps.td

// include/Tiny/TinyOps.td
#ifndef TINY_OPS_TD
#define TINY_OPS_TD

include "TinyDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"

// 基类:所有 tiny Op 的共同属性
class Tiny_Op<string mnemonic, list<Trait> traits = []>
    : Op<Tiny_Dialect, mnemonic, traits>;

// === tiny.constant ===
def Tiny_ConstantOp : Tiny_Op<"constant", [Pure]> {
  let summary = "constant tensor value";
  let arguments = (ins
    ElementsAttr:$value
  );
  let results = (outs
    AnyRankedTensor:$output
  );
  let assemblyFormat = "$value attr-dict `:` type($output)";
  let hasFolder = 1;
}

// === tiny.add ===
def Tiny_AddOp : Tiny_Op<"add", [Pure, SameOperandsAndResultType]> {
  let summary = "element-wise tensor addition";
  let arguments = (ins
    AnyRankedTensor:$lhs,
    AnyRankedTensor:$rhs
  );
  let results = (outs
    AnyRankedTensor:$result
  );
  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}

// === tiny.mul ===
def Tiny_MulOp : Tiny_Op<"mul", [Pure, SameOperandsAndResultType]> {
  let summary = "element-wise tensor multiplication";
  let arguments = (ins
    AnyRankedTensor:$lhs,
    AnyRankedTensor:$rhs
  );
  let results = (outs
    AnyRankedTensor:$result
  );
  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}

// === tiny.matmul ===
def Tiny_MatmulOp : Tiny_Op<"matmul", [Pure]> {
  let summary = "matrix multiplication";
  let arguments = (ins
    AnyRankedTensor:$lhs,
    AnyRankedTensor:$rhs
  );
  let results = (outs
    AnyRankedTensor:$result
  );

  let builders = [
    OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
      // 推导结果形状:[M,K] x [K,N] → [M,N]
      auto lhsType = lhs.getType().cast<RankedTensorType>();
      auto rhsType = rhs.getType().cast<RankedTensorType>();
      SmallVector<int64_t> resultShape = {
        lhsType.getShape()[0],
        rhsType.getShape()[1]
      };
      auto resultType = RankedTensorType::get(
        resultShape, lhsType.getElementType());
      return build($_builder, $_state, resultType, lhs, rhs);
    }]>
  ];

  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `x` type($rhs) `->` type($result)";
  let hasVerifier = 1;
}
#endif

关键设计决策: - Tiny_MatmulOp 用了自定义 builder 来推导结果形状([M,K] × [K,N] → [M,N])。 - hasVerifier = 1 意味着需要在 C++ 中实现 verify() 方法检查维度兼容。 - 所有数值 Op 标记为 Pure——无副作用,支持死代码消除和 CSE。

四、方言实现:TinyDialect.cpp

// lib/Tiny/TinyDialect.cpp
#include "Tiny/TinyDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"

using namespace mlir;
using namespace mlir::tiny;

#include "Tiny/TinyDialect.cpp.inc"

void TinyDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "Tiny/TinyOps.cpp.inc"
#undef GET_OP_LIST
  >();
}

// tiny.matmul 的验证器
LogicalResult TinyMatmulOp::verify() {
  auto lhsType = getLhs().getType().cast<RankedTensorType>();
  auto rhsType = getRhs().getType().cast<RankedTensorType>();

  // 检查 rank:matmul 必须是 2D
  if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
    return emitOpError("operands must be 2D tensors (matrices)");

  // 检查 inner dimension 匹配:[M, K] × [K, N]
  if (lhsType.getDimSize(1) != rhsType.getDimSize(0))
    return emitOpError("inner dimensions must match")
           << ": lhs shape = [" << lhsType.getDimSize(0)
           << ", " << lhsType.getDimSize(1)
           << "], rhs shape = [" << rhsType.getDimSize(0)
           << ", " << rhsType.getDimSize(1) << "]";

  return success();
}

五、降阶 Pass:TinyToLinalg

这是本项目的核心——将 tiny 方言降阶为 linalg 方言。

// lib/Tiny/TinyToLinalg.cpp
#include "Tiny/TinyDialect.h"
#include "Tiny/TinyOps.h"
#include "Tiny/TinyPasses.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
using namespace mlir::tiny;

// ============================================================
// Pattern 1: tiny.constant → arith.constant
// ============================================================
struct TinyConstantToLinalg : public OpConversionPattern<TinyConstantOp> {
  using OpConversionPattern<TinyConstantOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(
      TinyConstantOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {
    auto resultType = op.getResult().getType().cast<RankedTensorType>();
    auto constant = rewriter.create<arith::ConstantOp>(
        op.getLoc(), resultType, op.getValue());
    rewriter.replaceOp(op, constant.getResult());
    return success();
  }
};

// ============================================================
// Pattern 2: tiny.add → linalg.generic (element-wise)
// ============================================================
struct TinyAddToLinalg : public OpConversionPattern<TinyAddOp> {
  using OpConversionPattern<TinyAddOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(
      TinyAddOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {

    auto resultType = op.getResult().getType().cast<RankedTensorType>();
    auto elementType = resultType.getElementType();

    // 构造 identity indexing maps(逐元素操作——一一映射)
    auto maps = AffineMap::getMultiDimIdentityMap(
        resultType.getRank(), rewriter.getContext());

    // 创建一个空的 tensor 作为输出
    Value empty = rewriter.create<tensor::EmptyOp>(
        op.getLoc(), resultType.getShape(), elementType);

    // 用 linalg.generic 替代 tiny.add
    auto generic = rewriter.create<linalg::GenericOp>(
        op.getLoc(), resultType,
        ValueRange{adaptor.getLhs(), adaptor.getRhs()},
        ValueRange{empty},
        ArrayRef<AffineMap>{maps, maps, maps},
        ArrayRef<utils::IteratorType>{
            utils::IteratorType::parallel,
            utils::IteratorType::parallel},
        [](OpBuilder &bodyBuilder, Location loc,
           ValueRange args) {
          Value add = bodyBuilder.create<arith::AddFOp>(
              loc, args[0], args[1]);
          bodyBuilder.create<linalg::YieldOp>(loc, add);
        });

    rewriter.replaceOp(op, generic.getResults());
    return success();
  }
};

// ============================================================
// Pattern 3: tiny.mul → linalg.generic (element-wise)
// ============================================================
struct TinyMulToLinalg : public OpConversionPattern<TinyMulOp> {
  using OpConversionPattern<TinyMulOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(
      TinyMulOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {

    auto resultType = op.getResult().getType().cast<RankedTensorType>();
    auto elementType = resultType.getElementType();
    auto maps = AffineMap::getMultiDimIdentityMap(
        resultType.getRank(), rewriter.getContext());
    Value empty = rewriter.create<tensor::EmptyOp>(
        op.getLoc(), resultType.getShape(), elementType);

    auto generic = rewriter.create<linalg::GenericOp>(
        op.getLoc(), resultType,
        ValueRange{adaptor.getLhs(), adaptor.getRhs()},
        ValueRange{empty},
        ArrayRef<AffineMap>{maps, maps, maps},
        ArrayRef<utils::IteratorType>{
            utils::IteratorType::parallel,
            utils::IteratorType::parallel},
        [](OpBuilder &bodyBuilder, Location loc, ValueRange args) {
          Value mul = bodyBuilder.create<arith::MulFOp>(
              loc, args[0], args[1]);
          bodyBuilder.create<linalg::YieldOp>(loc, mul);
        });

    rewriter.replaceOp(op, generic.getResults());
    return success();
  }
};

// ============================================================
// Pattern 4: tiny.matmul → linalg.matmul
// ============================================================
struct TinyMatmulToLinalg : public OpConversionPattern<TinyMatmulOp> {
  using OpConversionPattern<TinyMatmulOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(
      TinyMatmulOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {

    auto lhsType = op.getLhs().getType().cast<RankedTensorType>();
    auto resultType = op.getResult().getType().cast<RankedTensorType>();
    auto elementType = resultType.getElementType();

    // 创建初始化为零的输出 tensor
    Value zero = rewriter.create<arith::ConstantOp>(
        op.getLoc(), elementType,
        rewriter.getFloatAttr(elementType, 0.0));
    Value empty = rewriter.create<tensor::EmptyOp>(
        op.getLoc(), resultType.getShape(), elementType);
    Value filled = rewriter.create<linalg::FillOp>(
        op.getLoc(), ValueRange{zero}, ValueRange{empty})
        .getResult(0);

    // 用 linalg.matmul 替代 tiny.matmul
    auto matmul = rewriter.create<linalg::MatmulOp>(
        op.getLoc(),
        ValueRange{adaptor.getLhs(), adaptor.getRhs()},
        ValueRange{filled});

    rewriter.replaceOp(op, matmul.getResults());
    return success();
  }
};

// ============================================================
// Pass: TinyToLinalgPass
// ============================================================
namespace {
struct TinyToLinalgPass
    : public PassWrapper<TinyToLinalgPass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TinyToLinalgPass)

  void runOnOperation() override {
    auto *context = &getContext();
    ConversionTarget target(*context);

    // 目标:只有 linalg, tensor, arith, func 方言的 Op 是合法的
    target.addLegalDialect<linalg::LinalgDialect>();
    target.addLegalDialect<tensor::TensorDialect>();
    target.addLegalDialect<arith::ArithDialect>();
    target.addLegalDialect<func::FuncDialect>();

    // 源:tiny 方言的所有 Op 必须被消除
    target.addIllegalDialect<TinyDialect>();

    RewritePatternSet patterns(context);
    patterns.add<TinyConstantToLinalg, TinyAddToLinalg,
                 TinyMulToLinalg, TinyMatmulToLinalg>(context);

    if (failed(applyFullConversion(getOperation(), target,
                                    std::move(patterns)))) {
      signalPassFailure();
    }
  }
};
} // namespace

// 注册 Pass
std::unique_ptr<Pass> mlir::tiny::createTinyToLinalgPass() {
  return std::make_unique<TinyToLinalgPass>();
}

六、降阶管线与 LLVM IR 导出

将自定义 Pass 注册进 mlir-opt 后,可用命令行走完降阶并导出 LLVM IR(MLIR 19.x;Pass 名以你构建的注册名为准):

# 1. tiny → linalg(自定义 Pass,注册名示例:--tiny-to-linalg)
mlir-opt test/tiny_matmul.mlir \
  --tiny-to-linalg \
  -canonicalize -cse \
  -one-shot-bufferize \
  -convert-linalg-to-loops \
  -lower-affine \
  -convert-scf-to-cf \
  -convert-to-llvm \
  -o tiny_llvm.mlir

# 2. llvm 方言 → LLVM IR 文本
mlir-translate tiny_llvm.mlir --mlir-to-llvmir -o tiny_matmul.ll

验证标准:tiny_llvm.mlir 中不再出现 tiny. 前缀的 Op;tiny_matmul.ll 包含 @main 函数定义。若要进一步 JIT 执行,参考 MLIR 官方 mlir/examples/toy/Ch7/toy.cpp 中的 ExecutionEngine 集成。

七、测试输入:tiny_matmul.mlir

// test/tiny_matmul.mlir
module {
  func.func @main() -> tensor<2x2xf32> {
    %A = tiny.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
    %B = tiny.constant dense<[[5.0, 6.0], [7.0, 8.0]]> : tensor<2x2xf32>
    %C = tiny.matmul %A, %B : tensor<2x2xf32> x tensor<2x2xf32> -> tensor<2x2xf32>
    return %C : tensor<2x2xf32>
  }
}

预期矩阵结果(手算验证,非运行时输出):

[[19.0, 22.0],
 [43.0, 50.0]]

验证:A[0,:]·B[:,0] = 1×5+2×7 = 19A[0,:]·B[:,1] = 1×6+2×8 = 22;依此类推。运行时数值验证需接入 ExecutionEngine 或外部测试框架,本篇以 IR/LLVM IR 降阶正确性为验收标准。

八、从 linalg.matmul 到 LLVM IR 的标准管线

经过 TinyToLinalgPass 后,IR 变为纯 linalg + tensor 方言。后续标准管线处理:

mlir-opt tiny_matmul_linalg.mlir \
  -canonicalize \
  -cse \
  -one-shot-bufferize \
  -convert-linalg-to-loops \
  -lower-affine \
  -convert-scf-to-cf \
  -convert-to-llvm

每一步的中间 IR 可以分别查看,方便调试和理解每层方言的语义。

九、CMake 集成

# CMakeLists.txt
cmake_minimum_required(VERSION 3.20)
project(TinyDSL)

find_package(MLIR REQUIRED CONFIG)

set(LLVM_TARGET_DEFINITIONS TinyOps.td)
mlir_tablegen(TinyOps.h.inc -gen-op-decls)
mlir_tablegen(TinyOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(TinyOpsIncGen)

set(LLVM_TARGET_DEFINITIONS TinyDialect.td)
mlir_tablegen(TinyDialect.h.inc -gen-dialect-decls)
mlir_tablegen(TinyDialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(TinyDialectIncGen)

add_mlir_library(TinyDialect
  TinyDialect.cpp
  TinyToLinalg.cpp

  DEPENDS TinyOpsIncGen TinyDialectIncGen

  LINK_LIBS PUBLIC
  MLIRIR
  MLIRInferTypeOpInterface
  MLIRLinalgDialect
  MLIRTensorDialect
  MLIRArithDialect
)

十、本篇后续与系列收官

这篇把全系列的概念汇集到了一次降阶管线实践中。完整 CMake 工程请对照 Toy 教程搭建;接下来的两篇(调试工作流、IREE 集成)提供实用工具链知识。最后一篇(总结与未来趋势)回到宏观视角,讨论 MLIR 2.0 规划和编译器工程的开放问题。

参考资料

官方文档(A 级)

源码(A 级)

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-09 · compiler / architecture

【编译器与 MLIR】MLIR 全景图与设计哲学

从 Module-Operation-Region-Block 四层结构出发,系统讲解 MLIR 的三条核心设计原则:渐进降阶、方言可组合性、基础设施复用,配合 IREE、CIRCT、Torch-MLIR 等实际案例建立心智模型。

2026-06-09 · compiler / architecture

【编译器与 MLIR】类型系统与属性

解析 MLIR 的类型体系:内建类型(Integer、Float、Tensor、MemRef)与自定义方言类型的注册机制;区分 Type 与 Attribute 的设计意图;通过 OpBuilder 理解类型和属性在 IR 构造中的实际角色。


By .