2018 年,MIT 的 Tim Kraska 联合 Google Brain 发表了一篇标题极为大胆的论文: The Case for Learned Index Structures。 论文的核心论点只有一句话——索引本质上是一个从 key 到存储位置的函数映射, 而机器学习模型天然擅长学习函数。 这句话像一颗炸弹,震动了整个数据库社区。 有人拍案叫绝,有人嗤之以鼻,但无论持何种态度, 几乎所有做存储与索引的研究者都不得不正视这个方向。
本文将从 Kraska 的核心洞察出发,逐步展开学习索引的理论基础、 典型数据结构(RMI、PGM-index、ALEX、LIPP)、 可更新索引的挑战、Bloom filter 的替代方案, 并给出一个完整的 C++ 实现。最后讨论工程落地中的真实问题, 以及我个人对这个方向的看法。
一、核心洞察:索引就是 CDF 模型
传统 B-tree 索引做了什么?给定一个 key,它告诉你这个 key 对应的记录在磁盘上的偏移量(或在内存中的位置)。如果底层数据是排序的,那这个问题可以换一种说法:给定 key,返回它在有序数组中的排名(rank),也就是累积分布函数(CDF)的值。
设有序数组 data[0..N-1],对于查询 key
k,我们需要找到位置 p,使得:
p = F(k) * N
其中 F(k) 是 key 的经验 CDF,即所有 key
中小于等于 k 的比例。
这个公式揭示了一个惊人的事实:索引查找等价于 CDF 预测。 如果我们能用一个机器学习模型来近似 CDF,就能直接预测出 key 的位置, 而不需要逐层遍历 B-tree 的内部节点。
让我们把这个想法具体化。假设数据服从均匀分布,那么 CDF 就是一条直线:
F(k) = (k - min_key) / (max_key - min_key)
一个简单的线性模型 pos = w * key + b
就能完美拟合。
当然,现实中的数据分布很少是完美均匀的,但这不妨碍我们用更复杂的模型来逼近真实
CDF。
这个洞察的深刻之处在于:B-tree
的每一层本质上都是一个”粗粒度的分段常数函数”, 用
O(log N) 层来逼近 CDF。
而一个训练好的回归模型可能只需要一次前向传播(几次乘法和加法)就能给出预测,
时间复杂度接近 O(1)。
当然,模型预测不可能完全精确,总会有误差。 设预测位置为
p_hat,真实位置为 p,最大误差为
err, 那么我们只需要在
[p_hat - err, p_hat + err]
范围内做一次局部搜索。 如果 err
足够小,这个搜索的代价远小于 B-tree 的
O(log N)。
下面的伪代码展示了最基本的学习索引查找流程:
def learned_lookup(model, data, key):
# 模型预测位置
predicted_pos = model.predict(key)
predicted_pos = clamp(predicted_pos, 0, len(data) - 1)
# 在误差范围内做局部搜索
lo = max(0, predicted_pos - model.max_error)
hi = min(len(data) - 1, predicted_pos + model.max_error)
# 二分查找精确定位
while lo <= hi:
mid = (lo + hi) // 2
if data[mid] == key:
return mid
elif data[mid] < key:
lo = mid + 1
else:
hi = mid - 1
return -1 # not found关键指标有两个:
- 模型推理时间:模型越复杂,推理越慢,但拟合越好。
- 最大误差(max error):误差越大,局部搜索范围越大。
学习索引的设计本质上就是在这两者之间做权衡。
二、递归模型索引(RMI):分而治之的分阶段模型
Kraska 论文中提出的第一个具体结构叫做 Recursive Model Index(RMI)。 它的核心思想是:用一个模型来拟合整个 CDF 太难了, 不如把问题分层——顶层模型做粗粒度预测,选择下一层的子模型; 子模型在更小的 key 范围内做更精细的预测。
RMI 的结构类似一棵树,通常有 2-3 层(stage):
- Stage 0:一个根模型,通常是简单的线性回归。输入 key,输出一个浮点数,用于选择 Stage 1 中的哪个子模型。
- Stage 1:
M_1个子模型,每个负责一段 key 范围。输出用于选择 Stage 2 中的子模型。 - Stage 2(叶子层):
M_2个子模型,每个输出最终的预测位置。
查找流程如下:
key -> Stage 0 model -> index i
-> Stage 1 model[i] -> index j
-> Stage 2 model[j] -> predicted position p
-> local search in [p - err, p + err]
每个 stage 的模型可以是不同类型的:线性回归、简单神经网络、样条函数等。论文中的实验表明,即使所有层都用简单的线性模型,效果也相当不错。
RMI 的训练过程是自顶向下的:
def train_rmi(data, keys, stages):
"""
data: 排序后的 key 数组
stages: 每层的模型数量,如 [1, 100, 10000]
"""
models = []
# Stage 0:在所有数据上训练根模型
root = LinearRegression()
root.fit(keys, positions)
models.append([root])
for stage_idx in range(1, len(stages)):
stage_models = [None] * stages[stage_idx]
# 用上一层的预测来分配 key 到当前层的模型
for key, pos in zip(keys, positions):
predicted = models[stage_idx - 1][...].predict(key)
model_idx = int(predicted * stages[stage_idx] / len(data))
model_idx = clamp(model_idx, 0, stages[stage_idx] - 1)
# 将 (key, pos) 分配给 stage_models[model_idx]
assign(stage_models[model_idx], key, pos)
# 训练当前层的每个模型
for m in stage_models:
m.fit()
models.append(stage_models)
return modelsRMI 的优点:
- 空间效率极高:每个线性模型只需要两个参数(斜率和截距), 10000 个模型也只占 80KB。相比之下,一个百万条记录的 B-tree 内部节点可能占数 MB。
- 缓存友好:模型参数紧凑,容易放进 L1/L2 缓存。
- 查找速度快:2-3 次线性模型推理 + 局部搜索,总时间常数很小。
RMI 的局限:
- 不支持插入和删除:这是最致命的问题,后面会详细讨论。
- 最大误差可能很大:某些分布下,个别叶子模型的误差可能非常大。
- 训练需要全量数据:不能增量训练。
三、PGM-index:最优分段线性近似
RMI 的一个问题是:它的模型结构(每层多少个模型)需要人工调参, 而且不能保证全局最优。 Ferragina 和 Vinciguerra 在 2020 年提出的 Piecewise Geometric Model index(PGM-index) 解决了这个问题。
PGM-index 的核心思想非常优雅:给定误差阈值
epsilon,用最少的线段来近似
CDF。这是一个经典的计算几何问题,有 O(N)
时间的最优算法。然后在分段之上递归构建索引,形成一棵树——树的每一层都是对下层段地址的
PGM 近似。
分段线性近似的关键算法是 optimal piecewise linear
approximation。给定有序点集
{(key_i, i)} 和误差阈值
epsilon,找到最少数量的线段,使得每条线段覆盖的所有点到线段的距离不超过
epsilon。
// 最优分段线性近似的核心算法(简化版)
struct Segment {
double slope;
double intercept;
size_t start_key;
size_t start_pos;
};
std::vector<Segment> optimal_pla(
const std::vector<int64_t>& keys,
size_t epsilon
) {
std::vector<Segment> segments;
size_t seg_start = 0;
while (seg_start < keys.size()) {
// 初始化可行域(一个锥形区域)
double upper_slope = INFINITY;
double lower_slope = -INFINITY;
double intercept_at_start = seg_start;
size_t seg_end = seg_start;
for (size_t i = seg_start + 1; i < keys.size(); ++i) {
double dx = keys[i] - keys[seg_start];
if (dx == 0) { seg_end = i; continue; }
double upper = (double(i - seg_start) + epsilon) / dx;
double lower = (double(i - seg_start) - epsilon) / dx;
// 收缩可行域
upper_slope = std::min(upper_slope, upper);
lower_slope = std::max(lower_slope, lower);
if (upper_slope < lower_slope) {
// 可行域为空,当前线段到此结束
break;
}
seg_end = i;
}
double slope = (upper_slope + lower_slope) / 2.0;
segments.push_back({slope, intercept_at_start,
(size_t)keys[seg_start], seg_start});
seg_start = seg_end + 1;
}
return segments;
}PGM-index 的递归构建过程:
Level 0: 对原始数据做分段线性近似,得到 S_0 个段
Level 1: 对 S_0 个段的起始位置做分段线性近似,得到 S_1 个段
Level 2: 对 S_1 个段的起始位置做分段线性近似,得到 S_2 个段
...
直到某层只有一个段(根节点)
查找过程:
1. 从根节点开始
2. 用当前层的线性模型预测下一层的段索引
3. 在 [predicted - epsilon, predicted + epsilon] 范围内二分查找精确段
4. 进入下一层,重复 2-3
5. 到达最底层后,预测数据位置并局部搜索
PGM-index 的理论保证非常强:
| 属性 | 值 |
|---|---|
| 空间复杂度 | O(N / epsilon) |
| 查找时间 | O(log(N/epsilon) * log(epsilon)) |
| 构建时间 | O(N) |
| 误差保证 | 每层最多 epsilon |
与 RMI 相比,PGM-index 不需要调参,自动适应数据分布,并且有严格的最坏情况保证。
四、ALEX:支持动态更新的学习索引
RMI 和 PGM-index 都是只读结构。现实世界的数据库需要频繁地插入和删除数据, 这怎么办? 2020 年,Microsoft Research 的 Ding 等人提出了 ALEX(Adaptive Learned indEX), 首次让学习索引真正可用于动态工作负载。
ALEX 的核心设计有三个:
1. Gapped Array(带间隙的数组)
传统数组在中间插入需要 O(N) 时间来移动元素。
ALEX
在数组中预留了间隙(gap),使得插入只需要移动局部的少量元素。
传统数组: [1, 3, 5, 7, 9, 11, 13, 15]
ALEX gapped array: [1, _, 3, 5, _, _, 7, 9, _, 11, _, 13, 15, _]
^ ^^^ ^ ^ ^
gap gaps gap gap gap
间隙的密度根据模型预测的插入位置动态调整。如果某个区域插入频繁,那里就会有更多间隙。
2. 指数搜索代替二分搜索
学习索引预测的位置有误差,需要局部搜索。 ALEX 选择了指数搜索(exponential search)而非二分搜索:
// 指数搜索:从预测位置开始,步长倍增
size_t exponential_search(
const int64_t* data,
size_t predicted_pos,
int64_t key,
size_t data_size
) {
// 先判断方向
if (data[predicted_pos] == key) return predicted_pos;
if (data[predicted_pos] < key) {
// 向右搜索
size_t step = 1;
size_t pos = predicted_pos;
while (pos + step < data_size && data[pos + step] < key) {
step *= 2;
}
// 在 [pos + step/2, min(pos + step, data_size-1)] 内二分
return binary_search(data, pos + step / 2,
std::min(pos + step, data_size - 1), key);
} else {
// 向左搜索
size_t step = 1;
size_t pos = predicted_pos;
while (pos >= step && data[pos - step] > key) {
step *= 2;
}
size_t lo = (pos >= step) ? pos - step : 0;
return binary_search(data, lo, pos - step / 2, key);
}
}指数搜索的妙处在于:当预测精确时(误差为
e),只需 O(log e) 步;
而二分搜索在固定范围内始终是
O(log(2 * max_err))。
对于模型预测精度好的区域,指数搜索更快。
3. 自适应节点分裂与合并
ALEX 的树结构会根据工作负载动态调整:
- 当一个叶节点的间隙用完、冲突过多时,分裂为多个子节点。
- 当相邻叶节点的利用率过低时,合并成一个节点。
- 分裂和合并时,重新训练对应的线性模型。
这与 B-tree 的分裂/合并机制非常相似,但 ALEX 的触发条件和重组策略都考虑了模型质量。ALEX 的整体结构如下:
[Internal Node: Linear Model]
/ | \
[Internal] [Internal] [Internal]
/ \ / \ / \
[Leaf] [Leaf] [Leaf] [Leaf] [Leaf] [Leaf]
| | | | | |
[GA] [GA] [GA] [GA] [GA] [GA]
GA = Gapped Array
每个内部节点包含一个线性模型和一个指针数组。每个叶节点包含一个线性模型和一个 gapped array。
ALEX 在 SOSD benchmark 上的表现非常出色:在大部分分布下,查找速度与最优的只读学习索引相当,同时支持每秒数百万次的插入操作。
五、LIPP:面向写入优化的学习索引
2021 年,Wu 等人提出的 LIPP(Learned Index with Precise Positions) 走了一条不同的路。LIPP 的核心观点是:与其像 ALEX 那样在 gapped array 上做文章,不如直接把插入的元素放到模型预测的位置上。
LIPP 的设计要点:
冲突链
当两个 key 被模型预测到同一个位置时,形成冲突。 LIPP 不移动已有元素,而是把冲突的元素挂在一个链上:
Slot[i]: key_a -> key_b -> key_c (collision chain)
Slot[i+1]: key_d
Slot[i+2]: (empty)
Slot[i+3]: key_e -> key_f
模型与节点设计
每个 LIPP 节点包含一个线性模型、一个数组(大小根据模型和预期数据量确定), 以及每个槽中存放的 key-value 对、子节点指针或冲突链。
节点重建
当冲突链变长(超过阈值),节点触发重建: - 收集节点中所有 key - 用新的线性模型重新拟合 - 扩大数组容量以减少冲突
// LIPP 节点重建(概念代码)
void rebuild_node(LIPPNode* node) {
// 1. 收集所有 key-value 对
auto all_kvs = collect_all_kvs(node);
// 2. 计算新的容量(通常是 key 数量的 1.2-1.5 倍)
size_t new_capacity = all_kvs.size() * expansion_factor;
// 3. 训练新的线性模型
LinearModel new_model;
new_model.train(all_kvs, new_capacity);
// 4. 重新分配 key 到新数组
node->resize(new_capacity);
node->model = new_model;
for (auto& kv : all_kvs) {
size_t pos = new_model.predict(kv.key);
node->insert_at(pos, kv);
}
}LIPP 的优势在于插入非常快——大多数情况下只需要一次模型预测和一次数组写入,不需要移动元素。代价是查找时可能需要遍历冲突链。在写入密集型工作负载下,LIPP 比 ALEX 有明显优势。
六、可更新学习索引的核心挑战
学习索引面临的最根本挑战是:插入操作会破坏 CDF 模型。
考虑一个简单的例子。假设我们有 100 个均匀分布的 key:
keys: [0, 1, 2, 3, ..., 99]
CDF: 0, 0.01, 0.02, ..., 1.0
线性模型 pos = key 完美拟合。现在插入 50
个新 key,全部在 [40, 60] 范围内:
keys: [0, 1, ..., 39, 40, 40.1, 40.2, ..., 59.8, 59.9, 60, 61, ..., 99]
原来的线性模型完全失效了——[40, 60] 范围内的
CDF 陡然变陡, 模型的预测误差急剧增大。
各种可更新学习索引对这个问题的解法不同,但基本思路可以归纳为:
| 策略 | 代表 | 优点 | 缺点 |
|---|---|---|---|
| 预留间隙 | ALEX | 插入快,不需要立即重训练 | 间隙浪费空间 |
| 冲突链 | LIPP | 写入极快,无需移动元素 | 读取需遍历链 |
| 增量重建 | PGM-index (dynamic) | 保持严格误差界 | 重建代价高 |
| Delta buffer | FITing-Tree | 分摊重建成本 | 查找需检查 buffer |
| 版本化模型 | DILI | 多版本模型平滑切换 | 实现复杂 |
一个通用的 delta buffer 方案是这样的:
写入流程:
1. 新数据先写入内存中的 buffer(通常是一个小型 B-tree 或跳表)
2. 当 buffer 大小超过阈值时,与主索引合并(merge)
3. 合并时重新训练模型
读取流程:
1. 先查主索引(学习索引)
2. 再查 buffer
3. 取两者中更新的结果
这和 LSM-tree 的思路非常相似。实际上,已经有研究者探索将学习索引与 LSM-tree 结合,例如 Bourbon(OSDI 2020)就是在 LevelDB 中用学习索引替换每层的 B-tree 索引。
七、学习化 Bloom Filter:用模型替代概率数据结构
Kraska 2018 年论文的第二部分提出了一个同样大胆的想法: 用机器学习模型替代 Bloom filter。
传统 Bloom filter 的工作原理:
Insert(key): 对 key 计算 k 个哈希,将对应 bit 置 1
Query(key): 对 key 计算 k 个哈希,检查所有 bit 是否为 1
-> 全部为 1:可能存在(false positive 可能)
-> 有 0:一定不存在
Bloom filter 的空间大小取决于目标 false positive rate(FPR)。 对于 1% 的 FPR,每个元素大约需要 10 bit。
学习化 Bloom filter 的思路是:将”key 是否在集合中”视为一个二分类问题。 训练一个模型来区分”存在的 key”和”不存在的 key”。
def learned_bloom_filter(model, backup_filter, key):
# 模型先做预测
score = model.predict(key) # 输出 [0, 1] 之间的概率
if score > threshold:
return True # 模型认为 key 存在
else:
# 模型不确定的情况,用备份 Bloom filter 兜底
return backup_filter.query(key)这个设计有一个关键的细节:模型的 false negative(漏判)是不可接受的, 因为 Bloom filter 的核心保证是没有 false negative。 所以需要一个备份 Bloom filter 来存储模型预测为”不存在”但实际存在的 key。
空间节省的来源在于: - 如果模型足够好,大多数”存在的 key”都能被正确预测 - 备份 Bloom filter 只需要存储很少的 key(模型漏掉的那些) - 模型本身的大小通常比全量 Bloom filter 小得多
论文报告称,在某些数据集上,学习化 Bloom filter 比传统 Bloom filter 节省 70% 的空间,同时保持相同的 FPR。
但有几个工程上的问题需要注意:
模型训练需要负样本:你需要知道”不在集合中的 key 长什么样”。在某些场景下(如 LSM-tree 的层间过滤),负样本就是其他层的 key,很容易获取。但在通用场景下,负样本的生成并不总是直观的。
模型推理延迟:即使是简单的神经网络,推理时间也比几次哈希计算慢得多。在延迟敏感的场景下需要权衡。
更新代价:集合变化时需要重训练模型,或者用备份 filter 吸收增量。
// 学习化 Bloom filter 的简化实现
class LearnedBloomFilter {
NeuralNetwork model_;
BloomFilter backup_;
double threshold_;
public:
void build(const std::vector<int64_t>& keys,
const std::vector<int64_t>& non_keys,
double target_fpr) {
// 训练二分类模型
model_.train(keys, non_keys);
// 找到使 false negative rate = 0 的阈值
threshold_ = find_threshold(model_, keys);
// 把模型漏掉的 key 放入备份 filter
std::vector<int64_t> missed;
for (auto k : keys) {
if (model_.predict(k) < threshold_) {
missed.push_back(k);
}
}
backup_.build(missed, target_fpr);
}
bool query(int64_t key) {
if (model_.predict(key) >= threshold_) {
return true;
}
return backup_.query(key);
}
};八、完整 C++ 实现:简单学习索引
下面给出一个完整的、可编译运行的 C++ 实现。 它包含一个简单的两层 RMI(根模型 + 叶子模型),支持点查询和范围查询。
// learned_index.hpp
// 一个完整的简单学习索引实现
// 支持:构建、点查询、范围查询
// 模型:两层 RMI,所有层使用线性回归
#pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <numeric>
#include <vector>
namespace learned_index {
// 线性模型:pos = slope * key + intercept
struct LinearModel {
double slope = 0.0;
double intercept = 0.0;
// 用最小二乘法训练
void train(const std::vector<int64_t>& keys,
const std::vector<double>& positions) {
assert(keys.size() == positions.size());
size_t n = keys.size();
if (n == 0) { slope = 0; intercept = 0; return; }
if (n == 1) { slope = 0; intercept = positions[0]; return; }
double sum_x = 0, sum_y = 0, sum_xy = 0, sum_xx = 0;
for (size_t i = 0; i < n; ++i) {
double x = static_cast<double>(keys[i]);
double y = positions[i];
sum_x += x;
sum_y += y;
sum_xy += x * y;
sum_xx += x * x;
}
double denom = n * sum_xx - sum_x * sum_x;
if (std::abs(denom) < 1e-10) {
slope = 0;
intercept = sum_y / n;
} else {
slope = (n * sum_xy - sum_x * sum_y) / denom;
intercept = (sum_y - slope * sum_x) / n;
}
}
double predict(int64_t key) const {
return slope * static_cast<double>(key) + intercept;
}
};
// 叶子模型:线性模型 + 误差界
struct LeafModel {
LinearModel model;
size_t start_pos = 0; // 该模型负责的数据起始位置
size_t end_pos = 0; // 该模型负责的数据结束位置
int64_t min_key = 0;
int64_t max_key = 0;
size_t max_error = 0; // 最大预测误差
};
class SimpleLearnedIndex {
public:
// 构建索引
void build(const std::vector<int64_t>& sorted_data,
size_t num_leaf_models = 0) {
data_ = sorted_data;
size_t n = data_.size();
if (n == 0) return;
// 默认叶子模型数量:sqrt(N)
if (num_leaf_models == 0) {
num_leaf_models = std::max<size_t>(
1, static_cast<size_t>(std::sqrt(n)));
}
num_leaves_ = num_leaf_models;
// --- 训练根模型 ---
std::vector<int64_t> all_keys(n);
std::vector<double> all_positions(n);
for (size_t i = 0; i < n; ++i) {
all_keys[i] = data_[i];
all_positions[i] = static_cast<double>(i);
}
root_model_.train(all_keys, all_positions);
// --- 将 key 分配到叶子模型 ---
leaves_.resize(num_leaves_);
std::vector<std::vector<int64_t>> leaf_keys(num_leaves_);
std::vector<std::vector<double>> leaf_positions(num_leaves_);
for (size_t i = 0; i < n; ++i) {
double pred = root_model_.predict(data_[i]);
// 将预测值映射到叶子模型索引
size_t leaf_idx = static_cast<size_t>(
pred / static_cast<double>(n) * num_leaves_);
leaf_idx = std::min(leaf_idx, num_leaves_ - 1);
leaf_keys[leaf_idx].push_back(data_[i]);
leaf_positions[leaf_idx].push_back(
static_cast<double>(i));
}
// --- 训练每个叶子模型 ---
for (size_t i = 0; i < num_leaves_; ++i) {
auto& leaf = leaves_[i];
if (leaf_keys[i].empty()) {
// 空叶子节点,设置为默认
leaf.start_pos = 0;
leaf.end_pos = 0;
leaf.max_error = 0;
continue;
}
leaf.model.train(leaf_keys[i], leaf_positions[i]);
leaf.min_key = leaf_keys[i].front();
leaf.max_key = leaf_keys[i].back();
leaf.start_pos = static_cast<size_t>(
leaf_positions[i].front());
leaf.end_pos = static_cast<size_t>(
leaf_positions[i].back());
// 计算最大误差
leaf.max_error = 0;
for (size_t j = 0; j < leaf_keys[i].size(); ++j) {
double pred = leaf.model.predict(leaf_keys[i][j]);
double actual = leaf_positions[i][j];
size_t err = static_cast<size_t>(
std::abs(pred - actual)) + 1;
leaf.max_error = std::max(leaf.max_error, err);
}
}
}
// 点查询:返回 key 的位置,未找到返回 -1
int64_t lookup(int64_t key) const {
if (data_.empty()) return -1;
size_t n = data_.size();
// Stage 0:根模型选择叶子
double root_pred = root_model_.predict(key);
size_t leaf_idx = static_cast<size_t>(
root_pred / static_cast<double>(n) * num_leaves_);
leaf_idx = std::min(leaf_idx, num_leaves_ - 1);
const auto& leaf = leaves_[leaf_idx];
if (leaf.start_pos == leaf.end_pos &&
leaf.max_error == 0) {
return -1;
}
// Stage 1:叶子模型预测位置
double pred = leaf.model.predict(key);
int64_t predicted_pos = static_cast<int64_t>(
std::round(pred));
// 限制搜索范围
int64_t lo = std::max<int64_t>(
0, predicted_pos - static_cast<int64_t>(
leaf.max_error));
int64_t hi = std::min<int64_t>(
static_cast<int64_t>(n) - 1,
predicted_pos + static_cast<int64_t>(
leaf.max_error));
lo = std::max<int64_t>(lo, 0);
hi = std::min<int64_t>(hi,
static_cast<int64_t>(n) - 1);
// 二分搜索精确定位
while (lo <= hi) {
int64_t mid = lo + (hi - lo) / 2;
if (data_[mid] == key) return mid;
if (data_[mid] < key) lo = mid + 1;
else hi = mid - 1;
}
return -1;
}
// 范围查询:返回 [lo_key, hi_key] 范围内的所有 key
std::vector<int64_t> range_query(
int64_t lo_key, int64_t hi_key) const {
std::vector<int64_t> result;
if (data_.empty()) return result;
// 找到 lo_key 的起始位置
int64_t start = find_lower_bound(lo_key);
if (start < 0) start = 0;
// 扫描直到超出 hi_key
for (size_t i = static_cast<size_t>(start);
i < data_.size() && data_[i] <= hi_key; ++i) {
if (data_[i] >= lo_key) {
result.push_back(data_[i]);
}
}
return result;
}
// 获取统计信息
void print_stats() const {
size_t total_max_err = 0;
size_t non_empty = 0;
for (const auto& leaf : leaves_) {
if (leaf.max_error > 0) {
total_max_err += leaf.max_error;
++non_empty;
}
}
double avg_err = non_empty > 0
? static_cast<double>(total_max_err) / non_empty
: 0;
std::cout << "=== Learned Index Stats ===" << '\n';
std::cout << "Data size: " << data_.size()
<< '\n';
std::cout << "Leaf models: " << num_leaves_
<< '\n';
std::cout << "Non-empty leaves:" << non_empty
<< '\n';
std::cout << "Avg max error: " << avg_err
<< '\n';
std::cout << "Model memory: "
<< model_size_bytes() << " bytes"
<< '\n';
}
size_t model_size_bytes() const {
// 根模型:2 * sizeof(double) = 16 bytes
// 每个叶子模型:2 * sizeof(double) + metadata
return 16 + num_leaves_ * sizeof(LeafModel);
}
private:
int64_t find_lower_bound(int64_t key) const {
size_t n = data_.size();
double root_pred = root_model_.predict(key);
size_t leaf_idx = static_cast<size_t>(
root_pred / static_cast<double>(n) * num_leaves_);
leaf_idx = std::min(leaf_idx, num_leaves_ - 1);
const auto& leaf = leaves_[leaf_idx];
double pred = leaf.model.predict(key);
int64_t pos = static_cast<int64_t>(std::round(pred));
int64_t lo = std::max<int64_t>(
0, pos - static_cast<int64_t>(leaf.max_error));
int64_t hi = std::min<int64_t>(
static_cast<int64_t>(n) - 1,
pos + static_cast<int64_t>(leaf.max_error));
// 标准 lower_bound
while (lo < hi) {
int64_t mid = lo + (hi - lo) / 2;
if (data_[mid] < key) lo = mid + 1;
else hi = mid;
}
return lo;
}
std::vector<int64_t> data_;
LinearModel root_model_;
std::vector<LeafModel> leaves_;
size_t num_leaves_ = 0;
};
} // namespace learned_index下面是测试和基准程序:
// benchmark.cpp
// 编译:g++ -O2 -std=c++17 -o benchmark benchmark.cpp
// 运行:./benchmark
#include "learned_index.hpp"
#include <chrono>
#include <random>
using namespace learned_index;
using Clock = std::chrono::high_resolution_clock;
// 生成不同分布的数据
std::vector<int64_t> generate_uniform(size_t n) {
std::vector<int64_t> data(n);
for (size_t i = 0; i < n; ++i) data[i] = i * 10;
return data;
}
std::vector<int64_t> generate_normal(size_t n) {
std::mt19937_64 rng(42);
std::normal_distribution<double> dist(5e8, 1e7);
std::vector<int64_t> data(n);
for (size_t i = 0; i < n; ++i) {
data[i] = static_cast<int64_t>(dist(rng));
}
std::sort(data.begin(), data.end());
data.erase(std::unique(data.begin(), data.end()),
data.end());
return data;
}
std::vector<int64_t> generate_lognormal(size_t n) {
std::mt19937_64 rng(42);
std::lognormal_distribution<double> dist(20.0, 2.0);
std::vector<int64_t> data(n);
for (size_t i = 0; i < n; ++i) {
data[i] = static_cast<int64_t>(dist(rng));
}
std::sort(data.begin(), data.end());
data.erase(std::unique(data.begin(), data.end()),
data.end());
return data;
}
double benchmark_learned_index(
const std::vector<int64_t>& data,
const std::vector<int64_t>& queries
) {
SimpleLearnedIndex idx;
idx.build(data);
auto start = Clock::now();
volatile int64_t sink = 0;
for (auto q : queries) {
sink = idx.lookup(q);
}
auto end = Clock::now();
double ns = std::chrono::duration_cast<
std::chrono::nanoseconds>(end - start).count();
return ns / queries.size();
}
double benchmark_binary_search(
const std::vector<int64_t>& data,
const std::vector<int64_t>& queries
) {
auto start = Clock::now();
volatile bool sink = false;
for (auto q : queries) {
sink = std::binary_search(
data.begin(), data.end(), q);
}
auto end = Clock::now();
double ns = std::chrono::duration_cast<
std::chrono::nanoseconds>(end - start).count();
return ns / queries.size();
}
int main() {
const size_t N = 10'000'000;
const size_t Q = 1'000'000;
struct TestCase {
std::string name;
std::vector<int64_t> data;
};
std::vector<TestCase> tests = {
{"Uniform", generate_uniform(N)},
{"Normal", generate_normal(N)},
{"Lognormal", generate_lognormal(N)},
};
for (auto& tc : tests) {
// 生成查询(随机选取已有 key)
std::mt19937_64 rng(123);
std::uniform_int_distribution<size_t> dist(
0, tc.data.size() - 1);
std::vector<int64_t> queries(Q);
for (size_t i = 0; i < Q; ++i) {
queries[i] = tc.data[dist(rng)];
}
std::cout << "\n--- " << tc.name
<< " (N=" << tc.data.size() << ") ---\n";
SimpleLearnedIndex idx;
idx.build(tc.data);
idx.print_stats();
double learned_ns = benchmark_learned_index(
tc.data, queries);
double binary_ns = benchmark_binary_search(
tc.data, queries);
std::cout << "Learned index: "
<< learned_ns << " ns/query\n";
std::cout << "Binary search: "
<< binary_ns << " ns/query\n";
std::cout << "Speedup: "
<< binary_ns / learned_ns << "x\n";
}
return 0;
}九、学习索引何时胜出、何时败退:数据分布的决定性影响
学习索引的性能高度依赖数据分布。这不是一个小细节,而是决定学习索引能否落地的核心因素。
学习索引占优的场景
1. 近似均匀分布的数据
当 key 近似均匀分布时,CDF 接近一条直线,一个简单的线性模型就能达到极低的误差。这是学习索引的最佳场景。
数据分布: key = [0, 1, 2, 3, ..., N-1]
CDF: 几乎完美的直线
线性模型误差:接近 0
查找速度: 1-2 次内存访问
vs B-tree: O(log N) 次指针追踪,每次可能 cache miss
典型的真实例子:时间戳索引(如日志系统中的时间戳),自增主键索引。
2. 分段线性的数据
如果 CDF 可以用少量的线段近似(即数据呈现几段不同斜率的线性趋势),PGM-index 这类分段线性结构就非常高效。
3. 平滑的 CDF
正态分布、Zipf 分布等,虽然不是线性的,但 CDF 是平滑的。用少量的模型就能达到很低的误差。
学习索引吃亏的场景
1. 高度不规则的分布
如果数据中有大量的”密度突变”(某些区域极稠密,某些极稀疏),模型很难学好,误差会很大。
数据示例:
[1, 2, 3, 1000000, 1000001, 1000002, 2000000000, ...]
密度突变:误差可能达到数百甚至数千
2. 字符串 key
学习索引最初是为数值 key 设计的。对于变长字符串 key,需要先将字符串映射到数值空间,这个映射本身就可能引入大量误差。虽然有 SIndex(SIGMOD 2020)等工作尝试解决这个问题,但目前效果仍不如 B-tree。
3. 多维 key
学习索引在高维 key 上的表现更差。虽然有 Flood(SIGMOD 2020)和 Tsunami(VLDB 2020)等工作,但多维学习索引目前仍处于研究早期。
SOSD Benchmark 的启示
Search on Sorted Data(SOSD)benchmark 提供了一组标准化的数据集和评测框架。来自 SOSD 的关键发现:
查找延迟 (ns/op)
数据集 B-tree PGM ALEX RMI
---------------------------------------------
uniform_dense 350 95 105 80
normal 380 120 115 110
lognormal 400 180 150 200
books (real) 420 160 140 130
osm (real) 450 350 250 280
wiki_ts (real) 410 130 125 115
注意:以上数字为概念性示意,实际数值取决于硬件和具体实现。关键观察是:在大多数分布下学习索引都比 B-tree 快,但在某些真实数据集(如 OSM 地理数据)上优势不明显甚至消失。
十、基准测试:学习索引 vs B+tree 的真实对比
为了更具体地理解学习索引的优劣,让我们做一组更细致的对比分析。
测试维度
维度 选项
-----------------------------------------------
数据规模 1M, 10M, 100M, 1B
数据分布 uniform, normal, lognormal, zipf
操作类型 point lookup, range scan, insert
读写比例 100:0, 95:5, 50:50, 0:100
点查询性能
对于只读的点查询,学习索引的优势主要来自两个方面:
更少的内存访问:B+tree 的每一层都是一次指针追踪,在大数据量下很可能导致 cache miss。学习索引的模型推理只涉及少量浮点运算,全部在寄存器/L1 缓存中完成。
更小的索引体积:一个 10M 条记录的 B+tree,内部节点可能占 50-100MB。同样数据量的 RMI(10000 个线性模型)只需要 160KB。
内存占用对比(10M 条 int64 key):
B+tree (fanout=256):
层数 = log_256(10M) ≈ 3
内部节点 ≈ 10M / 256 + 10M / 256^2 + 1 ≈ 39K 节点
每节点 4KB → 总计 ≈ 156 MB
RMI (10000 leaf models):
根模型: 16 bytes
叶子模型: 10000 * 56 bytes ≈ 547 KB
总计 ≈ 548 KB
PGM-index (epsilon=64):
段数 ≈ 10M / 64 ≈ 156K
每段 24 bytes → 总计 ≈ 3.6 MB
范围查询性能
范围查询的性能差异主要在”定位起始位置”这一步。一旦找到起始位置,后续的顺序扫描两者没有区别(都是遍历有序数组)。
# 范围查询伪代码
def range_query(index, lo, hi):
start_pos = index.lookup_lower_bound(lo) # 这一步有差异
result = []
for i in range(start_pos, len(data)): # 这一步没差异
if data[i] > hi:
break
result.append(data[i])
return result插入性能
这是学习索引最弱的环节。在 50:50 读写比例下:
吞吐量 (Mops/s)
索引 读写 50:50 纯写
---------------------------------------------
B+tree (STX) 4.2 3.8
ALEX 3.5 2.1
LIPP 3.8 3.2
PGM (dynamic) 2.8 1.5
B+tree 在写入密集场景下仍然是王者,因为它的分裂/合并机制经过了数十年的工程优化。学习索引的插入涉及到模型失效检测和潜在的重训练,开销更大。
空间-时间权衡
学习索引的一个独特优势是可以通过调整误差阈值来精确控制空间-时间权衡。
PGM-index 在不同 epsilon 下的表现:
epsilon 段数 索引大小 查找时间
-------------------------------------------------
4 2,500,000 57.2 MB 85 ns
16 625,000 14.3 MB 120 ns
64 156,250 3.6 MB 165 ns
256 39,063 0.9 MB 210 ns
1024 9,766 0.2 MB 260 ns
这种平滑的权衡曲线是 B-tree 很难做到的——B-tree 的节点大小相对固定,调节空间占用的唯一方式是改变 fanout,而 fanout 的选择空间有限。
十一、工业界落地现状
学习索引从论文到产品的路还很长,但已经有一些值得关注的进展。
Google 是学习索引的发源地(Kraska 在 Google Brain 期间完成了核心论文)。据公开信息,Google 内部的部分系统已经采用了学习索引的思想,特别是在 Bigtable 和分布式排序系统中。但具体细节没有公开发表。
Amazon
Amazon 的 DynamoDB 团队在 SIGMOD 2022 上发表了关于学习索引在云数据库中应用的论文。他们发现学习索引在特定的只读工作负载下(如冷数据的二级索引)效果很好,但在通用的 OLTP 工作负载下仍然不如 B-tree。
学术界原型
以下项目是目前质量较高的开源实现:
| 项目 | 类型 | 语言 | 特点 |
|---|---|---|---|
| PGM-index | 只读/动态 | C++ | 理论最优,头文件实现 |
| ALEX | 动态 | C++ | Microsoft Research,支持读写 |
| LIPP | 动态 | C++ | 清华大学,写入优化 |
| RadixSpline | 只读 | C++ | 极简实现,适合嵌入 |
| CDFShop | 只读 | C++ | 自动选择最佳模型 |
| SOSD | Benchmark | C++ | 标准化评测框架 |
我的判断
学习索引目前还没有在主流数据库产品中被广泛采用。原因不是性能不够好,而是工程化的代价太高:
- 主流数据库的 B-tree/LSM-tree 实现经过了数十年的打磨,可靠性极高。替换核心索引结构的风险巨大。
- 学习索引的优势在特定数据分布下才显著,不具有普适性。
- 可更新学习索引的实现复杂度远高于传统索引。
我认为学习索引最可能首先在以下场景落地:只读分析型系统(如列存数据库、数据湖的索引层),嵌入式索引(如替代 LSM-tree 内部的 B-tree 索引),以及特化的内存数据库(如时序数据库,数据分布相对规律)。
十二、工程陷阱与个人看法
工程陷阱速查表
| 陷阱 | 描述 | 应对策略 |
|---|---|---|
| 模型训练成本 | 构建索引时需要训练模型,在数据量大时可能需要数秒甚至数分钟 | 增量训练;只在批量加载时训练,增量更新用 delta buffer |
| 尾延迟(tail latency) | 模型在个别 key 上的预测误差可能极大,导致最坏情况查找时间远超平均值 | 设置误差上限;误差过大的区域回退到 B-tree |
| 分布漂移(distribution shift) | 随着数据的不断插入,数据分布与训练时不同,模型逐渐失效 | 监控模型误差;定期重训练;自适应机制(如 ALEX 的节点分裂) |
| 浮点精度 | 线性模型使用 double 运算,在极端 key 值下可能出现精度问题 | 使用规范化的 key 表示;对预测位置做 clamp |
| 并发控制 | 学习索引的并发更新比 B-tree 更复杂(模型重训练需要全局锁或 RCU) | 乐观并发控制;读写分离;版本化模型 |
| 字符串 key | 将字符串映射到数值空间并不直观,映射质量影响模型效果 | 前缀编码;字典编码;只在数值 key 上使用学习索引 |
| 空间碎片化 | Gapped array 和冲突链可能导致内存碎片 | 定期 compaction;使用内存池 |
| 批量加载优化 | 初始数据加载时应该利用数据已排序的特性 | 直接用有序数据训练,不需要先构建再插入 |
| 模型选择 | 不同的数据分布适合不同类型的模型(线性/样条/小型神经网络) | CDFShop 等自动选择工具;先 profile 再决定 |
| 多线程扩展 | 模型推理本身是无状态的(适合并发),但更新操作需要同步 | 读路径无锁;写路径分区或批量化 |
个人看法
学习索引是一个真正的范式转换,但它的影响力更多体现在思维方式上,而非直接替代传统索引。
Kraska 最大的贡献不在于某个具体的数据结构,而在于他打开了一扇门:系统组件可以被视为学习问题。这个思路后来催生了 learned cardinality estimation、learned query optimization、learned scheduling 等一系列工作,统称为 “ML for Systems”。
从纯工程角度看,我认为学习索引在以下方面已经被证明是有价值的:
空间效率:学习索引的空间占用通常比 B-tree 小一到两个数量级。在内存紧张的场景下,这个优势非常实际。
只读查询速度:在数据分布友好的场景下,学习索引的查找速度确实优于 B-tree。但”数据分布友好”这个前提条件往往被低估了。
激发传统索引的改进:学习索引的研究推动了人们重新审视 B-tree 的设计。例如,更好的前缀压缩、自适应节点大小等优化,有些是受学习索引启发而提出的。
但我也有一些保留意见:
benchmark 的局限性:很多论文中的对比实验使用的是非常简单的 B-tree 实现。当对手换成高度优化的 B-tree(如 ART、HOT、FAST)时,学习索引的优势会显著缩小。
系统级影响被低估:在真实数据库中,索引查找只是整个查询执行的一小部分。即使索引查找快了 2 倍,端到端的查询延迟改善可能只有 10-20%。
运维复杂度:生产环境需要监控模型质量、处理分布漂移、管理模型版本。这些运维负担是传统索引不存在的。
总的来说,我把学习索引视为索引设计工具箱中的一个新工具,而非 B-tree 的终结者。在合适的场景下使用合适的工具,这才是工程师应有的态度。
上一篇: MVCC 实现变体全解 下一篇: TCP 拥塞控制
相关阅读: - B-tree 深度解剖 - 查询优化器