土法炼钢兴趣小组的算法知识备份

学习索引:当机器学习遇上数据库

目录

2018 年,MIT 的 Tim Kraska 联合 Google Brain 发表了一篇标题极为大胆的论文: The Case for Learned Index Structures。 论文的核心论点只有一句话——索引本质上是一个从 key 到存储位置的函数映射, 而机器学习模型天然擅长学习函数。 这句话像一颗炸弹,震动了整个数据库社区。 有人拍案叫绝,有人嗤之以鼻,但无论持何种态度, 几乎所有做存储与索引的研究者都不得不正视这个方向。

本文将从 Kraska 的核心洞察出发,逐步展开学习索引的理论基础、 典型数据结构(RMI、PGM-index、ALEX、LIPP)、 可更新索引的挑战、Bloom filter 的替代方案, 并给出一个完整的 C++ 实现。最后讨论工程落地中的真实问题, 以及我个人对这个方向的看法。

一、核心洞察:索引就是 CDF 模型

传统 B-tree 索引做了什么?给定一个 key,它告诉你这个 key 对应的记录在磁盘上的偏移量(或在内存中的位置)。如果底层数据是排序的,那这个问题可以换一种说法:给定 key,返回它在有序数组中的排名(rank),也就是累积分布函数(CDF)的值。

设有序数组 data[0..N-1],对于查询 key k,我们需要找到位置 p,使得:

p = F(k) * N

其中 F(k) 是 key 的经验 CDF,即所有 key 中小于等于 k 的比例。

这个公式揭示了一个惊人的事实:索引查找等价于 CDF 预测。 如果我们能用一个机器学习模型来近似 CDF,就能直接预测出 key 的位置, 而不需要逐层遍历 B-tree 的内部节点。

让我们把这个想法具体化。假设数据服从均匀分布,那么 CDF 就是一条直线:

F(k) = (k - min_key) / (max_key - min_key)

一个简单的线性模型 pos = w * key + b 就能完美拟合。 当然,现实中的数据分布很少是完美均匀的,但这不妨碍我们用更复杂的模型来逼近真实 CDF。

这个洞察的深刻之处在于:B-tree 的每一层本质上都是一个”粗粒度的分段常数函数”, 用 O(log N) 层来逼近 CDF。 而一个训练好的回归模型可能只需要一次前向传播(几次乘法和加法)就能给出预测, 时间复杂度接近 O(1)

当然,模型预测不可能完全精确,总会有误差。 设预测位置为 p_hat,真实位置为 p,最大误差为 err, 那么我们只需要在 [p_hat - err, p_hat + err] 范围内做一次局部搜索。 如果 err 足够小,这个搜索的代价远小于 B-tree 的 O(log N)

下面的伪代码展示了最基本的学习索引查找流程:

def learned_lookup(model, data, key):
    # 模型预测位置
    predicted_pos = model.predict(key)
    predicted_pos = clamp(predicted_pos, 0, len(data) - 1)

    # 在误差范围内做局部搜索
    lo = max(0, predicted_pos - model.max_error)
    hi = min(len(data) - 1, predicted_pos + model.max_error)

    # 二分查找精确定位
    while lo <= hi:
        mid = (lo + hi) // 2
        if data[mid] == key:
            return mid
        elif data[mid] < key:
            lo = mid + 1
        else:
            hi = mid - 1

    return -1  # not found

关键指标有两个:

  1. 模型推理时间:模型越复杂,推理越慢,但拟合越好。
  2. 最大误差(max error):误差越大,局部搜索范围越大。

学习索引的设计本质上就是在这两者之间做权衡。

二、递归模型索引(RMI):分而治之的分阶段模型

Kraska 论文中提出的第一个具体结构叫做 Recursive Model Index(RMI)。 它的核心思想是:用一个模型来拟合整个 CDF 太难了, 不如把问题分层——顶层模型做粗粒度预测,选择下一层的子模型; 子模型在更小的 key 范围内做更精细的预测。

RMI Architecture

RMI 的结构类似一棵树,通常有 2-3 层(stage):

查找流程如下:

key -> Stage 0 model -> index i
     -> Stage 1 model[i] -> index j
     -> Stage 2 model[j] -> predicted position p
     -> local search in [p - err, p + err]

每个 stage 的模型可以是不同类型的:线性回归、简单神经网络、样条函数等。论文中的实验表明,即使所有层都用简单的线性模型,效果也相当不错。

RMI 的训练过程是自顶向下的:

def train_rmi(data, keys, stages):
    """
    data:   排序后的 key 数组
    stages: 每层的模型数量,如 [1, 100, 10000]
    """
    models = []
    # Stage 0:在所有数据上训练根模型
    root = LinearRegression()
    root.fit(keys, positions)
    models.append([root])

    for stage_idx in range(1, len(stages)):
        stage_models = [None] * stages[stage_idx]
        # 用上一层的预测来分配 key 到当前层的模型
        for key, pos in zip(keys, positions):
            predicted = models[stage_idx - 1][...].predict(key)
            model_idx = int(predicted * stages[stage_idx] / len(data))
            model_idx = clamp(model_idx, 0, stages[stage_idx] - 1)
            # 将 (key, pos) 分配给 stage_models[model_idx]
            assign(stage_models[model_idx], key, pos)

        # 训练当前层的每个模型
        for m in stage_models:
            m.fit()

        models.append(stage_models)

    return models

RMI 的优点:

  1. 空间效率极高:每个线性模型只需要两个参数(斜率和截距), 10000 个模型也只占 80KB。相比之下,一个百万条记录的 B-tree 内部节点可能占数 MB。
  2. 缓存友好:模型参数紧凑,容易放进 L1/L2 缓存。
  3. 查找速度快:2-3 次线性模型推理 + 局部搜索,总时间常数很小。

RMI 的局限:

  1. 不支持插入和删除:这是最致命的问题,后面会详细讨论。
  2. 最大误差可能很大:某些分布下,个别叶子模型的误差可能非常大。
  3. 训练需要全量数据:不能增量训练。

三、PGM-index:最优分段线性近似

RMI 的一个问题是:它的模型结构(每层多少个模型)需要人工调参, 而且不能保证全局最优。 Ferragina 和 Vinciguerra 在 2020 年提出的 Piecewise Geometric Model index(PGM-index) 解决了这个问题。

PGM-index 的核心思想非常优雅:给定误差阈值 epsilon,用最少的线段来近似 CDF。这是一个经典的计算几何问题,有 O(N) 时间的最优算法。然后在分段之上递归构建索引,形成一棵树——树的每一层都是对下层段地址的 PGM 近似。

分段线性近似的关键算法是 optimal piecewise linear approximation。给定有序点集 {(key_i, i)} 和误差阈值 epsilon,找到最少数量的线段,使得每条线段覆盖的所有点到线段的距离不超过 epsilon

// 最优分段线性近似的核心算法(简化版)
struct Segment {
    double slope;
    double intercept;
    size_t start_key;
    size_t start_pos;
};

std::vector<Segment> optimal_pla(
    const std::vector<int64_t>& keys,
    size_t epsilon
) {
    std::vector<Segment> segments;
    size_t seg_start = 0;

    while (seg_start < keys.size()) {
        // 初始化可行域(一个锥形区域)
        double upper_slope = INFINITY;
        double lower_slope = -INFINITY;
        double intercept_at_start = seg_start;

        size_t seg_end = seg_start;

        for (size_t i = seg_start + 1; i < keys.size(); ++i) {
            double dx = keys[i] - keys[seg_start];
            if (dx == 0) { seg_end = i; continue; }

            double upper = (double(i - seg_start) + epsilon) / dx;
            double lower = (double(i - seg_start) - epsilon) / dx;

            // 收缩可行域
            upper_slope = std::min(upper_slope, upper);
            lower_slope = std::max(lower_slope, lower);

            if (upper_slope < lower_slope) {
                // 可行域为空,当前线段到此结束
                break;
            }
            seg_end = i;
        }

        double slope = (upper_slope + lower_slope) / 2.0;
        segments.push_back({slope, intercept_at_start,
                           (size_t)keys[seg_start], seg_start});
        seg_start = seg_end + 1;
    }

    return segments;
}

PGM-index 的递归构建过程:

Level 0: 对原始数据做分段线性近似,得到 S_0 个段
Level 1: 对 S_0 个段的起始位置做分段线性近似,得到 S_1 个段
Level 2: 对 S_1 个段的起始位置做分段线性近似,得到 S_2 个段
...
直到某层只有一个段(根节点)

查找过程:

1. 从根节点开始
2. 用当前层的线性模型预测下一层的段索引
3. 在 [predicted - epsilon, predicted + epsilon] 范围内二分查找精确段
4. 进入下一层,重复 2-3
5. 到达最底层后,预测数据位置并局部搜索

PGM-index 的理论保证非常强:

属性
空间复杂度 O(N / epsilon)
查找时间 O(log(N/epsilon) * log(epsilon))
构建时间 O(N)
误差保证 每层最多 epsilon

与 RMI 相比,PGM-index 不需要调参,自动适应数据分布,并且有严格的最坏情况保证。

四、ALEX:支持动态更新的学习索引

RMI 和 PGM-index 都是只读结构。现实世界的数据库需要频繁地插入和删除数据, 这怎么办? 2020 年,Microsoft Research 的 Ding 等人提出了 ALEX(Adaptive Learned indEX), 首次让学习索引真正可用于动态工作负载。

ALEX 的核心设计有三个:

1. Gapped Array(带间隙的数组)

传统数组在中间插入需要 O(N) 时间来移动元素。 ALEX 在数组中预留了间隙(gap),使得插入只需要移动局部的少量元素。

传统数组:  [1, 3, 5, 7, 9, 11, 13, 15]

ALEX gapped array: [1, _, 3, 5, _, _, 7, 9, _, 11, _, 13, 15, _]
                       ^           ^^^      ^         ^          ^
                      gap         gaps     gap       gap        gap

间隙的密度根据模型预测的插入位置动态调整。如果某个区域插入频繁,那里就会有更多间隙。

2. 指数搜索代替二分搜索

学习索引预测的位置有误差,需要局部搜索。 ALEX 选择了指数搜索(exponential search)而非二分搜索:

// 指数搜索:从预测位置开始,步长倍增
size_t exponential_search(
    const int64_t* data,
    size_t predicted_pos,
    int64_t key,
    size_t data_size
) {
    // 先判断方向
    if (data[predicted_pos] == key) return predicted_pos;

    if (data[predicted_pos] < key) {
        // 向右搜索
        size_t step = 1;
        size_t pos = predicted_pos;
        while (pos + step < data_size && data[pos + step] < key) {
            step *= 2;
        }
        // 在 [pos + step/2, min(pos + step, data_size-1)] 内二分
        return binary_search(data, pos + step / 2,
                           std::min(pos + step, data_size - 1), key);
    } else {
        // 向左搜索
        size_t step = 1;
        size_t pos = predicted_pos;
        while (pos >= step && data[pos - step] > key) {
            step *= 2;
        }
        size_t lo = (pos >= step) ? pos - step : 0;
        return binary_search(data, lo, pos - step / 2, key);
    }
}

指数搜索的妙处在于:当预测精确时(误差为 e),只需 O(log e) 步; 而二分搜索在固定范围内始终是 O(log(2 * max_err))。 对于模型预测精度好的区域,指数搜索更快。

3. 自适应节点分裂与合并

ALEX 的树结构会根据工作负载动态调整:

这与 B-tree 的分裂/合并机制非常相似,但 ALEX 的触发条件和重组策略都考虑了模型质量。ALEX 的整体结构如下:

         [Internal Node: Linear Model]
        /          |            \
  [Internal]   [Internal]    [Internal]
   /    \       /    \         /    \
[Leaf] [Leaf] [Leaf] [Leaf] [Leaf] [Leaf]
  |      |      |      |      |      |
[GA]   [GA]   [GA]   [GA]   [GA]   [GA]

GA = Gapped Array

每个内部节点包含一个线性模型和一个指针数组。每个叶节点包含一个线性模型和一个 gapped array。

ALEX 在 SOSD benchmark 上的表现非常出色:在大部分分布下,查找速度与最优的只读学习索引相当,同时支持每秒数百万次的插入操作。

五、LIPP:面向写入优化的学习索引

2021 年,Wu 等人提出的 LIPP(Learned Index with Precise Positions) 走了一条不同的路。LIPP 的核心观点是:与其像 ALEX 那样在 gapped array 上做文章,不如直接把插入的元素放到模型预测的位置上。

LIPP 的设计要点:

冲突链

当两个 key 被模型预测到同一个位置时,形成冲突。 LIPP 不移动已有元素,而是把冲突的元素挂在一个链上:

Slot[i]:  key_a  ->  key_b  ->  key_c  (collision chain)
Slot[i+1]: key_d
Slot[i+2]: (empty)
Slot[i+3]: key_e  ->  key_f

模型与节点设计

每个 LIPP 节点包含一个线性模型、一个数组(大小根据模型和预期数据量确定), 以及每个槽中存放的 key-value 对、子节点指针或冲突链。

节点重建

当冲突链变长(超过阈值),节点触发重建: - 收集节点中所有 key - 用新的线性模型重新拟合 - 扩大数组容量以减少冲突

// LIPP 节点重建(概念代码)
void rebuild_node(LIPPNode* node) {
    // 1. 收集所有 key-value 对
    auto all_kvs = collect_all_kvs(node);

    // 2. 计算新的容量(通常是 key 数量的 1.2-1.5 倍)
    size_t new_capacity = all_kvs.size() * expansion_factor;

    // 3. 训练新的线性模型
    LinearModel new_model;
    new_model.train(all_kvs, new_capacity);

    // 4. 重新分配 key 到新数组
    node->resize(new_capacity);
    node->model = new_model;
    for (auto& kv : all_kvs) {
        size_t pos = new_model.predict(kv.key);
        node->insert_at(pos, kv);
    }
}

LIPP 的优势在于插入非常快——大多数情况下只需要一次模型预测和一次数组写入,不需要移动元素。代价是查找时可能需要遍历冲突链。在写入密集型工作负载下,LIPP 比 ALEX 有明显优势。

六、可更新学习索引的核心挑战

学习索引面临的最根本挑战是:插入操作会破坏 CDF 模型

考虑一个简单的例子。假设我们有 100 个均匀分布的 key:

keys: [0, 1, 2, 3, ..., 99]
CDF:   0, 0.01, 0.02, ..., 1.0

线性模型 pos = key 完美拟合。现在插入 50 个新 key,全部在 [40, 60] 范围内:

keys: [0, 1, ..., 39, 40, 40.1, 40.2, ..., 59.8, 59.9, 60, 61, ..., 99]

原来的线性模型完全失效了——[40, 60] 范围内的 CDF 陡然变陡, 模型的预测误差急剧增大。

各种可更新学习索引对这个问题的解法不同,但基本思路可以归纳为:

策略 代表 优点 缺点
预留间隙 ALEX 插入快,不需要立即重训练 间隙浪费空间
冲突链 LIPP 写入极快,无需移动元素 读取需遍历链
增量重建 PGM-index (dynamic) 保持严格误差界 重建代价高
Delta buffer FITing-Tree 分摊重建成本 查找需检查 buffer
版本化模型 DILI 多版本模型平滑切换 实现复杂

一个通用的 delta buffer 方案是这样的:

写入流程:
1. 新数据先写入内存中的 buffer(通常是一个小型 B-tree 或跳表)
2. 当 buffer 大小超过阈值时,与主索引合并(merge)
3. 合并时重新训练模型

读取流程:
1. 先查主索引(学习索引)
2. 再查 buffer
3. 取两者中更新的结果

这和 LSM-tree 的思路非常相似。实际上,已经有研究者探索将学习索引与 LSM-tree 结合,例如 Bourbon(OSDI 2020)就是在 LevelDB 中用学习索引替换每层的 B-tree 索引。

七、学习化 Bloom Filter:用模型替代概率数据结构

Kraska 2018 年论文的第二部分提出了一个同样大胆的想法: 用机器学习模型替代 Bloom filter。

传统 Bloom filter 的工作原理:

Insert(key): 对 key 计算 k 个哈希,将对应 bit 置 1
Query(key):  对 key 计算 k 个哈希,检查所有 bit 是否为 1
             -> 全部为 1:可能存在(false positive 可能)
             -> 有 0:一定不存在

Bloom filter 的空间大小取决于目标 false positive rate(FPR)。 对于 1% 的 FPR,每个元素大约需要 10 bit。

学习化 Bloom filter 的思路是:将”key 是否在集合中”视为一个二分类问题。 训练一个模型来区分”存在的 key”和”不存在的 key”。

def learned_bloom_filter(model, backup_filter, key):
    # 模型先做预测
    score = model.predict(key)  # 输出 [0, 1] 之间的概率

    if score > threshold:
        return True   # 模型认为 key 存在
    else:
        # 模型不确定的情况,用备份 Bloom filter 兜底
        return backup_filter.query(key)

这个设计有一个关键的细节:模型的 false negative(漏判)是不可接受的, 因为 Bloom filter 的核心保证是没有 false negative。 所以需要一个备份 Bloom filter 来存储模型预测为”不存在”但实际存在的 key。

空间节省的来源在于: - 如果模型足够好,大多数”存在的 key”都能被正确预测 - 备份 Bloom filter 只需要存储很少的 key(模型漏掉的那些) - 模型本身的大小通常比全量 Bloom filter 小得多

论文报告称,在某些数据集上,学习化 Bloom filter 比传统 Bloom filter 节省 70% 的空间,同时保持相同的 FPR。

但有几个工程上的问题需要注意:

  1. 模型训练需要负样本:你需要知道”不在集合中的 key 长什么样”。在某些场景下(如 LSM-tree 的层间过滤),负样本就是其他层的 key,很容易获取。但在通用场景下,负样本的生成并不总是直观的。

  2. 模型推理延迟:即使是简单的神经网络,推理时间也比几次哈希计算慢得多。在延迟敏感的场景下需要权衡。

  3. 更新代价:集合变化时需要重训练模型,或者用备份 filter 吸收增量。

// 学习化 Bloom filter 的简化实现
class LearnedBloomFilter {
    NeuralNetwork model_;
    BloomFilter backup_;
    double threshold_;

public:
    void build(const std::vector<int64_t>& keys,
               const std::vector<int64_t>& non_keys,
               double target_fpr) {
        // 训练二分类模型
        model_.train(keys, non_keys);

        // 找到使 false negative rate = 0 的阈值
        threshold_ = find_threshold(model_, keys);

        // 把模型漏掉的 key 放入备份 filter
        std::vector<int64_t> missed;
        for (auto k : keys) {
            if (model_.predict(k) < threshold_) {
                missed.push_back(k);
            }
        }
        backup_.build(missed, target_fpr);
    }

    bool query(int64_t key) {
        if (model_.predict(key) >= threshold_) {
            return true;
        }
        return backup_.query(key);
    }
};

八、完整 C++ 实现:简单学习索引

下面给出一个完整的、可编译运行的 C++ 实现。 它包含一个简单的两层 RMI(根模型 + 叶子模型),支持点查询和范围查询。

// learned_index.hpp
// 一个完整的简单学习索引实现
// 支持:构建、点查询、范围查询
// 模型:两层 RMI,所有层使用线性回归

#pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <numeric>
#include <vector>

namespace learned_index {

// 线性模型:pos = slope * key + intercept
struct LinearModel {
    double slope = 0.0;
    double intercept = 0.0;

    // 用最小二乘法训练
    void train(const std::vector<int64_t>& keys,
               const std::vector<double>& positions) {
        assert(keys.size() == positions.size());
        size_t n = keys.size();
        if (n == 0) { slope = 0; intercept = 0; return; }
        if (n == 1) { slope = 0; intercept = positions[0]; return; }

        double sum_x = 0, sum_y = 0, sum_xy = 0, sum_xx = 0;
        for (size_t i = 0; i < n; ++i) {
            double x = static_cast<double>(keys[i]);
            double y = positions[i];
            sum_x += x;
            sum_y += y;
            sum_xy += x * y;
            sum_xx += x * x;
        }

        double denom = n * sum_xx - sum_x * sum_x;
        if (std::abs(denom) < 1e-10) {
            slope = 0;
            intercept = sum_y / n;
        } else {
            slope = (n * sum_xy - sum_x * sum_y) / denom;
            intercept = (sum_y - slope * sum_x) / n;
        }
    }

    double predict(int64_t key) const {
        return slope * static_cast<double>(key) + intercept;
    }
};

// 叶子模型:线性模型 + 误差界
struct LeafModel {
    LinearModel model;
    size_t start_pos = 0;     // 该模型负责的数据起始位置
    size_t end_pos = 0;       // 该模型负责的数据结束位置
    int64_t min_key = 0;
    int64_t max_key = 0;
    size_t max_error = 0;     // 最大预测误差
};

class SimpleLearnedIndex {
public:
    // 构建索引
    void build(const std::vector<int64_t>& sorted_data,
               size_t num_leaf_models = 0) {
        data_ = sorted_data;
        size_t n = data_.size();
        if (n == 0) return;

        // 默认叶子模型数量:sqrt(N)
        if (num_leaf_models == 0) {
            num_leaf_models = std::max<size_t>(
                1, static_cast<size_t>(std::sqrt(n)));
        }
        num_leaves_ = num_leaf_models;

        // --- 训练根模型 ---
        std::vector<int64_t> all_keys(n);
        std::vector<double> all_positions(n);
        for (size_t i = 0; i < n; ++i) {
            all_keys[i] = data_[i];
            all_positions[i] = static_cast<double>(i);
        }
        root_model_.train(all_keys, all_positions);

        // --- 将 key 分配到叶子模型 ---
        leaves_.resize(num_leaves_);
        std::vector<std::vector<int64_t>> leaf_keys(num_leaves_);
        std::vector<std::vector<double>> leaf_positions(num_leaves_);

        for (size_t i = 0; i < n; ++i) {
            double pred = root_model_.predict(data_[i]);
            // 将预测值映射到叶子模型索引
            size_t leaf_idx = static_cast<size_t>(
                pred / static_cast<double>(n) * num_leaves_);
            leaf_idx = std::min(leaf_idx, num_leaves_ - 1);

            leaf_keys[leaf_idx].push_back(data_[i]);
            leaf_positions[leaf_idx].push_back(
                static_cast<double>(i));
        }

        // --- 训练每个叶子模型 ---
        for (size_t i = 0; i < num_leaves_; ++i) {
            auto& leaf = leaves_[i];
            if (leaf_keys[i].empty()) {
                // 空叶子节点,设置为默认
                leaf.start_pos = 0;
                leaf.end_pos = 0;
                leaf.max_error = 0;
                continue;
            }

            leaf.model.train(leaf_keys[i], leaf_positions[i]);
            leaf.min_key = leaf_keys[i].front();
            leaf.max_key = leaf_keys[i].back();
            leaf.start_pos = static_cast<size_t>(
                leaf_positions[i].front());
            leaf.end_pos = static_cast<size_t>(
                leaf_positions[i].back());

            // 计算最大误差
            leaf.max_error = 0;
            for (size_t j = 0; j < leaf_keys[i].size(); ++j) {
                double pred = leaf.model.predict(leaf_keys[i][j]);
                double actual = leaf_positions[i][j];
                size_t err = static_cast<size_t>(
                    std::abs(pred - actual)) + 1;
                leaf.max_error = std::max(leaf.max_error, err);
            }
        }
    }

    // 点查询:返回 key 的位置,未找到返回 -1
    int64_t lookup(int64_t key) const {
        if (data_.empty()) return -1;
        size_t n = data_.size();

        // Stage 0:根模型选择叶子
        double root_pred = root_model_.predict(key);
        size_t leaf_idx = static_cast<size_t>(
            root_pred / static_cast<double>(n) * num_leaves_);
        leaf_idx = std::min(leaf_idx, num_leaves_ - 1);

        const auto& leaf = leaves_[leaf_idx];
        if (leaf.start_pos == leaf.end_pos &&
            leaf.max_error == 0) {
            return -1;
        }

        // Stage 1:叶子模型预测位置
        double pred = leaf.model.predict(key);
        int64_t predicted_pos = static_cast<int64_t>(
            std::round(pred));

        // 限制搜索范围
        int64_t lo = std::max<int64_t>(
            0, predicted_pos - static_cast<int64_t>(
                leaf.max_error));
        int64_t hi = std::min<int64_t>(
            static_cast<int64_t>(n) - 1,
            predicted_pos + static_cast<int64_t>(
                leaf.max_error));
        lo = std::max<int64_t>(lo, 0);
        hi = std::min<int64_t>(hi,
            static_cast<int64_t>(n) - 1);

        // 二分搜索精确定位
        while (lo <= hi) {
            int64_t mid = lo + (hi - lo) / 2;
            if (data_[mid] == key) return mid;
            if (data_[mid] < key) lo = mid + 1;
            else hi = mid - 1;
        }

        return -1;
    }

    // 范围查询:返回 [lo_key, hi_key] 范围内的所有 key
    std::vector<int64_t> range_query(
        int64_t lo_key, int64_t hi_key) const {
        std::vector<int64_t> result;
        if (data_.empty()) return result;

        // 找到 lo_key 的起始位置
        int64_t start = find_lower_bound(lo_key);
        if (start < 0) start = 0;

        // 扫描直到超出 hi_key
        for (size_t i = static_cast<size_t>(start);
             i < data_.size() && data_[i] <= hi_key; ++i) {
            if (data_[i] >= lo_key) {
                result.push_back(data_[i]);
            }
        }

        return result;
    }

    // 获取统计信息
    void print_stats() const {
        size_t total_max_err = 0;
        size_t non_empty = 0;
        for (const auto& leaf : leaves_) {
            if (leaf.max_error > 0) {
                total_max_err += leaf.max_error;
                ++non_empty;
            }
        }
        double avg_err = non_empty > 0
            ? static_cast<double>(total_max_err) / non_empty
            : 0;

        std::cout << "=== Learned Index Stats ===" << '\n';
        std::cout << "Data size:       " << data_.size()
                  << '\n';
        std::cout << "Leaf models:     " << num_leaves_
                  << '\n';
        std::cout << "Non-empty leaves:" << non_empty
                  << '\n';
        std::cout << "Avg max error:   " << avg_err
                  << '\n';
        std::cout << "Model memory:    "
                  << model_size_bytes() << " bytes"
                  << '\n';
    }

    size_t model_size_bytes() const {
        // 根模型:2 * sizeof(double) = 16 bytes
        // 每个叶子模型:2 * sizeof(double) + metadata
        return 16 + num_leaves_ * sizeof(LeafModel);
    }

private:
    int64_t find_lower_bound(int64_t key) const {
        size_t n = data_.size();
        double root_pred = root_model_.predict(key);
        size_t leaf_idx = static_cast<size_t>(
            root_pred / static_cast<double>(n) * num_leaves_);
        leaf_idx = std::min(leaf_idx, num_leaves_ - 1);

        const auto& leaf = leaves_[leaf_idx];
        double pred = leaf.model.predict(key);
        int64_t pos = static_cast<int64_t>(std::round(pred));

        int64_t lo = std::max<int64_t>(
            0, pos - static_cast<int64_t>(leaf.max_error));
        int64_t hi = std::min<int64_t>(
            static_cast<int64_t>(n) - 1,
            pos + static_cast<int64_t>(leaf.max_error));

        // 标准 lower_bound
        while (lo < hi) {
            int64_t mid = lo + (hi - lo) / 2;
            if (data_[mid] < key) lo = mid + 1;
            else hi = mid;
        }
        return lo;
    }

    std::vector<int64_t> data_;
    LinearModel root_model_;
    std::vector<LeafModel> leaves_;
    size_t num_leaves_ = 0;
};

}  // namespace learned_index

下面是测试和基准程序:

// benchmark.cpp
// 编译:g++ -O2 -std=c++17 -o benchmark benchmark.cpp
// 运行:./benchmark

#include "learned_index.hpp"
#include <chrono>
#include <random>

using namespace learned_index;
using Clock = std::chrono::high_resolution_clock;

// 生成不同分布的数据
std::vector<int64_t> generate_uniform(size_t n) {
    std::vector<int64_t> data(n);
    for (size_t i = 0; i < n; ++i) data[i] = i * 10;
    return data;
}

std::vector<int64_t> generate_normal(size_t n) {
    std::mt19937_64 rng(42);
    std::normal_distribution<double> dist(5e8, 1e7);
    std::vector<int64_t> data(n);
    for (size_t i = 0; i < n; ++i) {
        data[i] = static_cast<int64_t>(dist(rng));
    }
    std::sort(data.begin(), data.end());
    data.erase(std::unique(data.begin(), data.end()),
               data.end());
    return data;
}

std::vector<int64_t> generate_lognormal(size_t n) {
    std::mt19937_64 rng(42);
    std::lognormal_distribution<double> dist(20.0, 2.0);
    std::vector<int64_t> data(n);
    for (size_t i = 0; i < n; ++i) {
        data[i] = static_cast<int64_t>(dist(rng));
    }
    std::sort(data.begin(), data.end());
    data.erase(std::unique(data.begin(), data.end()),
               data.end());
    return data;
}

double benchmark_learned_index(
    const std::vector<int64_t>& data,
    const std::vector<int64_t>& queries
) {
    SimpleLearnedIndex idx;
    idx.build(data);

    auto start = Clock::now();
    volatile int64_t sink = 0;
    for (auto q : queries) {
        sink = idx.lookup(q);
    }
    auto end = Clock::now();
    double ns = std::chrono::duration_cast<
        std::chrono::nanoseconds>(end - start).count();
    return ns / queries.size();
}

double benchmark_binary_search(
    const std::vector<int64_t>& data,
    const std::vector<int64_t>& queries
) {
    auto start = Clock::now();
    volatile bool sink = false;
    for (auto q : queries) {
        sink = std::binary_search(
            data.begin(), data.end(), q);
    }
    auto end = Clock::now();
    double ns = std::chrono::duration_cast<
        std::chrono::nanoseconds>(end - start).count();
    return ns / queries.size();
}

int main() {
    const size_t N = 10'000'000;
    const size_t Q = 1'000'000;

    struct TestCase {
        std::string name;
        std::vector<int64_t> data;
    };

    std::vector<TestCase> tests = {
        {"Uniform",   generate_uniform(N)},
        {"Normal",    generate_normal(N)},
        {"Lognormal", generate_lognormal(N)},
    };

    for (auto& tc : tests) {
        // 生成查询(随机选取已有 key)
        std::mt19937_64 rng(123);
        std::uniform_int_distribution<size_t> dist(
            0, tc.data.size() - 1);
        std::vector<int64_t> queries(Q);
        for (size_t i = 0; i < Q; ++i) {
            queries[i] = tc.data[dist(rng)];
        }

        std::cout << "\n--- " << tc.name
                  << " (N=" << tc.data.size() << ") ---\n";

        SimpleLearnedIndex idx;
        idx.build(tc.data);
        idx.print_stats();

        double learned_ns = benchmark_learned_index(
            tc.data, queries);
        double binary_ns = benchmark_binary_search(
            tc.data, queries);

        std::cout << "Learned index: "
                  << learned_ns << " ns/query\n";
        std::cout << "Binary search: "
                  << binary_ns << " ns/query\n";
        std::cout << "Speedup:       "
                  << binary_ns / learned_ns << "x\n";
    }

    return 0;
}

九、学习索引何时胜出、何时败退:数据分布的决定性影响

学习索引的性能高度依赖数据分布。这不是一个小细节,而是决定学习索引能否落地的核心因素。

学习索引占优的场景

1. 近似均匀分布的数据

当 key 近似均匀分布时,CDF 接近一条直线,一个简单的线性模型就能达到极低的误差。这是学习索引的最佳场景。

数据分布:  key = [0, 1, 2, 3, ..., N-1]
CDF:       几乎完美的直线
线性模型误差:接近 0
查找速度:   1-2 次内存访问

vs B-tree: O(log N) 次指针追踪,每次可能 cache miss

典型的真实例子:时间戳索引(如日志系统中的时间戳),自增主键索引。

2. 分段线性的数据

如果 CDF 可以用少量的线段近似(即数据呈现几段不同斜率的线性趋势),PGM-index 这类分段线性结构就非常高效。

3. 平滑的 CDF

正态分布、Zipf 分布等,虽然不是线性的,但 CDF 是平滑的。用少量的模型就能达到很低的误差。

学习索引吃亏的场景

1. 高度不规则的分布

如果数据中有大量的”密度突变”(某些区域极稠密,某些极稀疏),模型很难学好,误差会很大。

数据示例:
  [1, 2, 3, 1000000, 1000001, 1000002, 2000000000, ...]
  密度突变:误差可能达到数百甚至数千

2. 字符串 key

学习索引最初是为数值 key 设计的。对于变长字符串 key,需要先将字符串映射到数值空间,这个映射本身就可能引入大量误差。虽然有 SIndex(SIGMOD 2020)等工作尝试解决这个问题,但目前效果仍不如 B-tree。

3. 多维 key

学习索引在高维 key 上的表现更差。虽然有 Flood(SIGMOD 2020)和 Tsunami(VLDB 2020)等工作,但多维学习索引目前仍处于研究早期。

SOSD Benchmark 的启示

Search on Sorted Data(SOSD)benchmark 提供了一组标准化的数据集和评测框架。来自 SOSD 的关键发现:

                     查找延迟 (ns/op)
数据集          B-tree   PGM    ALEX    RMI
---------------------------------------------
uniform_dense     350     95     105     80
normal            380    120     115    110
lognormal         400    180     150    200
books (real)      420    160     140    130
osm (real)        450    350     250    280
wiki_ts (real)    410    130     125    115

注意:以上数字为概念性示意,实际数值取决于硬件和具体实现。关键观察是:在大多数分布下学习索引都比 B-tree 快,但在某些真实数据集(如 OSM 地理数据)上优势不明显甚至消失。

十、基准测试:学习索引 vs B+tree 的真实对比

为了更具体地理解学习索引的优劣,让我们做一组更细致的对比分析。

测试维度

维度            选项
-----------------------------------------------
数据规模        1M, 10M, 100M, 1B
数据分布        uniform, normal, lognormal, zipf
操作类型        point lookup, range scan, insert
读写比例        100:0, 95:5, 50:50, 0:100

点查询性能

对于只读的点查询,学习索引的优势主要来自两个方面:

  1. 更少的内存访问:B+tree 的每一层都是一次指针追踪,在大数据量下很可能导致 cache miss。学习索引的模型推理只涉及少量浮点运算,全部在寄存器/L1 缓存中完成。

  2. 更小的索引体积:一个 10M 条记录的 B+tree,内部节点可能占 50-100MB。同样数据量的 RMI(10000 个线性模型)只需要 160KB。

内存占用对比(10M 条 int64 key):

B+tree (fanout=256):
  层数 = log_256(10M) ≈ 3
  内部节点 ≈ 10M / 256 + 10M / 256^2 + 1 ≈ 39K 节点
  每节点 4KB → 总计 ≈ 156 MB

RMI (10000 leaf models):
  根模型: 16 bytes
  叶子模型: 10000 * 56 bytes ≈ 547 KB
  总计 ≈ 548 KB

PGM-index (epsilon=64):
  段数 ≈ 10M / 64 ≈ 156K
  每段 24 bytes → 总计 ≈ 3.6 MB

范围查询性能

范围查询的性能差异主要在”定位起始位置”这一步。一旦找到起始位置,后续的顺序扫描两者没有区别(都是遍历有序数组)。

# 范围查询伪代码
def range_query(index, lo, hi):
    start_pos = index.lookup_lower_bound(lo)  # 这一步有差异
    result = []
    for i in range(start_pos, len(data)):     # 这一步没差异
        if data[i] > hi:
            break
        result.append(data[i])
    return result

插入性能

这是学习索引最弱的环节。在 50:50 读写比例下:

                     吞吐量 (Mops/s)
索引             读写 50:50    纯写
---------------------------------------------
B+tree (STX)        4.2        3.8
ALEX                3.5        2.1
LIPP                3.8        3.2
PGM (dynamic)       2.8        1.5

B+tree 在写入密集场景下仍然是王者,因为它的分裂/合并机制经过了数十年的工程优化。学习索引的插入涉及到模型失效检测和潜在的重训练,开销更大。

空间-时间权衡

学习索引的一个独特优势是可以通过调整误差阈值来精确控制空间-时间权衡。

PGM-index 在不同 epsilon 下的表现:

epsilon    段数        索引大小     查找时间
-------------------------------------------------
4          2,500,000   57.2 MB      85 ns
16         625,000     14.3 MB      120 ns
64         156,250     3.6 MB       165 ns
256        39,063      0.9 MB       210 ns
1024       9,766       0.2 MB       260 ns

这种平滑的权衡曲线是 B-tree 很难做到的——B-tree 的节点大小相对固定,调节空间占用的唯一方式是改变 fanout,而 fanout 的选择空间有限。

十一、工业界落地现状

学习索引从论文到产品的路还很长,但已经有一些值得关注的进展。

Google

Google 是学习索引的发源地(Kraska 在 Google Brain 期间完成了核心论文)。据公开信息,Google 内部的部分系统已经采用了学习索引的思想,特别是在 Bigtable 和分布式排序系统中。但具体细节没有公开发表。

Amazon

Amazon 的 DynamoDB 团队在 SIGMOD 2022 上发表了关于学习索引在云数据库中应用的论文。他们发现学习索引在特定的只读工作负载下(如冷数据的二级索引)效果很好,但在通用的 OLTP 工作负载下仍然不如 B-tree。

学术界原型

以下项目是目前质量较高的开源实现:

项目 类型 语言 特点
PGM-index 只读/动态 C++ 理论最优,头文件实现
ALEX 动态 C++ Microsoft Research,支持读写
LIPP 动态 C++ 清华大学,写入优化
RadixSpline 只读 C++ 极简实现,适合嵌入
CDFShop 只读 C++ 自动选择最佳模型
SOSD Benchmark C++ 标准化评测框架

我的判断

学习索引目前还没有在主流数据库产品中被广泛采用。原因不是性能不够好,而是工程化的代价太高:

  1. 主流数据库的 B-tree/LSM-tree 实现经过了数十年的打磨,可靠性极高。替换核心索引结构的风险巨大。
  2. 学习索引的优势在特定数据分布下才显著,不具有普适性。
  3. 可更新学习索引的实现复杂度远高于传统索引。

我认为学习索引最可能首先在以下场景落地:只读分析型系统(如列存数据库、数据湖的索引层),嵌入式索引(如替代 LSM-tree 内部的 B-tree 索引),以及特化的内存数据库(如时序数据库,数据分布相对规律)。

十二、工程陷阱与个人看法

工程陷阱速查表

陷阱 描述 应对策略
模型训练成本 构建索引时需要训练模型,在数据量大时可能需要数秒甚至数分钟 增量训练;只在批量加载时训练,增量更新用 delta buffer
尾延迟(tail latency) 模型在个别 key 上的预测误差可能极大,导致最坏情况查找时间远超平均值 设置误差上限;误差过大的区域回退到 B-tree
分布漂移(distribution shift) 随着数据的不断插入,数据分布与训练时不同,模型逐渐失效 监控模型误差;定期重训练;自适应机制(如 ALEX 的节点分裂)
浮点精度 线性模型使用 double 运算,在极端 key 值下可能出现精度问题 使用规范化的 key 表示;对预测位置做 clamp
并发控制 学习索引的并发更新比 B-tree 更复杂(模型重训练需要全局锁或 RCU) 乐观并发控制;读写分离;版本化模型
字符串 key 将字符串映射到数值空间并不直观,映射质量影响模型效果 前缀编码;字典编码;只在数值 key 上使用学习索引
空间碎片化 Gapped array 和冲突链可能导致内存碎片 定期 compaction;使用内存池
批量加载优化 初始数据加载时应该利用数据已排序的特性 直接用有序数据训练,不需要先构建再插入
模型选择 不同的数据分布适合不同类型的模型(线性/样条/小型神经网络) CDFShop 等自动选择工具;先 profile 再决定
多线程扩展 模型推理本身是无状态的(适合并发),但更新操作需要同步 读路径无锁;写路径分区或批量化

个人看法

学习索引是一个真正的范式转换,但它的影响力更多体现在思维方式上,而非直接替代传统索引。

Kraska 最大的贡献不在于某个具体的数据结构,而在于他打开了一扇门:系统组件可以被视为学习问题。这个思路后来催生了 learned cardinality estimation、learned query optimization、learned scheduling 等一系列工作,统称为 “ML for Systems”。

从纯工程角度看,我认为学习索引在以下方面已经被证明是有价值的:

  1. 空间效率:学习索引的空间占用通常比 B-tree 小一到两个数量级。在内存紧张的场景下,这个优势非常实际。

  2. 只读查询速度:在数据分布友好的场景下,学习索引的查找速度确实优于 B-tree。但”数据分布友好”这个前提条件往往被低估了。

  3. 激发传统索引的改进:学习索引的研究推动了人们重新审视 B-tree 的设计。例如,更好的前缀压缩、自适应节点大小等优化,有些是受学习索引启发而提出的。

但我也有一些保留意见:

  1. benchmark 的局限性:很多论文中的对比实验使用的是非常简单的 B-tree 实现。当对手换成高度优化的 B-tree(如 ART、HOT、FAST)时,学习索引的优势会显著缩小。

  2. 系统级影响被低估:在真实数据库中,索引查找只是整个查询执行的一小部分。即使索引查找快了 2 倍,端到端的查询延迟改善可能只有 10-20%。

  3. 运维复杂度:生产环境需要监控模型质量、处理分布漂移、管理模型版本。这些运维负担是传统索引不存在的。

总的来说,我把学习索引视为索引设计工具箱中的一个新工具,而非 B-tree 的终结者。在合适的场景下使用合适的工具,这才是工程师应有的态度。


上一篇: MVCC 实现变体全解 下一篇: TCP 拥塞控制

相关阅读: - B-tree 深度解剖 - 查询优化器


By .