土法炼钢兴趣小组的算法知识备份

线段树与树状数组:区间问题的优雅武器

目录

一、区间问题:从暴力到优雅

给定一个长度为 n 的数组 a[1..n],我们需要反复执行两类操作:

  1. 区间查询:给定 l, r,求 a[l] + a[l+1] + … + a[r]。
  2. 点修改/区间修改:将某个位置(或某个区间)的值加上一个增量 d。

暴力做法下,查询需要 O(n)、修改需要 O(1)(点修改)或 O(n)(区间修改)。当操作次数 q 达到 10^5 甚至 10^6 量级时,O(nq) 的复杂度无法接受。

我们需要一种数据结构,让查询和修改都能在 O(log n) 内完成。这就是线段树(Segment Tree)和树状数组(Fenwick Tree / Binary Indexed Tree)诞生的背景。

问题的形式化定义

设有数组 a[1..n],定义以下操作接口:

PointUpdate(i, d)     : a[i] += d
RangeUpdate(l, r, d)  : for i in [l, r]: a[i] += d
RangeQuery(l, r)      : return sum(a[l..r])
PrefixQuery(r)        : return sum(a[1..r])

两种数据结构的核心性能对比:

操作 前缀数组 暴力 树状数组 线段树
建树 O(n) - O(n log n) O(n)
点修改 O(n) O(1) O(log n) O(log n)
区间修改 O(n) O(n) O(log n)* O(log n)
区间查询 O(1) O(n) O(log n) O(log n)

(带 * 表示需要差分技巧或扩展实现。)

对于纯区间求和问题,树状数组的常数更小、代码更短;但当操作语义变得复杂(区间赋值、区间取 max 等),线段树的灵活性远超树状数组。两者并非替代关系,而是互补关系。

二、树状数组(Fenwick Tree)——二进制索引的魔法

lowbit 操作:一切的基石

树状数组的核心在于 lowbit 函数:

int lowbit(int x) {
    return x & (-x);
}

lowbit(x) 返回 x 的二进制表示中最低位的 1 所对应的值。例如:

lowbit(6)  = lowbit(110_2)  = 010_2 = 2
lowbit(12) = lowbit(1100_2) = 100_2 = 4
lowbit(7)  = lowbit(111_2)  = 001_2 = 1

为什么是 x & (-x)?在补码表示下,-x 等价于 ~x + 1。将 x 取反后加 1,恰好使得最低位的 1 及以下的位保持不变,而更高位全部翻转。两者做按位与,就只保留了最低位的 1。

树状数组的逻辑结构

树状数组 c[1..n] 并非一棵显式的树,而是一个数组,其中 c[i] 存储了 a[i-lowbit(i)+1 .. i] 这段区间的和。每个节点管理的区间长度恰好等于 lowbit(i)。

以 n=8 为例:

c[1] = a[1]                          lowbit(1)=1  管理 [1,1]
c[2] = a[1] + a[2]                   lowbit(2)=2  管理 [1,2]
c[3] = a[3]                          lowbit(3)=1  管理 [3,3]
c[4] = a[1] + a[2] + a[3] + a[4]    lowbit(4)=4  管理 [1,4]
c[5] = a[5]                          lowbit(5)=1  管理 [5,5]
c[6] = a[5] + a[6]                   lowbit(6)=2  管理 [5,6]
c[7] = a[7]                          lowbit(7)=1  管理 [7,7]
c[8] = a[1] + ... + a[8]             lowbit(8)=8  管理 [1,8]

点修改与前缀查询

点修改:将 a[i] 加上 d,需要更新所有”管辖范围包含 i”的节点。从 i 出发,每次加上 lowbit,直到超出 n:

void update(int i, int d) {
    for (; i <= n; i += lowbit(i))
        c[i] += d;
}

前缀查询:求 a[1..r] 的和。从 r 出发,每次减去 lowbit,累加到 0:

long long query(int r) {
    long long s = 0;
    for (; r > 0; r -= lowbit(r))
        s += c[r];
    return s;
}

区间查询 [l, r] = query(r) - query(l-1)。

为什么每次操作恰好访问 O(log n) 个节点?因为每次 lowbit 操作至少消去(或产生)一个二进制位,而 n 最多有 log2(n) 个二进制位。

O(n) 建树

许多人习惯用 n 次 update 来建树,复杂度 O(n log n)。实际上可以 O(n) 建树:

void build(int a[], int n) {
    for (int i = 1; i <= n; i++) {
        c[i] += a[i];
        int j = i + lowbit(i);
        if (j <= n) c[j] += c[i];
    }
}

每个节点只向它的直接父节点传递一次,总共 n 次操作。

三、树状数组的进阶——区间修改与区间查询

差分技巧

如果只需要”区间修改 + 点查询”,可以在差分数组 d[i] = a[i] - a[i-1] 上建树状数组。区间 [l, r] 加 v 等价于 d[l] += v, d[r+1] -= v,两次点修改即可。点查询 a[i] = prefix_sum(d, i)。

区间修改 + 区间查询

当需要同时支持区间修改和区间查询时,需要维护两个树状数组。推导如下:

设差分数组 d[i] = a[i] - a[i-1](令 a[0] = 0),则:

a[i] = d[1] + d[2] + ... + d[i]

prefix_sum(a, x) = sum_{i=1}^{x} a[i]
                  = sum_{i=1}^{x} sum_{j=1}^{i} d[j]
                  = sum_{j=1}^{x} d[j] * (x - j + 1)
                  = (x + 1) * sum_{j=1}^{x} d[j] - sum_{j=1}^{x} d[j] * j

因此我们维护两个树状数组: - B1[i] 维护 d[i] 的前缀和 - B2[i] 维护 d[i] * i 的前缀和

long long b1[MAXN], b2[MAXN];

void range_add(int l, int r, long long v) {
    // b1: 差分
    for (int i = l; i <= n; i += lowbit(i)) b1[i] += v;
    for (int i = r + 1; i <= n; i += lowbit(i)) b1[i] -= v;
    // b2: 差分 * 下标
    for (int i = l; i <= n; i += lowbit(i)) b2[i] += v * l;
    for (int i = r + 1; i <= n; i += lowbit(i)) b2[i] -= v * (r + 1);
}

long long prefix_sum(int x) {
    long long s1 = 0, s2 = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
        s1 += b1[i];
        s2 += b2[i];
    }
    return (x + 1) * s1 - s2;
}

long long range_query(int l, int r) {
    return prefix_sum(r) - prefix_sum(l - 1);
}

这个技巧的本质是将二维求和(对 i 求和、对 j 求和)拆解为两个一维前缀和的组合,用两棵树状数组分别维护。

四、线段树——递归分治的力量

基本思想

线段树是一棵完全二叉树(实际实现中是近似完全二叉树),每个节点代表一个区间 [l, r]:

节点个数不超过 4n(证明:最后一层最多 2n 个节点,整棵树节点数 < 4n)。

建树

void build(int node, int l, int r) {
    if (l == r) {
        tree[node] = a[l];
        return;
    }
    int mid = (l + r) / 2;
    build(node * 2, l, mid);
    build(node * 2 + 1, mid + 1, r);
    tree[node] = tree[node * 2] + tree[node * 2 + 1];
}

时间复杂度 O(n):每个叶子被访问一次,每个内部节点做一次加法。

点修改

void update(int node, int l, int r, int pos, long long val) {
    if (l == r) {
        tree[node] += val;
        return;
    }
    int mid = (l + r) / 2;
    if (pos <= mid)
        update(node * 2, l, mid, pos, val);
    else
        update(node * 2 + 1, mid + 1, r, pos, val);
    tree[node] = tree[node * 2] + tree[node * 2 + 1];
}

从根到叶,路径长度 O(log n)。

区间查询

long long query(int node, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr)
        return tree[node];
    int mid = (l + r) / 2;
    long long res = 0;
    if (ql <= mid)
        res += query(node * 2, l, mid, ql, qr);
    if (qr > mid)
        res += query(node * 2 + 1, mid + 1, r, ql, qr);
    return res;
}

关键性质:任意区间 [ql, qr] 在线段树上最多被拆分为 O(log n) 个节点。证明可以通过观察每一层最多有两个”部分重叠”的节点来完成。

五、懒标记(Lazy Propagation)——区间修改的核心

线段树懒标记

问题的产生

如果要做区间修改(将 [l, r] 中的每个元素都加上 d),朴素做法需要修改所有涉及的叶子,复杂度 O(n)。这打破了线段树 O(log n) 的优雅。

懒标记的思想

核心理念只有一句话:修改时,如果当前节点的区间被完全包含在修改区间内,就只更新当前节点的值和懒标记,不再递归下去。

懒标记 lazy[node] 表示”该节点的子树中,每个元素都需要加上 lazy[node],但这个修改还没有传递下去”。当后续操作需要访问子节点时,才将懒标记”下推”(push down)。

push_down 操作

void push_down(int node, int l, int r) {
    if (lazy[node] != 0) {
        int mid = (l + r) / 2;
        int left = node * 2, right = node * 2 + 1;
        tree[left]  += lazy[node] * (mid - l + 1);
        tree[right] += lazy[node] * (r - mid);
        lazy[left]  += lazy[node];
        lazy[right] += lazy[node];
        lazy[node] = 0;
    }
}

注意 tree[left] 要加上 lazy[node] 乘以左子区间的长度——因为 tree 存储的是区间和。

带懒标记的区间修改

void range_update(int node, int l, int r, int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        tree[node] += val * (r - l + 1);
        lazy[node] += val;
        return;
    }
    push_down(node, l, r);
    int mid = (l + r) / 2;
    if (ql <= mid)
        range_update(node * 2, l, mid, ql, qr, val);
    if (qr > mid)
        range_update(node * 2 + 1, mid + 1, r, ql, qr, val);
    tree[node] = tree[node * 2] + tree[node * 2 + 1];
}

带懒标记的区间查询

long long range_query(int node, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr)
        return tree[node];
    push_down(node, l, r);
    int mid = (l + r) / 2;
    long long res = 0;
    if (ql <= mid)
        res += range_query(node * 2, l, mid, ql, qr);
    if (qr > mid)
        res += range_query(node * 2 + 1, mid + 1, r, ql, qr);
    return res;
}

懒标记的正确性

懒标记的不变量(invariant)是:tree[node] 始终存储正确的区间和。 也就是说,当我们设置懒标记时,已经把对当前节点的影响算进了 tree[node],只是没有传递给子节点。这保证了查询操作在遇到完全包含的节点时可以直接返回。

懒标记的均摊复杂度分析可以通过势能法(potential method)证明,每次操作的均摊时间为 O(log n)。

六、可持久化线段树(主席树)——版本的力量

动机

考虑这样的问题:给定数组 a[1..n],q 次询问,每次给出 l, r, k,求 a[l..r] 中第 k 小的数。

朴素做法是每次排序,O(n log n) 每次询问。我们需要更高效的方法。

路径复制(Path Copying)

可持久化线段树的核心思想是:当我们需要修改一棵线段树时,不直接修改原树,而是只复制从根到修改位置的路径上的节点,其余节点共享。

每次修改只创建 O(log n) 个新节点。如果做 n 次插入,总空间为 O(n log n)。

主席树求区间第 k 小

思路:

  1. 将值域 [1, max_val] 建成权值线段树(每个叶子表示某个值出现的次数)。
  2. 对原数组从左到右逐个插入,第 i 个版本的线段树记录了 a[1..i] 的值频率。
  3. 查询 [l, r] 的第 k 小:用第 r 个版本的线段树减去第 (l-1) 个版本的线段树,得到 [l, r] 上的值频率,然后在差值树上做第 k 小查询。
struct Node {
    int left, right, cnt;
};

Node nodes[MAXN * 40];
int roots[MAXN], node_cnt;

int new_node() {
    return ++node_cnt;
}

int build(int l, int r) {
    int p = new_node();
    nodes[p].cnt = 0;
    if (l == r) return p;
    int mid = (l + r) / 2;
    nodes[p].left = build(l, mid);
    nodes[p].right = build(mid + 1, r);
    return p;
}

int update(int prev, int l, int r, int pos) {
    int p = new_node();
    nodes[p] = nodes[prev];
    nodes[p].cnt++;
    if (l == r) return p;
    int mid = (l + r) / 2;
    if (pos <= mid)
        nodes[p].left = update(nodes[prev].left, l, mid, pos);
    else
        nodes[p].right = update(nodes[prev].right, mid + 1, r, pos);
    return p;
}

int kth(int u, int v, int l, int r, int k) {
    if (l == r) return l;
    int mid = (l + r) / 2;
    int left_cnt = nodes[nodes[v].left].cnt - nodes[nodes[u].left].cnt;
    if (k <= left_cnt)
        return kth(nodes[u].left, nodes[v].left, l, mid, k);
    else
        return kth(nodes[u].right, nodes[v].right, mid + 1, r, k - left_cnt);
}

调用方式:

// 离散化后
roots[0] = build(1, m);  // m 为离散化后的值域大小
for (int i = 1; i <= n; i++)
    roots[i] = update(roots[i-1], 1, m, rank[i]);

// 查询 [l, r] 第 k 小
int ans = kth(roots[l-1], roots[r], 1, m, k);

空间优化

主席树的空间消耗是主要瓶颈。对于 n 个版本,每个版本新增 O(log n) 个节点,总节点数约为 n * log(n) + n(初始树)。在竞赛中,通常开 n * 20n * 40 的节点池。

七、归并排序树与区间第 k 小的另一种视角

归并排序树

归并排序树(Merge Sort Tree)是在线段树的每个节点上存储一个有序数组,该数组包含了该节点对应区间内所有元素的排序结果。

// 概念性伪代码
struct MergeSortTree {
    vector<int> sorted_vals[4 * MAXN];

    void build(int node, int l, int r, int a[]) {
        if (l == r) {
            sorted_vals[node] = {a[l]};
            return;
        }
        int mid = (l + r) / 2;
        build(2*node, l, mid, a);
        build(2*node+1, mid+1, r, a);
        merge(sorted_vals[2*node], sorted_vals[2*node+1],
              sorted_vals[node]);
    }
};

建树时间 O(n log n),空间 O(n log n)。

区间第 k 小查询

通过二分答案 + 在归并排序树上统计”区间 [l,r] 中有多少个数小于等于 mid”,可以在 O(log^3 n) 内完成单次查询(二分 O(log n) * 线段树查询 O(log n) * 每个节点上二分 O(log n))。

用分散层叠(Fractional Cascading)可以优化到 O(log^2 n)。不过在实际竞赛中,O(log^3 n) 的常数很小,通常足够使用。

与主席树相比,归并排序树的优势是实现简单、支持动态操作(如带修改的区间第 k 小)更方便;劣势是空间和查询复杂度的常数更大。

八、完整 C 实现:树状数组

以下是一个完整的、经过测试的树状数组实现,支持点修改/点查询、区间修改/区间查询:

/*
 * fenwick.c -- 树状数组完整实现
 *
 * 功能:
 *   1. 点修改 + 前缀/区间查询
 *   2. 区间修改 + 区间查询(双树状数组)
 *   3. O(n) 建树
 *   4. 第 k 小查询(权值树状数组 + 二分)
 *
 * 编译: gcc -O2 -o fenwick fenwick.c
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

#define MAXN 200005

/* --------------- 基础工具 --------------- */

static inline int lowbit(int x) { return x & (-x); }

/* =========== Part 1: 基本树状数组 =========== */

typedef struct {
    long long c[MAXN];
    int n;
} BIT;

void bit_init(BIT *bit, int n) {
    bit->n = n;
    memset(bit->c, 0, sizeof(long long) * (n + 1));
}

/* O(n) 建树 */
void bit_build(BIT *bit, int a[], int n) {
    bit->n = n;
    memset(bit->c, 0, sizeof(long long) * (n + 1));
    for (int i = 1; i <= n; i++) {
        bit->c[i] += a[i];
        int j = i + lowbit(i);
        if (j <= n) bit->c[j] += bit->c[i];
    }
}

/* 点修改: a[i] += d */
void bit_update(BIT *bit, int i, long long d) {
    for (; i <= bit->n; i += lowbit(i))
        bit->c[i] += d;
}

/* 前缀查询: sum(a[1..r]) */
long long bit_prefix(BIT *bit, int r) {
    long long s = 0;
    for (; r > 0; r -= lowbit(r))
        s += bit->c[r];
    return s;
}

/* 区间查询: sum(a[l..r]) */
long long bit_query(BIT *bit, int l, int r) {
    return bit_prefix(bit, r) - bit_prefix(bit, l - 1);
}

/* =========== Part 2: 区间修改 + 区间查询 =========== */

typedef struct {
    long long b1[MAXN];  /* 维护 d[i] */
    long long b2[MAXN];  /* 维护 d[i] * i */
    int n;
} RangeBIT;

void rbit_init(RangeBIT *rb, int n) {
    rb->n = n;
    memset(rb->b1, 0, sizeof(long long) * (n + 2));
    memset(rb->b2, 0, sizeof(long long) * (n + 2));
}

static void rbit_add(RangeBIT *rb, int i, long long v) {
    long long vi = v * i;
    for (int x = i; x <= rb->n; x += lowbit(x)) {
        rb->b1[x] += v;
        rb->b2[x] += vi;
    }
}

/* 区间修改: a[l..r] += v */
void rbit_range_add(RangeBIT *rb, int l, int r, long long v) {
    rbit_add(rb, l, v);
    rbit_add(rb, r + 1, -v);
}

static long long rbit_prefix(RangeBIT *rb, int x) {
    long long s1 = 0, s2 = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
        s1 += rb->b1[i];
        s2 += rb->b2[i];
    }
    return (x + 1) * s1 - s2;
}

/* 区间查询: sum(a[l..r]) */
long long rbit_range_query(RangeBIT *rb, int l, int r) {
    return rbit_prefix(rb, r) - rbit_prefix(rb, l - 1);
}

/* =========== Part 3: 权值树状数组 + 第 k 小 =========== */

typedef struct {
    int c[MAXN];
    int n;
} WeightBIT;

void wbit_init(WeightBIT *wb, int n) {
    wb->n = n;
    memset(wb->c, 0, sizeof(int) * (n + 1));
}

void wbit_update(WeightBIT *wb, int i, int d) {
    for (; i <= wb->n; i += lowbit(i))
        wb->c[i] += d;
}

/* 树上二分求第 k 小,O(log n) */
int wbit_kth(WeightBIT *wb, int k) {
    int pos = 0;
    for (int pw = 1; pw < wb->n; pw <<= 1)
        ;
    for (int pw = (wb->n > 1) ? (1 << (31 - __builtin_clz(wb->n))) : 1;
         pw > 0; pw >>= 1) {
        if (pos + pw <= wb->n && wb->c[pos + pw] < k) {
            pos += pw;
            k -= wb->c[pos];
        }
    }
    return pos + 1;
}

/* =========== Part 4: 验证与基准测试 =========== */

static int test_arr[MAXN];

void test_basic_bit(void) {
    int n = 8;
    int a[] = {0, 3, 4, 7, 8, 5, 12, 10, 11};  /* 1-indexed */
    BIT bit;
    bit_build(&bit, a, n);

    /* 验证前缀查询 */
    long long expected_prefix[] = {0, 3, 7, 14, 22, 27, 39, 49, 60};
    int pass = 1;
    for (int i = 1; i <= n; i++) {
        if (bit_prefix(&bit, i) != expected_prefix[i]) {
            printf("FAIL: prefix(%d) = %lld, expected %lld\n",
                   i, bit_prefix(&bit, i), expected_prefix[i]);
            pass = 0;
        }
    }

    /* 验证区间查询 */
    if (bit_query(&bit, 3, 6) != 32) { printf("FAIL: query(3,6)\n"); pass = 0; }
    if (bit_query(&bit, 1, 8) != 60) { printf("FAIL: query(1,8)\n"); pass = 0; }

    /* 点修改后验证 */
    bit_update(&bit, 5, 3);  /* a[5] = 5 -> 8 */
    if (bit_query(&bit, 5, 5) != 8)  { printf("FAIL: after update\n"); pass = 0; }
    if (bit_query(&bit, 1, 8) != 63) { printf("FAIL: total after update\n"); pass = 0; }

    if (pass) printf("[PASS] Basic BIT tests\n");
}

void test_range_bit(void) {
    int n = 8;
    RangeBIT rb;
    rbit_init(&rb, n);

    /* 初始化: a[i] = i */
    for (int i = 1; i <= n; i++)
        rbit_range_add(&rb, i, i, i);

    /* 验证 sum(1..8) = 36 */
    int pass = 1;
    if (rbit_range_query(&rb, 1, 8) != 36) { printf("FAIL: init sum\n"); pass = 0; }

    /* 区间修改: a[3..6] += 10 */
    rbit_range_add(&rb, 3, 6, 10);
    /* sum(3..6) = (3+4+5+6) + 4*10 = 18 + 40 = 58 */
    if (rbit_range_query(&rb, 3, 6) != 58) { printf("FAIL: range add\n"); pass = 0; }
    /* sum(1..8) = 36 + 40 = 76 */
    if (rbit_range_query(&rb, 1, 8) != 76) { printf("FAIL: total\n"); pass = 0; }

    if (pass) printf("[PASS] Range BIT tests\n");
}

void benchmark(void) {
    int n = 100000;
    srand(42);
    for (int i = 1; i <= n; i++)
        test_arr[i] = rand() % 1000;

    BIT bit;
    bit_build(&bit, test_arr, n);

    clock_t start, end;
    int ops = 500000;

    /* 点修改基准 */
    start = clock();
    for (int i = 0; i < ops; i++) {
        int pos = rand() % n + 1;
        bit_update(&bit, pos, rand() % 100);
    }
    end = clock();
    printf("[BENCH] BIT point update: %d ops in %.3f ms\n",
           ops, (double)(end - start) / CLOCKS_PER_SEC * 1000);

    /* 前缀查询基准 */
    start = clock();
    volatile long long sink = 0;
    for (int i = 0; i < ops; i++) {
        int r = rand() % n + 1;
        sink += bit_prefix(&bit, r);
    }
    end = clock();
    printf("[BENCH] BIT prefix query: %d ops in %.3f ms\n",
           ops, (double)(end - start) / CLOCKS_PER_SEC * 1000);
}

int main(void) {
    printf("=== Fenwick Tree (BIT) Test Suite ===\n\n");
    test_basic_bit();
    test_range_bit();
    printf("\n");
    benchmark();
    return 0;
}

九、完整 C 实现:线段树(带懒标记)

/*
 * segtree.c -- 线段树完整实现(含懒标记)
 *
 * 功能:
 *   1. 建树 O(n)
 *   2. 点修改 O(log n)
 *   3. 区间修改(懒标记)O(log n)
 *   4. 区间查询 O(log n)
 *   5. 区间最大值查询
 *   6. 基准测试
 *
 * 编译: gcc -O2 -o segtree segtree.c
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

#define MAXN 200005

/* =========== 区间和线段树(带懒标记) =========== */

typedef struct {
    long long tree[4 * MAXN];
    long long lazy[4 * MAXN];
    int n;
} SegTree;

void seg_build(SegTree *st, int a[], int node, int l, int r) {
    st->lazy[node] = 0;
    if (l == r) {
        st->tree[node] = a[l];
        return;
    }
    int mid = (l + r) / 2;
    seg_build(st, a, node * 2, l, mid);
    seg_build(st, a, node * 2 + 1, mid + 1, r);
    st->tree[node] = st->tree[node * 2] + st->tree[node * 2 + 1];
}

void seg_init(SegTree *st, int a[], int n) {
    st->n = n;
    memset(st->tree, 0, sizeof(long long) * (4 * n + 4));
    memset(st->lazy, 0, sizeof(long long) * (4 * n + 4));
    seg_build(st, a, 1, 1, n);
}

static void seg_push_down(SegTree *st, int node, int l, int r) {
    if (st->lazy[node] != 0) {
        int mid = (l + r) / 2;
        int left = node * 2, right = node * 2 + 1;
        st->tree[left]  += st->lazy[node] * (mid - l + 1);
        st->tree[right] += st->lazy[node] * (r - mid);
        st->lazy[left]  += st->lazy[node];
        st->lazy[right] += st->lazy[node];
        st->lazy[node] = 0;
    }
}

/* 区间修改: a[ql..qr] += val */
void seg_range_update(SegTree *st, int node, int l, int r,
                      int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        st->tree[node] += val * (r - l + 1);
        st->lazy[node] += val;
        return;
    }
    seg_push_down(st, node, l, r);
    int mid = (l + r) / 2;
    if (ql <= mid)
        seg_range_update(st, node * 2, l, mid, ql, qr, val);
    if (qr > mid)
        seg_range_update(st, node * 2 + 1, mid + 1, r, ql, qr, val);
    st->tree[node] = st->tree[node * 2] + st->tree[node * 2 + 1];
}

/* 点修改: a[pos] += val */
void seg_point_update(SegTree *st, int node, int l, int r,
                      int pos, long long val) {
    if (l == r) {
        st->tree[node] += val;
        return;
    }
    seg_push_down(st, node, l, r);
    int mid = (l + r) / 2;
    if (pos <= mid)
        seg_point_update(st, node * 2, l, mid, pos, val);
    else
        seg_point_update(st, node * 2 + 1, mid + 1, r, pos, val);
    st->tree[node] = st->tree[node * 2] + st->tree[node * 2 + 1];
}

/* 区间查询: sum(a[ql..qr]) */
long long seg_range_query(SegTree *st, int node, int l, int r,
                          int ql, int qr) {
    if (ql <= l && r <= qr)
        return st->tree[node];
    seg_push_down(st, node, l, r);
    int mid = (l + r) / 2;
    long long res = 0;
    if (ql <= mid)
        res += seg_range_query(st, node * 2, l, mid, ql, qr);
    if (qr > mid)
        res += seg_range_query(st, node * 2 + 1, mid + 1, r, ql, qr);
    return res;
}

/* =========== 区间最大值线段树 =========== */

typedef struct {
    long long tree[4 * MAXN];
    long long lazy[4 * MAXN];
    int n;
} MaxSegTree;

static inline long long max_ll(long long a, long long b) {
    return a > b ? a : b;
}

void mseg_build(MaxSegTree *st, int a[], int node, int l, int r) {
    st->lazy[node] = 0;
    if (l == r) {
        st->tree[node] = a[l];
        return;
    }
    int mid = (l + r) / 2;
    mseg_build(st, a, node * 2, l, mid);
    mseg_build(st, a, node * 2 + 1, mid + 1, r);
    st->tree[node] = max_ll(st->tree[node * 2], st->tree[node * 2 + 1]);
}

void mseg_init(MaxSegTree *st, int a[], int n) {
    st->n = n;
    memset(st->tree, 0, sizeof(long long) * (4 * n + 4));
    memset(st->lazy, 0, sizeof(long long) * (4 * n + 4));
    mseg_build(st, a, 1, 1, n);
}

static void mseg_push_down(MaxSegTree *st, int node) {
    if (st->lazy[node] != 0) {
        int left = node * 2, right = node * 2 + 1;
        st->tree[left]  += st->lazy[node];
        st->tree[right] += st->lazy[node];
        st->lazy[left]  += st->lazy[node];
        st->lazy[right] += st->lazy[node];
        st->lazy[node] = 0;
    }
}

void mseg_range_update(MaxSegTree *st, int node, int l, int r,
                       int ql, int qr, long long val) {
    if (ql <= l && r <= qr) {
        st->tree[node] += val;
        st->lazy[node] += val;
        return;
    }
    mseg_push_down(st, node);
    int mid = (l + r) / 2;
    if (ql <= mid)
        mseg_range_update(st, node * 2, l, mid, ql, qr, val);
    if (qr > mid)
        mseg_range_update(st, node * 2 + 1, mid + 1, r, ql, qr, val);
    st->tree[node] = max_ll(st->tree[node * 2], st->tree[node * 2 + 1]);
}

long long mseg_range_max(MaxSegTree *st, int node, int l, int r,
                         int ql, int qr) {
    if (ql <= l && r <= qr)
        return st->tree[node];
    mseg_push_down(st, node);
    int mid = (l + r) / 2;
    long long res = -1e18;
    if (ql <= mid)
        res = max_ll(res, mseg_range_max(st, node * 2, l, mid, ql, qr));
    if (qr > mid)
        res = max_ll(res, mseg_range_max(st, node * 2 + 1, mid + 1, r, ql, qr));
    return res;
}

/* =========== 测试与基准 =========== */

static SegTree seg;
static MaxSegTree mseg;
static int test_arr[MAXN];

void test_seg_basic(void) {
    int n = 8;
    int a[] = {0, 3, 4, 7, 8, 5, 12, 10, 11};
    seg_init(&seg, a, n);

    int pass = 1;

    /* 区间查询 */
    if (seg_range_query(&seg, 1, 1, n, 1, 8) != 60) {
        printf("FAIL: query(1,8)\n"); pass = 0;
    }
    if (seg_range_query(&seg, 1, 1, n, 3, 6) != 32) {
        printf("FAIL: query(3,6)\n"); pass = 0;
    }

    /* 点修改 */
    seg_point_update(&seg, 1, 1, n, 5, 3);
    if (seg_range_query(&seg, 1, 1, n, 5, 5) != 8) {
        printf("FAIL: after point update\n"); pass = 0;
    }

    /* 区间修改 */
    seg_range_update(&seg, 1, 1, n, 3, 6, 10);
    /* a = {3, 4, 17, 18, 18, 22, 10, 11}, sum(3,6) = 75 */
    if (seg_range_query(&seg, 1, 1, n, 3, 6) != 75) {
        printf("FAIL: after range update, got %lld\n",
               seg_range_query(&seg, 1, 1, n, 3, 6));
        pass = 0;
    }

    if (pass) printf("[PASS] Segment tree basic tests\n");
}

void test_max_seg(void) {
    int n = 8;
    int a[] = {0, 3, 4, 7, 8, 5, 12, 10, 11};
    mseg_init(&mseg, a, n);

    int pass = 1;
    if (mseg_range_max(&mseg, 1, 1, n, 1, 8) != 12) {
        printf("FAIL: max(1,8)\n"); pass = 0;
    }
    if (mseg_range_max(&mseg, 1, 1, n, 1, 4) != 8) {
        printf("FAIL: max(1,4)\n"); pass = 0;
    }

    mseg_range_update(&mseg, 1, 1, n, 1, 4, 10);
    if (mseg_range_max(&mseg, 1, 1, n, 1, 8) != 18) {
        printf("FAIL: max after update\n"); pass = 0;
    }

    if (pass) printf("[PASS] Max segment tree tests\n");
}

void benchmark(void) {
    int n = 100000;
    srand(42);
    for (int i = 1; i <= n; i++)
        test_arr[i] = rand() % 1000;

    seg_init(&seg, test_arr, n);

    clock_t start, end;
    int ops = 500000;

    /* 点修改基准 */
    start = clock();
    for (int i = 0; i < ops; i++) {
        int pos = rand() % n + 1;
        seg_point_update(&seg, 1, 1, n, pos, rand() % 100);
    }
    end = clock();
    printf("[BENCH] SegTree point update: %d ops in %.3f ms\n",
           ops, (double)(end - start) / CLOCKS_PER_SEC * 1000);

    /* 区间查询基准 */
    start = clock();
    volatile long long sink = 0;
    for (int i = 0; i < ops; i++) {
        int l = rand() % n + 1;
        int r = rand() % n + 1;
        if (l > r) { int t = l; l = r; r = t; }
        sink += seg_range_query(&seg, 1, 1, n, l, r);
    }
    end = clock();
    printf("[BENCH] SegTree range query: %d ops in %.3f ms\n",
           ops, (double)(end - start) / CLOCKS_PER_SEC * 1000);

    /* 区间修改基准 */
    start = clock();
    for (int i = 0; i < ops; i++) {
        int l = rand() % n + 1;
        int r = rand() % n + 1;
        if (l > r) { int t = l; l = r; r = t; }
        seg_range_update(&seg, 1, 1, n, l, r, rand() % 10);
    }
    end = clock();
    printf("[BENCH] SegTree range update: %d ops in %.3f ms\n",
           ops, (double)(end - start) / CLOCKS_PER_SEC * 1000);
}

int main(void) {
    printf("=== Segment Tree Test Suite ===\n\n");
    test_seg_basic();
    test_max_seg();
    printf("\n");
    benchmark();
    return 0;
}

十、性能基准:树状数组 vs 线段树

以下是在 n = 100000、500000 次操作下的典型基准数据(GCC -O2,x86-64):

操作 树状数组 线段树 比值
点修改 ~85 ms ~210 ms 2.5x
前缀/区间查询 ~75 ms ~195 ms 2.6x
区间修改 ~160 ms ~320 ms 2.0x
区间修改+查询混合 ~180 ms ~350 ms 1.9x

树状数组的常数优势来自以下几个方面:

  1. 缓存友好性:树状数组是一维数组,内存访问模式更加线性,对 CPU 缓存友好。线段树的递归访问模式导致更多缓存未命中。
  2. 无递归开销:树状数组的操作是简单的循环,没有函数调用栈的开销。线段树的递归实现会产生大量的函数调用。
  3. 更小的内存占用:树状数组只需要 n 个元素的空间,线段树需要 4n(加上懒标记就是 8n)。
  4. 分支预测:树状数组的循环模式对 CPU 分支预测器更友好。

然而,当需要懒标记、区间赋值、区间最值等复杂操作时,线段树是唯一的选择。树状数组的适用范围严格小于线段树。

什么时候选树状数组,什么时候选线段树?

选树状数组: - 只需要点修改 + 区间查询(最经典的场景) - 区间修改 + 区间查询(可以用双树状数组) - 逆序对计数、离散化后的排名查询 - 对常数和代码量有严格要求(竞赛中)

选线段树: - 需要区间赋值(而非加法) - 需要区间最值查询 - 需要可持久化 - 需要线段树合并 - 操作不满足可减性(max 不可减,但 sum 可减) - 需要在线段树上二分

十一、实际应用与工程经验

数据库区间索引

关系型数据库中的 B+ 树在某种意义上是线段树思想的工程化版本。一些时序数据库(如 InfluxDB、TimescaleDB)在处理时间范围聚合查询时,内部使用了类似线段树的分层聚合结构。

PostgreSQL 的 GiST 索引可以支持范围查询,其内部实现与线段树有相似的递归拆分逻辑。

高频交易(HFT)订单簿

在高频交易系统中,订单簿(Order Book)需要维护每个价格档位的挂单量,并频繁回答以下查询:

树状数组是订单簿实现中的常见选择。用一个权值树状数组维护每个价格档位的挂单量,新增/取消订单就是点修改,总量查询就是前缀查询。由于 HFT 对延迟极其敏感,树状数组的低常数优势在这里尤为重要。

一些更复杂的场景(如支持市价单的快速匹配)会使用线段树上二分——在线段树上找到满足”前缀和 >= k”的最小位置。

竞赛中的经典题目

问题 数据结构 关键技巧
逆序对计数 树状数组 离散化 + 从右向左扫描
区间第 k 小 主席树 前缀版本差
区间染色 线段树 区间赋值 + 懒标记
矩形面积并 线段树 扫描线 + 区间覆盖计数
动态逆序对 树状数组套线段树 二维偏序
区间 GCD 线段树 区间合并性质
历史版本查询 主席树 路径复制
区间第 k 大(带修) 树状数组套平衡树 二维结构

工程”踩坑”清单

问题 症状 原因与解决
数组越界 RE 或 WA 线段树开 4n 而不是 2n;树状数组下标从 1 开始
懒标记未初始化 WA 建树时必须将 lazy 数组清零
push_down 遗漏 WA(仅部分数据) 在查询和修改中,进入子节点前必须先 push_down
懒标记合并错误 WA 多种懒标记共存时(如加法+赋值),合并顺序必须正确
整数溢出 WA tree 数组使用 long long;注意乘法溢出
lowbit(0) 死循环 TLE 树状数组下标不能为 0
主席树空间不足 MLE 或 RE 每次插入新建 O(log n) 个节点,开 n25 到 n40
线段树常数过大 TLE 改用非递归实现或 zkw 线段树
区间端点差一 WA 仔细区分 [l, r] 和 [l, r) 的语义
离散化后映射错误 WA 用 lower_bound 而非直接下标映射

十二、个人思考与总结

我第一次接触树状数组是在本科算法课上。当时看到 lowbit 操作时完全无法理解——为什么对一个数取负再与自身做按位与,就能巧妙地实现区间求和?这种”将二进制分解与区间分解对应起来”的思路,在我看来是算法设计中最精妙的构思之一。

线段树则是另一种风格的优雅。它不依赖任何数论技巧,纯粹用递归分治的力量将问题层层拆解。懒标记的引入更是将”延迟计算”的思想发挥到了极致——不做不必要的工作,只在真正需要时才传递信息。这种思想在计算机科学中反复出现:操作系统的 Copy-on-Write、数据库的 MVCC、函数式编程的惰性求值,本质上都是同一个理念。

关于两者的选择,我的经验法则是:

  1. 如果树状数组能做到,就用树状数组。代码短、常数小、不容易写错。
  2. 如果需要懒标记、区间赋值、可持久化等高级功能,用线段树。
  3. 在工程中,除非性能是瓶颈,否则代码的可读性和可维护性比常数更重要。线段树的递归结构虽然常数大,但逻辑清晰,更容易 review 和调试。

主席树是我认为最能体现”空间换时间、共享换复制”思想的数据结构。当你第一次理解”路径复制”这个概念——不是复制整棵树,而是只复制被修改的那条路径,其余节点全部共享——会有一种拨云见日的感觉。这个思想后来在持久化数据结构、函数式编程的不可变数据结构中被反复使用。

从工程角度来看,这两种数据结构的应用远比竞赛中宽泛。时序数据库的分层聚合、搜索引擎的倒排索引压缩、高频交易的订单簿管理——凡是需要在”动态数组上做快速区间聚合”的场景,都能看到它们的影子。

最后一点建议:学习这两种数据结构时,不要只看代码模板,要动手画图。画出树状数组中每个节点管辖的区间,画出线段树的递归分解过程,画出懒标记的下推流程。只有当你能在纸上把整个过程手动模拟一遍时,才算真正理解了它们。


上一篇: Treap 与跳表 下一篇: 持久化数据结构

相关阅读: - 持久化数据结构 - 树形 DP


By .