土法炼钢兴趣小组的算法知识备份

水塘抽样:未知大小数据流的公平抽样

目录

你有一条不知道何时结束的数据流,你只能扫描一遍,内存只够存 k 个样本。你怎么保证,当流结束的那一刻,每个元素被选中的概率恰好是 k/n?

这就是水塘抽样(Reservoir Sampling)要解决的问题。它的答案简洁到令人怀疑——一个 if 和一个随机数就够了。但简洁的背后是一个精致的概率不变量,以及三十多年来围绕它展开的一系列优化。

本文从最朴素的 Algorithm R 出发,逐步讲到跳跃优化的 Algorithm L、加权抽样的 A-Res 和 A-ExpJ,再到分布式场景下的合并策略。所有算法都附带完整的 C 实现。

水塘抽样替换过程

一、问题定义

1.1 形式化描述

给定一个数据流 S = a₁, a₂, …, aₙ,其中 n 事先未知。要求在单遍扫描(single pass)的约束下,维护一个大小恰好为 k 的样本集 R,使得当流结束时,S 中每个元素被选入 R 的概率均为 k/n。

约束条件:

  1. 单遍扫描:每个元素只被读取一次,不可回溯。
  2. 有限内存:只能存储 O(k) 个元素。
  3. 未知长度:处理第 i 个元素时不知道 n 的值。
  4. 均匀性:任意 k 元子集被选中的概率相等,即 1/C(n,k)。

1.2 为什么不能用简单方法

方法一:先数再抽。 需要两遍扫描,违反约束 1。在网络流、日志管道等场景中,数据流不可重放。

方法二:以固定概率 p 保留每个元素。 Bernoulli 采样可以做到单遍扫描,但样本数是一个随机变量——期望为 np,方差为 np(1-p)。当 n 未知时,无法事先确定 p 使得样本数恰好为 k。

方法三:保留所有数据再随机选。 违反约束 2,当 n 极大时内存不可承受。

水塘抽样的精妙之处在于:它在任意时刻 i 都维护了一个”到目前为止的完美样本”。这个不变量是整个算法的灵魂。

1.3 核心不变量

不变量 I(i):处理完第 i 个元素后(i ≥ k),水塘中的 k 个元素构成前 i 个元素的均匀随机样本。即对于任意 j ≤ i,元素 aⱼ 在水塘中的概率为 k/i。

一旦流结束(i = n),不变量自然给出 k/n 的均匀性保证。

二、Algorithm R:经典水塘抽样

2.1 算法描述

Algorithm R 由 Jeffrey Vitter 在 1985 年的论文 “Random Sampling with a Reservoir” 中正式分析,但其思想可以追溯到 Alan Waterman 的工作。

算法分两个阶段:

阶段一(填充):将前 k 个元素直接放入水塘。

阶段二(替换):对于第 i 个元素(i > k): 1. 生成一个 [1, i] 范围内的均匀随机整数 j。 2. 如果 j ≤ k,则用 aᵢ 替换水塘中位置 j 的元素。 3. 否则,丢弃 aᵢ。

2.2 伪代码

ALGORITHM-R(stream, k):
    R[1..k] ← stream 的前 k 个元素
    i ← k
    for each 后续元素 x in stream:
        i ← i + 1
        j ← RANDOM(1, i)
        if j ≤ k:
            R[j] ← x
    return R

2.3 C 实现

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>

/* 生成 [lo, hi] 之间的均匀随机整数 */
static long rand_range(long lo, long hi)
{
    assert(hi >= lo);
    long range = hi - lo + 1;
    /* 拒绝采样消除模偏差 */
    long limit = RAND_MAX - (RAND_MAX % range);
    long r;
    do {
        r = random();
    } while (r >= limit);
    return lo + r % range;
}

/*
 * Algorithm R — 经典水塘抽样
 *
 * reservoir: 预分配的大小为 k 的数组
 * stream:    输入数据流数组(模拟)
 * n:         流的长度
 * k:         样本大小
 *
 * 返回值: 实际采样数(min(n, k))
 */
int reservoir_sample_R(int *reservoir, const int *stream, int n, int k)
{
    if (n <= 0 || k <= 0)
        return 0;

    int actual_k = (n < k) ? n : k;

    /* 阶段一:填充 */
    for (int i = 0; i < actual_k; i++)
        reservoir[i] = stream[i];

    /* 阶段二:替换 */
    for (int i = k; i < n; i++) {
        long j = rand_range(0, i);  /* [0, i] 均匀随机 */
        if (j < k)
            reservoir[j] = stream[i];
    }

    return actual_k;
}

2.4 正确性证明:归纳法

定理:Algorithm R 在处理完第 i 个元素后,水塘中每个元素被选中的概率恰好为 k/i。

证明(数学归纳法)

基础情形:i = k。前 k 个元素全部放入水塘,每个元素的概率为 k/k = 1。不变量成立。

归纳步骤:假设处理完第 i-1 个元素后,水塘中每个元素的概率为 k/(i-1)。现在处理第 i 个元素。

对于新元素 aᵢ: - j 在 [1, i] 中均匀随机 - P(aᵢ 被选中) = P(j ≤ k) = k/i  ✓

对于已在水塘中的元素 aⱼ(j < i): - P(aⱼ 仍在水塘) = P(aⱼ 在第 i-1 步后仍在水塘) × P(aⱼ 不被 aᵢ 替换) - = (k/(i-1)) × (1 - 1/i) - = (k/(i-1)) × ((i-1)/i) - = k/i  ✓

其中”不被替换”的概率分析:aᵢ 被选中替换某个位置的概率是 k/i,而替换到 aⱼ 所在位置的概率是 1/k,所以 aⱼ 被替换的概率是 (k/i) × (1/k) = 1/i。因此不被替换的概率是 1 - 1/i = (i-1)/i。

归纳完毕。当流结束时 i = n,每个元素的概率为 k/n。 □

2.5 复杂度分析

指标 复杂度
时间 O(n)
空间 O(k)
随机数生成次数 n - k
比较次数 n - k

对于 n 远大于 k 的场景(比如 n = 10⁹, k = 1000),生成 10⁹ 个随机数的开销并不小——这正是 Algorithm L 要解决的问题。

三、Algorithm L:跳跃优化

3.1 核心观察

Algorithm R 的问题在于:对于每一个到达的元素,都要生成一个随机数来决定是否替换。但当 n 远大于 k 时,绝大多数元素会被丢弃(概率约为 1 - k/n)。能不能直接跳过那些一定会被丢弃的元素?

Algorithm L(同样由 Vitter 在 1985 年提出)正是基于这个思路。它的关键洞察是:

两次成功替换之间跳过的元素数,服从参数相关的几何分布。

更精确地说,当前计数器为 i 时,下一个被选中替换的元素的位置 i’ 满足:

skip = i' - i - 1

其中 skip 的分布可以通过连续近似来高效采样。

3.2 推导

假设当前已处理 i 个元素。第 i+1 个元素被选中的概率是 k/(i+1)。如果没被选中,继续看第 i+2 个,被选中概率 k/(i+2),以此类推。

我们需要的是”下一个被选中元素之前跳过了多少个”,即首次成功的位置。这本质上是一个带有变化概率的几何分布。

Vitter 提出了一个优雅的近似:利用连续化的方法,将跳跃距离计算转化为对均匀随机变量的函数变换。

设 W 为一个关键的权重变量,初始值 W = exp(-log(random())/k),其中 random() 产生 (0,1) 均匀分布。跳跃距离为:

skip = floor(log(random()) / log(1 - W))

每次替换后更新 W:

W = W * exp(-log(random()) / k)

3.3 伪代码

ALGORITHM-L(stream, k):
    R[1..k] ← stream 的前 k 个元素
    W ← exp(log(random()) / k)
    i ← k
    loop:
        skip ← floor(log(random()) / log(1 - W))
        i ← i + skip + 1
        if i > n: break
        R[RANDOM(0, k-1)] ← stream[i]
        W ← W * exp(log(random()) / k)
    return R

3.4 C 实现

#include <math.h>

/* 生成 (0, 1) 开区间的均匀随机浮点数 */
static double rand_uniform(void)
{
    double r;
    do {
        r = (double)random() / (double)RAND_MAX;
    } while (r == 0.0 || r == 1.0);
    return r;
}

/*
 * Algorithm L — 跳跃优化水塘抽样
 *
 * 与 Algorithm R 接口一致,但内部通过跳跃减少随机数生成次数。
 */
int reservoir_sample_L(int *reservoir, const int *stream, int n, int k)
{
    if (n <= 0 || k <= 0)
        return 0;

    int actual_k = (n < k) ? n : k;

    /* 阶段一:填充 */
    for (int i = 0; i < actual_k; i++)
        reservoir[i] = stream[i];

    if (n <= k)
        return actual_k;

    /* 初始化权重 W */
    double W = exp(log(rand_uniform()) / k);

    int i = k;
    while (1) {
        /* 计算跳跃距离 */
        double skip_d = floor(log(rand_uniform()) / log(1.0 - W));
        /* 防止 skip 溢出 */
        if (skip_d > (double)(n - i))
            break;
        int skip = (int)skip_d;
        i += skip + 1;
        if (i >= n)
            break;
        /* 替换水塘中随机位置 */
        int j = (int)(rand_uniform() * k);
        if (j >= k) j = k - 1;  /* 安全边界 */
        reservoir[j] = stream[i];
        /* 更新权重 */
        W *= exp(log(rand_uniform()) / k);
    }

    return actual_k;
}

3.5 复杂度分析

指标 Algorithm R Algorithm L
时间 O(n) O(k(1 + log(n/k)))
空间 O(k) O(k)
随机数次数 n - k O(k(1 + log(n/k)))
元素比较 n - k 0(跳跃直达)

当 n = 10⁹, k = 1000 时: - Algorithm R 需要约 10⁹ 个随机数 - Algorithm L 只需要约 1000 × (1 + log(10⁶)) ≈ 1000 × 15 = 15000 个随机数

这是三个数量级的差距。但要注意,Algorithm L 仍然需要遍历流中的所有元素——跳跃只是避免了对每个元素生成随机数,并不能跳过读取操作(除非流支持随机访问)。

3.6 跳跃距离的直觉

为什么跳跃距离的期望是 n/k?

处理第 i 个元素时,替换概率为 k/i。两次替换之间的间隔期望约为 i/k。对所有替换求和:

总替换次数 ≈ ∑(i=k+1 to n) k/i ≈ k × ln(n/k)

总跳跃距离 ≈ n - k(所有跳过的元素),分摊到 k × ln(n/k) 次替换上,平均每次跳跃 ≈ (n-k) / (k × ln(n/k))。

对于 n 远大于 k 的典型情况,每次跳跃的平均长度确实在 n/k 量级。

四、加权水塘抽样

4.1 问题扩展

实际应用中,元素往往不是等权重的。例如: - 日志采样:错误日志的权重应高于普通日志 - 广告展示:按出价权重采样 - 数据分析:按数据重要性采样

形式化:每个元素 aᵢ 附带权重 wᵢ > 0。要求元素 aᵢ 被选入样本的概率正比于 wᵢ。

4.2 A-Res 算法(Efraimidis-Spirakis 2006)

A-Res(Algorithm with Reservoir)是由 Efraimidis 和 Spirakis 在 2006 年提出的加权水塘抽样算法。

核心思想:为每个元素计算一个”键”(key),然后保留键最大的 k 个元素。

对于元素 aᵢ(权重 wᵢ),其键的计算方式为:

keyᵢ = random()^(1/wᵢ)

其中 random() 产生 (0, 1) 均匀分布。

为什么这样有效? 键 keyᵢ = U^(1/wᵢ) 的分布满足:权重越大的元素,其键的分布越偏向 1(即键更大),因此更可能被保留。可以证明,保留键最大的 k 个元素等价于按权重做加权无放回抽样(Weighted Sampling Without Replacement, WSwR)。

4.3 A-Res 实现

typedef struct {
    int    value;
    double key;
} weighted_item_t;

/* 最小堆操作:维护堆顶为最小键 */
static void heap_sift_down(weighted_item_t *heap, int size, int pos)
{
    while (1) {
        int smallest = pos;
        int left  = 2 * pos + 1;
        int right = 2 * pos + 2;
        if (left < size && heap[left].key < heap[smallest].key)
            smallest = left;
        if (right < size && heap[right].key < heap[smallest].key)
            smallest = right;
        if (smallest == pos) break;
        weighted_item_t tmp = heap[pos];
        heap[pos] = heap[smallest];
        heap[smallest] = tmp;
        pos = smallest;
    }
}

static void heap_build(weighted_item_t *heap, int size)
{
    for (int i = size / 2 - 1; i >= 0; i--)
        heap_sift_down(heap, size, i);
}

/*
 * A-Res 加权水塘抽样
 *
 * values:  元素值数组
 * weights: 对应权重数组
 * n:       流长度
 * k:       样本大小
 * result:  输出的样本值数组(大小 k)
 */
int weighted_reservoir_ares(int *result, const int *values,
                            const double *weights, int n, int k)
{
    if (n <= 0 || k <= 0)
        return 0;

    int actual_k = (n < k) ? n : k;
    weighted_item_t *heap = malloc(k * sizeof(weighted_item_t));
    if (!heap) return 0;

    /* 阶段一:填充 */
    for (int i = 0; i < actual_k; i++) {
        heap[i].value = values[i];
        heap[i].key   = pow(rand_uniform(), 1.0 / weights[i]);
    }

    if (n <= k) {
        for (int i = 0; i < actual_k; i++)
            result[i] = heap[i].value;
        free(heap);
        return actual_k;
    }

    /* 建最小堆 */
    heap_build(heap, k);

    /* 阶段二:替换 */
    for (int i = k; i < n; i++) {
        double key = pow(rand_uniform(), 1.0 / weights[i]);
        if (key > heap[0].key) {
            heap[0].value = values[i];
            heap[0].key   = key;
            heap_sift_down(heap, k, 0);
        }
    }

    for (int i = 0; i < k; i++)
        result[i] = heap[i].value;
    free(heap);
    return k;
}

4.4 A-ExpJ 算法:指数跳跃优化

A-ExpJ 是 A-Res 的跳跃优化版本,思路类似 Algorithm L 对 Algorithm R 的优化。

核心思想:堆顶的最小键 T 是一个阈值。新元素的键必须大于 T 才能进入水塘。利用这个阈值,可以计算出需要跳过多少个元素才会出现一个键大于 T 的元素。

具体跳跃策略:

  1. 设当前阈值为 T(堆顶的键)。
  2. 生成 X = log(random()) / log(T),这给出了跳跃的权重和。
  3. 逐个累加元素权重,直到累加和超过 X,此时的元素即为下一个要替换进水塘的。
  4. 计算新元素的键(在保证大于 T 的条件下采样)。
  5. 替换堆顶,调整堆,更新阈值。
/*
 * A-ExpJ 加权水塘抽样(跳跃优化)
 * 简化实现,展示核心思路
 */
int weighted_reservoir_aexpj(int *result, const int *values,
                             const double *weights, int n, int k)
{
    if (n <= 0 || k <= 0)
        return 0;

    int actual_k = (n < k) ? n : k;
    weighted_item_t *heap = malloc(k * sizeof(weighted_item_t));
    if (!heap) return 0;

    /* 阶段一:填充 */
    for (int i = 0; i < actual_k; i++) {
        heap[i].value = values[i];
        heap[i].key   = pow(rand_uniform(), 1.0 / weights[i]);
    }

    if (n <= k) {
        for (int i = 0; i < actual_k; i++)
            result[i] = heap[i].value;
        free(heap);
        return actual_k;
    }

    heap_build(heap, k);

    /* 阶段二:指数跳跃替换 */
    double threshold = heap[0].key;
    double weight_sum = 0.0;
    double jump_target = log(rand_uniform()) / log(threshold);

    int i = k;
    while (i < n) {
        weight_sum += weights[i];
        if (weight_sum >= jump_target) {
            /* 此元素被选中,计算条件键 */
            double tw = pow(threshold, weights[i]);
            double key = pow(tw + rand_uniform() * (1.0 - tw),
                            1.0 / weights[i]);
            /* 替换堆顶 */
            heap[0].value = values[i];
            heap[0].key   = key;
            heap_sift_down(heap, k, 0);
            /* 更新阈值和跳跃目标 */
            threshold = heap[0].key;
            weight_sum = 0.0;
            jump_target = log(rand_uniform()) / log(threshold);
        }
        i++;
    }

    for (int i = 0; i < k; i++)
        result[i] = heap[i].value;
    free(heap);
    return k;
}

4.5 加权算法对比

算法 随机数次数 每元素开销 适用场景
A-Res O(n) O(log k) 通用加权采样
A-ExpJ O(k log(n/k)) 期望 O(log k) 大流量、权重均匀
朴素拒绝法 不可预测 不固定 不推荐用于流式

五、分布式水塘抽样

5.1 问题场景

现代数据系统中,数据流往往分布在多个节点上。例如:

问题:每个节点独立做水塘抽样后,如何合并得到全局的均匀样本?

5.2 合并策略

假设有 m 个节点,第 j 个节点处理了 nⱼ 个元素(∑nⱼ = n),各自维护大小为 k 的水塘 Rⱼ。

方法一:带计数合并

每个节点除了水塘还要记录自己处理的元素总数 nⱼ。合并时:

  1. 汇总所有 m×k 个候选元素和对应的 nⱼ 到协调节点。
  2. 协调节点对这 m×k 个元素做加权水塘抽样,其中第 j 个节点的每个元素权重为 nⱼ/k。
  3. 或者更简单地:以概率 nⱼ/n 从第 j 个节点取样。

方法二:键排序合并(基于 A-Res)

如果每个节点使用 A-Res 算法并记录每个元素的键:

  1. 汇总所有 m×k 个(元素, 键)对。
  2. 取键最大的 k 个。

这种方法天然正确,因为 A-Res 的全局最优 k 个键就是全局样本。

5.3 分布式实现

typedef struct {
    int *reservoir;     /* 水塘数组 */
    int  k;             /* 水塘大小 */
    long count;         /* 已处理元素数 */
} reservoir_state_t;

/*
 * 合并两个水塘
 * 结果存入 dst,src 的内容被消费
 */
void reservoir_merge(reservoir_state_t *dst, const reservoir_state_t *src)
{
    long total = dst->count + src->count;

    /* 对于 src 中的每个元素,以 src->count / total 的概率纳入 */
    for (int i = 0; i < src->k; i++) {
        /* 元素在 src 的水塘中的"等效概率"是 1/src->count × src->k */
        /* 简化方案:从合并的 2k 个候选中随机选 k 个 */
        long j = rand_range(0, total - 1);
        if (j < src->count) {
            /* 用 src 的元素替换 dst 中的随机位置 */
            int pos = (int)rand_range(0, dst->k - 1);
            dst->reservoir[pos] = src->reservoir[i];
        }
    }

    dst->count = total;
}

/*
 * 更精确的合并:两个水塘按比例采样
 */
void reservoir_merge_exact(int *result, int k,
                           const int *r1, long n1,
                           const int *r2, long n2)
{
    long total = n1 + n2;
    int idx = 0;

    /* 从 r1 中以 n1/total 的期望比例采样 */
    for (int i = 0; i < k && idx < k; i++) {
        if (rand_range(1, total - idx) <= (long)(k * n1 / total) - idx / 2) {
            result[idx++] = r1[i];
        }
    }
    /* 从 r2 中补齐 */
    for (int i = 0; idx < k; i++) {
        result[idx++] = r2[i % k];
    }
}

5.4 MapReduce 中的水塘抽样

在 MapReduce 框架中的典型实现模式:

Map 阶段:每个 Mapper 对自己的数据分片做 Algorithm R/L,输出 (键, 水塘, 计数) 三元组。

Reduce 阶段:单个 Reducer 收集所有 Mapper 的水塘,用合并策略得到全局样本。

关键点:Reduce 阶段的内存开销是 O(m × k),其中 m 是 Mapper 数量。当 m 和 k 都不大时,这是可接受的。

六、应用场景

6.1 A/B 测试中的用户采样

在 A/B 测试中,需要从用户流量中均匀抽取一定比例的用户进入实验组。水塘抽样的优势在于:

  1. 精确控制样本量:不像 Bernoulli 采样那样样本量随机波动。
  2. 公平性保证:每个用户被选中的概率严格相等。
  3. 流式处理:不需要预先知道总用户数。

但实际的 A/B 测试系统通常使用基于哈希的确定性分流(如 hash(user_id) % 100 < threshold),因为需要确保同一用户在多次请求中被分到同一组。水塘抽样更适合一次性的批量采样场景。

6.2 数据库采样

许多数据库引擎在查询优化中使用水塘抽样来收集统计信息:

6.3 机器学习训练数据选择

在大规模机器学习场景中,训练数据量远超内存容量。水塘抽样用于:

  1. 数据子集选择:从 TB 级数据中均匀抽取训练子集。
  2. 在线学习的经验回放:强化学习中的 experience replay buffer 本质上就是水塘抽样。
  3. 数据流上的 SGD:每个 mini-batch 是从到目前为止见过的所有数据的均匀样本。

6.4 网络监控与流量分析

七、与其他采样方法的对比

7.1 对比表

特性 水塘抽样 Bernoulli 采样 分层采样 系统采样
样本量 精确 k 个 随机(期望 np) 每层精确 精确
需要预知 n
单遍扫描 需要两遍或预知层大小
均匀性 严格均匀 每元素独立 层内均匀 周期性
内存 O(k) O(np) O(样本量) O(样本量)
适合流式 有限支持 有限支持
可加权 是(A-Res) 可以 天然支持 不直接支持

7.2 Bernoulli 采样 vs 水塘抽样

Bernoulli 采样的规则简单:以概率 p 独立保留每个元素。

优势: - 实现更简单 - 元素间独立,天然适合分布式 - 不需要维护状态(除了 p)

劣势: - 样本量不确定:Var = np(1-p),当 n 大时标准差可达 √(np) - 需要预估 n 来设定合适的 p - 无法保证精确 k 个样本

在实践中,如果对样本量的精确性要求不高(如日志采样只要”大约 1%“),Bernoulli 采样因为简单和无状态而更受欢迎。如果必须精确控制样本量(如 A/B 测试需要精确的统计功效),水塘抽样更合适。

7.3 分层采样的互补

分层采样(Stratified Sampling)将总体分成若干互不相交的层(stratum),在每层内独立采样。

水塘抽样可以与分层采样结合:对每一层的数据流分别做水塘抽样。这在以下场景中有用:

八、完整 C 实现

以下是一个包含所有主要变体的完整实现,带有测试和验证:

/*
 * reservoir_sampling.c — 水塘抽样完整实现
 *
 * 包含:Algorithm R, Algorithm L, A-Res, 分布式合并, 验证测试
 * 编译:gcc -O2 -o reservoir reservoir_sampling.c -lm
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include <assert.h>

/* ========== 随机数工具 ========== */

static long rand_range(long lo, long hi)
{
    assert(hi >= lo);
    long range = hi - lo + 1;
    long limit = RAND_MAX - (RAND_MAX % range);
    long r;
    do {
        r = random();
    } while (r >= limit);
    return lo + r % range;
}

static double rand_uniform(void)
{
    double r;
    do {
        r = (double)random() / (double)RAND_MAX;
    } while (r == 0.0 || r == 1.0);
    return r;
}

/* ========== Algorithm R ========== */

int reservoir_R(int *reservoir, const int *stream, int n, int k)
{
    if (n <= 0 || k <= 0) return 0;
    int actual = (n < k) ? n : k;

    for (int i = 0; i < actual; i++)
        reservoir[i] = stream[i];

    for (int i = k; i < n; i++) {
        long j = rand_range(0, i);
        if (j < k)
            reservoir[j] = stream[i];
    }
    return actual;
}

/* ========== Algorithm L ========== */

int reservoir_L(int *reservoir, const int *stream, int n, int k)
{
    if (n <= 0 || k <= 0) return 0;
    int actual = (n < k) ? n : k;

    for (int i = 0; i < actual; i++)
        reservoir[i] = stream[i];

    if (n <= k) return actual;

    double W = exp(log(rand_uniform()) / k);
    int i = k - 1;

    while (1) {
        double skip = floor(log(rand_uniform()) / log(1.0 - W));
        if (skip > (double)(n - i - 1)) break;
        i += (int)skip + 1;
        if (i >= n) break;
        reservoir[(int)(rand_uniform() * k) % k] = stream[i];
        W *= exp(log(rand_uniform()) / k);
    }
    return actual;
}

/* ========== k=1 特例 ========== */

int reservoir_single(const int *stream, int n)
{
    assert(n > 0);
    int result = stream[0];
    for (int i = 1; i < n; i++) {
        if (rand_range(0, i) == 0)
            result = stream[i];
    }
    return result;
}

/* ========== A-Res 加权采样 ========== */

typedef struct {
    int    value;
    double key;
} witem_t;

static void min_heap_sift(witem_t *h, int size, int pos)
{
    while (1) {
        int s = pos, l = 2*pos+1, r = 2*pos+2;
        if (l < size && h[l].key < h[s].key) s = l;
        if (r < size && h[r].key < h[s].key) s = r;
        if (s == pos) break;
        witem_t t = h[pos]; h[pos] = h[s]; h[s] = t;
        pos = s;
    }
}

static void min_heap_build(witem_t *h, int size)
{
    for (int i = size/2 - 1; i >= 0; i--)
        min_heap_sift(h, size, i);
}

int reservoir_ares(int *result, const int *vals, const double *wts,
                   int n, int k)
{
    if (n <= 0 || k <= 0) return 0;
    int actual = (n < k) ? n : k;
    witem_t *heap = malloc(k * sizeof(witem_t));
    if (!heap) return 0;

    for (int i = 0; i < actual; i++) {
        heap[i].value = vals[i];
        heap[i].key   = pow(rand_uniform(), 1.0 / wts[i]);
    }

    if (n <= k) {
        for (int i = 0; i < actual; i++) result[i] = heap[i].value;
        free(heap);
        return actual;
    }

    min_heap_build(heap, k);

    for (int i = k; i < n; i++) {
        double key = pow(rand_uniform(), 1.0 / wts[i]);
        if (key > heap[0].key) {
            heap[0].value = vals[i];
            heap[0].key   = key;
            min_heap_sift(heap, k, 0);
        }
    }

    for (int i = 0; i < k; i++) result[i] = heap[i].value;
    free(heap);
    return k;
}

/* ========== 验证函数 ========== */

/*
 * 验证均匀性:对每个元素统计被选中的次数,
 * 理论期望 = trials * k / n
 */
void verify_uniformity(const char *name,
                       int (*sampler)(int*, const int*, int, int),
                       int n, int k, int trials)
{
    int *stream = malloc(n * sizeof(int));
    int *reservoir = malloc(k * sizeof(int));
    long *count = calloc(n, sizeof(long));

    for (int i = 0; i < n; i++)
        stream[i] = i;

    for (int t = 0; t < trials; t++) {
        sampler(reservoir, stream, n, k);
        for (int j = 0; j < k; j++)
            count[reservoir[j]]++;
    }

    double expected = (double)trials * k / n;
    double max_dev = 0;
    for (int i = 0; i < n; i++) {
        double dev = fabs((double)count[i] - expected) / expected;
        if (dev > max_dev) max_dev = dev;
    }

    printf("[%s] n=%d, k=%d, trials=%d\n", name, n, k, trials);
    printf("  期望次数: %.1f, 最大偏差: %.2f%%\n",
           expected, max_dev * 100);
    printf("  结论: %s\n\n",
           max_dev < 0.05 ? "通过(偏差 < 5%%)" : "可能存在偏差");

    free(stream);
    free(reservoir);
    free(count);
}

/* ========== 性能基准 ========== */

void benchmark(const char *name,
               int (*sampler)(int*, const int*, int, int),
               int n, int k)
{
    int *stream = malloc(n * sizeof(int));
    int *reservoir = malloc(k * sizeof(int));

    for (int i = 0; i < n; i++)
        stream[i] = i;

    struct timespec start, end;
    clock_gettime(CLOCK_MONOTONIC, &start);

    int rounds = 100;
    for (int t = 0; t < rounds; t++)
        sampler(reservoir, stream, n, k);

    clock_gettime(CLOCK_MONOTONIC, &end);

    double elapsed = (end.tv_sec - start.tv_sec)
                   + (end.tv_nsec - start.tv_nsec) / 1e9;
    printf("[%s] n=%d, k=%d, %d 轮: %.3f ms/轮\n",
           name, n, k, rounds, elapsed / rounds * 1000);

    free(stream);
    free(reservoir);
}

/* ========== 主程序 ========== */

int main(void)
{
    srandom(time(NULL));

    printf("===== 均匀性验证 =====\n\n");
    verify_uniformity("Algorithm R", reservoir_R, 100, 10, 100000);
    verify_uniformity("Algorithm L", reservoir_L, 100, 10, 100000);

    printf("===== 性能基准 =====\n\n");
    benchmark("Algorithm R", reservoir_R, 1000000, 100);
    benchmark("Algorithm L", reservoir_L, 1000000, 100);

    printf("===== k=1 特例测试 =====\n\n");
    {
        int stream[] = {0, 1, 2, 3, 4};
        int counts[5] = {0};
        int trials = 100000;
        for (int t = 0; t < trials; t++)
            counts[reservoir_single(stream, 5)]++;
        printf("k=1, n=5, %d 轮:\n", trials);
        for (int i = 0; i < 5; i++)
            printf("  元素 %d: %.2f%% (期望 20.00%%)\n",
                   i, 100.0 * counts[i] / trials);
    }

    return 0;
}

九、工程实践中的注意事项

9.1 常见陷阱表

陷阱 描述 后果 解决方案
模偏差 rand() % n 当 RAND_MAX 不是 n 的整数倍时有偏 样本不均匀 使用拒绝采样或更好的 RNG
浮点精度 pow(u, 1/w) 在 w 极大或极小时精度不足 加权采样偏差 使用对数域运算:exp(log(u)/w)
种子管理 固定种子导致每次结果相同 采样无随机性 使用时间+PID 组合种子,或使用 /dev/urandom
整数溢出 n 超过 int 范围(>2^31) 未定义行为 使用 int64_t 或 size_t
线程安全 全局 RNG 状态竞争 结果不可复现、概率偏差 每线程独立 RNG,或用 random_r
空流处理 未检查 n=0 或 k=0 段错误 入口参数校验
W 下溢 Algorithm L 中 W 趋近于 0 导致 log(1-W) ≈ -W 跳跃距离计算错误 当 W < ε 时回退到 Algorithm R
权重为零 A-Res 中 w=0 导致 1/w = +∞ 键为 0 或 NaN 过滤零权重元素或设下限

9.2 随机数生成器选择

水塘抽样对 RNG 质量有一定要求。以下是常见选择:

RNG 周期 速度 质量 推荐程度
glibc random() ~2^31 开发/测试用
drand48() 2^48 可用
xoshiro256** 2^256 极快 推荐
PCG-64 2^128 很快 推荐
/dev/urandom 无限 密码级 安全场景

对于生产环境的水塘抽样,建议使用 xoshiro256** 或 PCG 系列。它们在速度和统计质量之间取得了良好平衡。

9.3 线程安全设计

#include <pthread.h>

typedef struct {
    int    *reservoir;
    int     k;
    long    count;
    /* 每个实例独立的 RNG 状态 */
    unsigned long rng_state[4];  /* xoshiro256** */
} thread_safe_reservoir_t;

每个采样器实例持有自己的 RNG 状态,避免竞争。初始化时用不同的种子。

十、特殊变体与扩展

10.1 k=1 的简化

当只需要采集一个样本时,算法极大简化:

/* 从流中均匀随机选一个元素 */
int sample_one(stream_t *s)
{
    int result = stream_next(s);
    int count = 1;
    while (stream_has_next(s)) {
        count++;
        int x = stream_next(s);
        if (rand_range(1, count) == 1)
            result = x;
    }
    return result;
}

这个简化版本是著名面试题”随机选择链表中的一个节点”的标准解法(LeetCode 382)。

10.2 带过期的滑动窗口水塘抽样

在某些场景中,只关心最近 W 个元素中的样本(滑动窗口采样)。基本思路:

  1. 为每个进入水塘的元素记录其到达时间戳。
  2. 当有新元素到达时,检查水塘中是否有过期元素。
  3. 优先替换过期元素;如果没有过期元素,按标准水塘抽样逻辑处理。

这是链式采样(chain sampling)的一个应用场景。

10.3 带拒绝的精确加权采样

当需要严格的”按权重概率精确等于 wᵢ/∑wⱼ”时,A-Res 给出的是 WSwR(加权无放回抽样),而非 PPS(概率正比于大小抽样)。对于某些统计应用(如 Horvitz-Thompson 估计器),需要的是 PPS,此时需要使用更复杂的算法,如 Chao 的方法。

10.4 最小哈希与水塘抽样的联系

MinHash 可以被视为 k=1 水塘抽样在集合相似度估计中的应用:对每个元素计算哈希值,保留最小的(或最大的)就等价于从流中随机选一个。保留最小的 k 个(bottom-k sketch)则等价于水塘抽样。

十一、数学补充

11.1 耦合论证

除了归纳法,还有一种基于耦合(coupling)的优雅证明方式。

考虑一个”上帝视角”的随机过程:在流开始之前,为每个元素 aᵢ 独立生成一个 [0,1] 上的均匀随机数 Uᵢ。最终的样本就是 Uᵢ 最小的 k 个元素。

可以证明,Algorithm R 的执行过程与这个”上帝视角”过程是等价的(在分布意义上)。

具体来说,当处理第 i 个元素时,用 rand(1,i) ≤ k 来决定是否替换,等价于在 Uᵢ < k/i 时替换——而 k/i 恰好是当前水塘阈值的期望。

11.2 信息论下界

水塘抽样需要 O(k log(n/k)) 比特的随机性。这是因为最终结果是 C(n,k) 种等可能结果之一,需要 log₂ C(n,k) ≈ k log(n/k) 比特来编码。

Algorithm L 的随机数使用量 O(k(1 + log(n/k))) 已经接近这个信息论下界(相差一个常数因子),说明它在随机性使用上是近似最优的。

11.3 集中不等式

虽然水塘抽样保证了每个元素的边际概率为 k/n,但样本的质量也取决于样本均值对总体均值的逼近程度。

设 X̄ 为水塘样本的均值,μ 为总体均值。由于水塘抽样等价于无放回简单随机抽样(SRS),有:

Var(X̄) = σ²/k × (1 - k/n)

其中 (1 - k/n) 是有限总体校正因子(finite population correction)。当 k/n 很小时,Var(X̄) ≈ σ²/k,与有放回抽样相同。

配合 Hoeffding 不等式:

P(|X̄ - μ| > ε) ≤ 2 exp(-2kε² / (b-a)²)

其中 [a,b] 是元素值的范围。这给出了样本量 k 和精度 ε 之间的定量关系。

十二、个人思考

12.1 简洁性是算法之美

水塘抽样是我最喜欢的算法之一,原因不在于它有多高深,而在于它的核心思想可以用一句话概括:第 i 个元素以 k/i 的概率替换水塘中的随机位置。

这个规则如此简单,以至于很多人第一次见到时会下意识地觉得”这不可能是对的”。但数学归纳法的证明只有四行,干净利落。我认为这种”简单到令人怀疑”的算法才是真正优雅的算法。

12.2 Algorithm L 的工程价值被低估了

在我接触的很多生产系统中,工程师们使用的仍然是朴素的 Algorithm R。这当然没错——正确性是第一位的——但当数据流量达到每秒百万级时,每个元素都生成一个随机数的开销是实实在在的。

Algorithm L 的实现只比 Algorithm R 多了十几行代码,但在 n/k 很大的场景下(这几乎是所有生产场景),随机数生成次数从 O(n) 降到 O(k log(n/k)),这是值得做的优化。

12.3 加权采样的陷阱

A-Res 算法的正确性依赖于一个微妙的数学性质:U^(1/w) 的极值统计恰好给出了加权无放回抽样的概率。但在工程实现中,浮点精度是一个真实的问题。当权重差异非常大(如 10⁶ 倍)时,小权重元素的键会非常接近 0,大权重元素的键会非常接近 1,浮点表示的精度不足以区分它们。

我的建议是:在对数域做所有运算。将 key = U^(1/w) 变成 log(key) = log(U)/w,比较时直接比较 log(key) 即可。这不仅避免了 pow() 函数的精度问题,还避免了下溢。

12.4 面试中的水塘抽样

水塘抽样是面试热门题,常见的考法有:

  1. 链表随机节点(LeetCode 382):k=1 的水塘抽样。
  2. 随机翻牌:从未知长度的流中等概率选一个。
  3. 文件中随机一行:经典的 Unix 面试题。

面试中最容易犯的错误是用 rand() % i 而不处理模偏差。虽然在面试中这个细节通常不会扣分,但在生产代码中这是一个真实的 bug。

12.5 与流式算法家族的关系

水塘抽样是流式算法家族中最基础的一员。它与其他流式算法有着有趣的联系:

这些算法共同构成了流式数据处理的基础工具箱,每一个都是在有限空间和单遍扫描的约束下做出巧妙的取舍。

12.5 关于”公平”的哲学

水塘抽样保证的是一种特定意义上的”公平”——每个元素被选中的概率相等。但这真的是我们想要的”公平”吗?

在很多应用场景中,我们可能需要的是”代表性”而非”均匀性”。均匀随机抽取 1000 个用户,可能全都来自大城市(因为大城市用户多)。这时候分层采样或加权采样才是正确的选择。

算法的正确性从来不等于应用的正确性。选择哪种采样方法,取决于你真正想回答的问题是什么。

参考文献

  1. Vitter, J. S. (1985). “Random Sampling with a Reservoir.” ACM Transactions on Mathematical Software, 11(1), 37-57.

  2. Efraimidis, P. S., & Spirakis, P. G. (2006). “Weighted Random Sampling with a Reservoir.” Information Processing Letters, 97(5), 181-185.

  3. Li, K. (1994). “Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n))).” ACM Transactions on Mathematical Software, 20(4), 481-493.

  4. Chao, M. T. (1982). “A General Purpose Unequal Probability Sampling Plan.” Biometrika, 69(3), 653-656.

  5. Knuth, D. E. (1997). The Art of Computer Programming, Volume 2: Seminumerical Algorithms, 3rd Edition. Addison-Wesley. Section 3.4.2.

  6. Aggarwal, C. C. (2007). Data Streams: Models and Algorithms. Springer. Chapter 2.

  7. Cormode, G., & Duffield, N. (2014). “Sampling for Big Data: A Tutorial.” ACM SIGKDD Conference on Knowledge Discovery and Data Mining.

  8. Tillé, Y. (2006). Sampling Algorithms. Springer Series in Statistics.


算法系列导航上一篇:t-digest | 下一篇:MinHash 与 SimHash

相关阅读随机化算法:当运气成为武器 | 流式算法总论


By .