土法炼钢兴趣小组的算法知识备份

从零实现一个向量搜索引擎

目录

前几篇文章里,我们分别拆解了 HNSW、乘积量化、ScaNN 与 DiskANN 的核心算法。单独看每个模块都不复杂,但把它们组装成一个”真正能跑”的向量搜索引擎,中间要趟过的坑远比想象中多。这篇文章就来做这件事:从距离函数到索引层,从 WAL 到 mmap,一步步搭出一个麻雀虽小、五脏俱全的检索系统,最后用 Go 写出一个可运行的 mini engine。

向量搜索引擎架构

一、整体架构:三层分离

一个向量搜索引擎,不管规模大小,通常可以拆成三层:

这三层的边界划分至关重要。我见过不少项目把索引逻辑和存储 IO 混在一起,结果是:换一种持久化方案就要重写半个索引,测试也变得极其困难。好的分层意味着索引层只关心”给我一个向量数组的指针”,而不需要知道这些向量是从 mmap 来的还是从内存分配器来的。

1.1 数据流走向

一条典型的插入请求:

Client -> HTTP POST /insert
       -> API 层解析 JSON,分配向量 ID
       -> WAL 追加写
       -> 索引层:HNSW 图插入 + PQ 编码
       -> 返回 ID

一条典型的查询请求:

Client -> HTTP POST /search
       -> API 层解析查询向量 + topK + filter
       -> 索引层:metadata pre-filter -> HNSW search -> PQ rerank
       -> 返回 topK 结果

1.2 设计原则

我在实现这个 mini engine 时遵循几个原则:

  1. 写路径先过 WAL,再更新索引。进程崩溃后可以从 WAL 重放恢复。
  2. 读路径不加写锁。用 sync.RWMutex 允许并发查询。
  3. 索引与存储解耦。向量数据可以在内存里,也可以 mmap 进来,索引层不感知。
  4. 小数据量走暴力搜索。向量数少于阈值时,HNSW 的图维护成本反而高于线性扫描。

二、距离函数:L2、余弦与内积

距离函数是向量检索的地基。选错距离函数,后面的索引再精妙也白搭。

2.1 三种常用距离

欧氏距离(L2):最直觉的”两点间距离”。用于图像特征、科学计算场景居多。

\[d_{L2}(\mathbf{x}, \mathbf{y}) = \sqrt{\sum_{i=1}^{d}(x_i - y_i)^2}\]

实际检索时通常省略开方,直接比较平方距离,结果排序不变。

余弦相似度(Cosine):衡量方向而非大小,NLP 文本嵌入的标配。

\[\text{cosine}(\mathbf{x}, \mathbf{y}) = \frac{\mathbf{x} \cdot \mathbf{y}}{\|\mathbf{x}\| \|\mathbf{y}\|}\]

如果向量已经做过 L2 归一化(单位向量),余弦相似度退化为内积,这是很多系统的优化技巧。

内积(Inner Product):推荐系统里最常见,值越大越相似。注意内积不是严格的度量(不满足三角不等式),在 HNSW 图上的搜索行为和 L2 有微妙差异。

\[\text{IP}(\mathbf{x}, \mathbf{y}) = \sum_{i=1}^{d} x_i \cdot y_i\]

2.2 SIMD 加速

距离计算是整个系统最热的路径,值得用 SIMD 指令加速。以 L2 距离为例,一个 128 维 float32 向量对之间的距离计算,朴素循环需要 128 次减法、128 次乘法、128 次加法。用 AVX2 的 256 位寄存器,每次处理 8 个 float32,循环次数降到 16。

// Go 中可以通过 unsafe + 汇编实现 SIMD,这里给出纯 Go 版本作为参考
func L2SquaredDistance(a, b []float32) float32 {
    var sum float32
    for i := range a {
        diff := a[i] - b[i]
        sum += diff * diff
    }
    return sum
}

func InnerProduct(a, b []float32) float32 {
    var sum float32
    for i := range a {
        sum += a[i] * b[i]
    }
    return sum
}

func CosineDistance(a, b []float32) float32 {
    var dot, normA, normB float32
    for i := range a {
        dot += a[i] * b[i]
        normA += a[i] * a[i]
        normB += b[i] * b[i]
    }
    if normA == 0 || normB == 0 {
        return 1.0
    }
    return 1.0 - dot/float32(math.Sqrt(float64(normA)*float64(normB)))
}

在生产系统中,这些函数通常会有对应的汇编版本。Milvus 用 CGO 调 SIMD intrinsics,Qdrant 用 Rust 的 std::simd。我们的 mini engine 先用纯 Go,后面如果需要极致性能再替换。

2.3 距离函数的选择策略

场景 推荐距离 原因
图像特征检索(ResNet、CLIP) L2 特征空间各维度量纲一致
文本嵌入(Sentence-BERT) Cosine 文本向量长度不稳定,方向才有意义
推荐系统(双塔模型) Inner Product 模型训练时即以内积为目标
归一化后的任意向量 Inner Product 归一化后 cosine = IP,省一步除法

一个实用技巧:如果你的向量来自同一个模型、同一个 batch normalize 层,大概率已经近似归一化了。这时直接用内积既快又准。

三、HNSW 索引集成

HNSW(Hierarchical Navigable Small World)是目前工业界最主流的近似最近邻索引。它的核心思想是构建一个多层跳表式的图结构:顶层稀疏、底层稠密,查询时从顶层快速定位到大致区域,再在底层精细搜索。

3.1 关键参数

3.2 图的构建过程

InsertVector(v):
    level = RandomLevel()     // 指数衰减分布,大部分节点在 level 0
    ep = entry_point

    // 从最高层开始贪心下降到 level+1
    for l = max_level down to level+1:
        ep = GreedySearch(v, ep, layer=l)

    // 从 level 层到第 0 层,执行 SearchLayer + 连边
    for l = level down to 0:
        neighbors = SearchLayer(v, ep, ef=efConstruction, layer=l)
        SelectNeighbors(v, neighbors, M)  // 启发式裁边
        for each neighbor n in selected:
            AddBidirectionalEdge(v, n, layer=l)
            if degree(n) > Mmax:
                PruneConnections(n, Mmax, layer=l)
        ep = neighbors

3.3 启发式邻居选择

原始 HNSW 论文提出了两种邻居选择策略。简单策略直接取最近的 M 个邻居;启发式策略则会检查候选邻居之间的相互距离,避免选出”扎堆”的邻居,从而让图的连通性更好。

// 启发式邻居选择:优先保留那些"不被已选邻居覆盖"的候选
func selectNeighborsHeuristic(candidates []NodeDist, m int, distFunc DistFunc) []uint32 {
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].Dist < candidates[j].Dist
    })
    selected := make([]NodeDist, 0, m)
    for _, c := range candidates {
        if len(selected) >= m {
            break
        }
        // 检查 c 是否比所有已选邻居都更"有价值"
        good := true
        for _, s := range selected {
            if distFunc(c.Vec, s.Vec) < c.Dist {
                good = false
                break
            }
        }
        if good {
            selected = append(selected, c)
        }
    }
    ids := make([]uint32, len(selected))
    for i, s := range selected {
        ids[i] = s.ID
    }
    return ids
}

3.4 并发安全

HNSW 的一个实际工程问题是并发。查询本身是只读的,多个查询可以并行执行。但插入会修改图结构,需要某种形式的同步。

最简单的方案是全局读写锁:查询拿读锁,插入拿写锁。这在 QPS 不高时够用,但在高并发写入场景会成为瓶颈。更精细的方案是节点级别的锁,但实现复杂度骤增。

我们的 mini engine 采用中间方案:整个索引一把 sync.RWMutex,查询并发执行,插入串行化。对于大部分单机场景,这已经足够。

四、乘积量化与内存压缩

当向量数量达到百万级,原始向量占用的内存可能令人咋舌。以 128 维 float32 为例,一百万条向量需要 128 * 4 * 1,000,000 = 512 MB。如果要十亿条呢?512 GB,单机根本装不下。

乘积量化(Product Quantization,PQ)是解决这个问题的经典方法。

4.1 PQ 的核心思想

\(d\) 维向量切成 \(m\) 个子向量,每个子向量独立做 K-means 聚类(通常 \(K=256\)),然后用聚类中心的编号(1 字节)代替原始子向量。这样一条 128 维 float32 向量(512 字节)被压缩成 \(m\) 个字节。当 \(m=8\) 时,压缩比高达 64 倍。

4.2 ADC 距离计算

查询时不需要解码回原始向量。用”非对称距离计算”(Asymmetric Distance Computation,ADC):预先计算查询向量到每个子空间的 256 个聚类中心的距离,存入查表(lookup table),然后对每条数据库向量,只需 \(m\) 次查表加法即可得到近似距离。

type PQIndex struct {
    M          int          // 子空间数量
    Ksub       int          // 每个子空间的聚类中心数,通常 256
    Dim        int          // 原始向量维度
    Centroids  [][]float32  // [M*Ksub][subDim] 聚类中心
    Codes      [][]byte     // [N][M] 压缩编码
}

// 构建查距离表:query 到每个子空间每个聚类中心的距离
func (pq *PQIndex) BuildDistTable(query []float32) [][]float32 {
    subDim := pq.Dim / pq.M
    table := make([][]float32, pq.M)
    for i := 0; i < pq.M; i++ {
        table[i] = make([]float32, pq.Ksub)
        qSub := query[i*subDim : (i+1)*subDim]
        for j := 0; j < pq.Ksub; j++ {
            centroid := pq.Centroids[i*pq.Ksub+j]
            table[i][j] = L2SquaredDistance(qSub, centroid)
        }
    }
    return table
}

// ADC 距离:用查表代替真实距离计算
func (pq *PQIndex) ADCDistance(table [][]float32, code []byte) float32 {
    var dist float32
    for i := 0; i < pq.M; i++ {
        dist += table[i][code[i]]
    }
    return dist
}

4.3 PQ 与 HNSW 的配合

在我们的 mini engine 中,PQ 主要用于两个场景:

  1. 重排序(rerank):HNSW 返回 \(ef\) 个候选后,用 ADC 距离精排,减少需要加载的原始向量数量。
  2. 纯 PQ 暴力搜索:当向量数量不多但维度很高时,PQ 可以把暴力搜索的内存占用压下来,同时因为缓存友好(每条向量只有几个字节),吞吐量反而比原始向量暴力搜索更高。

五、WAL:崩溃恢复的最后防线

任何持久化系统都需要面对一个问题:进程随时可能崩溃,内存中的数据转瞬即逝。WAL(Write-Ahead Log)是数据库领域的经典解法:先把操作写入磁盘日志,再更新内存中的索引。崩溃后从日志重放即可恢复。

5.1 WAL 记录格式

+--------+--------+--------+--------+--------+
| CRC32  | Length | OpType |   ID   | Vector |
| 4 byte | 4 byte | 1 byte | 8 byte | N byte |
+--------+--------+--------+--------+--------+

5.2 Go 实现

const (
    OpInsert byte = 0x01
    OpDelete byte = 0x02
)

type WAL struct {
    file *os.File
    mu   sync.Mutex
}

func OpenWAL(path string) (*WAL, error) {
    f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
    if err != nil {
        return nil, err
    }
    return &WAL{file: f}, nil
}

func (w *WAL) AppendInsert(id uint64, vec []float32) error {
    w.mu.Lock()
    defer w.mu.Unlock()

    vecBytes := float32SliceToBytes(vec)
    payload := make([]byte, 1+8+len(vecBytes))
    payload[0] = OpInsert
    binary.LittleEndian.PutUint64(payload[1:9], id)
    copy(payload[9:], vecBytes)

    return w.writeRecord(payload)
}

func (w *WAL) AppendDelete(id uint64) error {
    w.mu.Lock()
    defer w.mu.Unlock()

    payload := make([]byte, 1+8)
    payload[0] = OpDelete
    binary.LittleEndian.PutUint64(payload[1:9], id)

    return w.writeRecord(payload)
}

func (w *WAL) writeRecord(payload []byte) error {
    length := uint32(len(payload))
    checksum := crc32.ChecksumIEEE(payload)

    header := make([]byte, 8)
    binary.LittleEndian.PutUint32(header[0:4], checksum)
    binary.LittleEndian.PutUint32(header[4:8], length)

    if _, err := w.file.Write(header); err != nil {
        return err
    }
    if _, err := w.file.Write(payload); err != nil {
        return err
    }
    return w.file.Sync()
}

func (w *WAL) Replay(onInsert func(uint64, []float32), onDelete func(uint64)) error {
    w.file.Seek(0, io.SeekStart)
    reader := bufio.NewReader(w.file)

    for {
        header := make([]byte, 8)
        if _, err := io.ReadFull(reader, header); err != nil {
            if err == io.EOF || err == io.ErrUnexpectedEOF {
                return nil
            }
            return err
        }

        checksum := binary.LittleEndian.Uint32(header[0:4])
        length := binary.LittleEndian.Uint32(header[4:8])

        payload := make([]byte, length)
        if _, err := io.ReadFull(reader, payload); err != nil {
            return nil // 部分写入的记录,忽略
        }

        if crc32.ChecksumIEEE(payload) != checksum {
            return nil // CRC 不匹配,说明是不完整记录
        }

        switch payload[0] {
        case OpInsert:
            id := binary.LittleEndian.Uint64(payload[1:9])
            vec := bytesToFloat32Slice(payload[9:])
            onInsert(id, vec)
        case OpDelete:
            id := binary.LittleEndian.Uint64(payload[1:9])
            onDelete(id)
        }
    }
}

func (w *WAL) Truncate() error {
    w.mu.Lock()
    defer w.mu.Unlock()
    return w.file.Truncate(0)
}

5.3 fsync 策略

每次写入都调 file.Sync() 是最安全的,但也最慢。生产系统通常有三种策略:每条 fsync(最安全,吞吐最低)、批量 fsync(攒一批后统一刷盘)、定时 fsync(每隔固定时间刷一次,最多丢失一个窗口的数据)。我们的 mini engine 采用”每批 fsync”策略:batch insert 接口接收一批向量,全部追加到 WAL 后做一次 fsync,再批量更新索引。

六、mmap:让操作系统管理你的内存

当向量数据量超过物理内存时,最直接的想法是”分页加载”。但手写分页逻辑又复杂又容易出 bug。mmap 提供了一条捷径:把文件映射到虚拟地址空间,让操作系统的页缓存替你管理哪些数据在内存、哪些在磁盘。

6.1 mmap 基本原理

func MmapFile(path string, size int) ([]byte, error) {
    f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644)
    if err != nil {
        return nil, err
    }
    defer f.Close()

    // 确保文件足够大
    if err := f.Truncate(int64(size)); err != nil {
        return nil, err
    }

    data, err := syscall.Mmap(
        int(f.Fd()), 0, size,
        syscall.PROT_READ|syscall.PROT_WRITE,
        syscall.MAP_SHARED,
    )
    if err != nil {
        return nil, err
    }
    return data, nil
}

// 通过 mmap 读写向量,就像操作普通 byte slice 一样
func GetVector(mmapData []byte, id int, dim int) []float32 {
    offset := id * dim * 4
    return bytesToFloat32Slice(mmapData[offset : offset+dim*4])
}

6.2 mmap 的优势与陷阱

优势:零拷贝读取(直接在页缓存上操作)、自动页面调度(操作系统根据访问模式调入调出)、简化代码(不需要自己管理缓冲区和淘汰策略)。

陷阱:page fault 不可控(对 P99 敏感的场景需配合 madvise(MADV_WILLNEED) 预取);MAP_SHARED 写入刷盘时机不确定(写路径仍需走 WAL);32 位系统的虚拟地址空间不够用(64 位无此问题);Linux 的 THP 可能导致内存膨胀(建议用 madvise(MADV_HUGEPAGE) 显式控制)。

6.3 mmap 在向量引擎中的使用模式

典型的使用模式是:

  1. 向量数据本体用 mmap 映射一个大文件。文件布局是定长记录:[vec_0][vec_1]...[vec_N],每条记录 dim * sizeof(float32) 字节。
  2. HNSW 图索引保持在内存中(图的邻接表相比向量本身占用小得多)。
  3. 查询时,HNSW 返回候选 ID 列表,再通过 mmap 按 ID 直接读取原始向量做精排。

这种模式下,热门向量自然留在页缓存里,冷门向量被操作系统换出到磁盘,无需手写 LRU。DiskANN 本质上也是这个思路,只不过在图遍历层面做了更多优化来减少随机 IO。

七、元数据过滤:前置与后置

真实的向量检索几乎从来不是”给我最近的 K 个向量”这么简单。用户通常会带上过滤条件:“在 category=shoes 且 price < 100 的商品中,找最相似的 10 个。”

7.1 后置过滤(Post-filter)

最简单的实现:先做纯向量检索,拿到一批候选,再用元数据条件过滤。

func PostFilterSearch(engine *Engine, query []float32, k int, filter FilterFunc) []Result {
    // 多取一些候选,因为过滤会淘汰一部分
    overFetch := k * 10
    candidates := engine.HNSWSearch(query, overFetch)

    results := make([]Result, 0, k)
    for _, c := range candidates {
        if filter(c.Metadata) {
            results = append(results, c)
            if len(results) >= k {
                break
            }
        }
    }
    return results
}

问题很明显:如果满足过滤条件的向量占比很低(高选择性),你可能取了 10K 个候选还凑不够 K 个结果。此时要么反复加大 overFetch,要么放弃。

7.2 前置过滤(Pre-filter)

在索引检索之前,先用元数据条件筛出一个候选集(bitmap),然后只在这个候选集上做向量检索。

func PreFilterSearch(engine *Engine, query []float32, k int, filter FilterFunc) []Result {
    // 第一步:元数据过滤,得到 bitmap
    bitmap := engine.MetadataIndex.Filter(filter)
    matchCount := bitmap.Count()

    if matchCount == 0 {
        return nil
    }

    // 如果匹配数量少,直接暴力搜索
    if matchCount < 1000 {
        return engine.BruteForceSearchWithBitmap(query, k, bitmap)
    }

    // 否则在 HNSW 中搜索,但跳过不在 bitmap 中的节点
    return engine.HNSWSearchWithBitmap(query, k, bitmap)
}

前置过滤的挑战在于:HNSW 图的连通性建立在全部向量之上,如果过滤掉了大量节点,图可能变得不连通,导致搜索质量下降。一种缓解方法是在搜索时允许”穿越”被过滤的节点(不返回它们,但可以经过它们走向更好的候选),Qdrant 就是这么做的。

7.3 混合策略

生产系统通常根据过滤比例动态选择策略:

过滤后剩余比例 策略 原因
> 90% Post-filter 几乎不过滤,后置最简单
20% - 90% HNSW + pre-filter bitmap 图仍然大致连通
1% - 20% HNSW with traversal through filtered 允许穿越被过滤节点
< 1% Brute-force on filtered set 候选太少,暴力搜索更快

八、完整实现:Go Mini Vector Engine

下面是一个约 300 行的完整实现,包含 HNSW 简化版索引、暴力搜索回退和 HTTP API。为了控制篇幅,省略了 PQ 和 mmap 部分,但整体架构清晰可扩展。

package main

import (
    "container/heap"
    "encoding/json"
    "fmt"
    "log"
    "math"
    "math/rand"
    "net/http"
    "sort"
    "sync"
    "sync/atomic"
    "time"
)

// --- 距离函数 ---

func l2Squared(a, b []float32) float32 {
    var sum float32
    for i := range a {
        d := a[i] - b[i]
        sum += d * d
    }
    return sum
}

// --- Min-Heap for topK ---

type Result struct {
    ID   uint64  `json:"id"`
    Dist float32 `json:"distance"`
}

type MaxHeap []Result

func (h MaxHeap) Len() int            { return len(h) }
func (h MaxHeap) Less(i, j int) bool   { return h[i].Dist > h[j].Dist }
func (h MaxHeap) Swap(i, j int)        { h[i], h[j] = h[j], h[i] }
func (h *MaxHeap) Push(x interface{})  { *h = append(*h, x.(Result)) }
func (h *MaxHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[:n-1]
    return x
}

// --- HNSW 简化实现 ---

const (
    hnswM             = 16
    hnswMmax0         = 32
    hnswEfConstruction = 200
    hnswMl            = 1.0 / math.Ln2 // 1/ln(M)
)

type hnswNode struct {
    id      uint64
    vec     []float32
    level   int
    friends [][]uint32 // friends[layer] = list of neighbor indices
}

type hnswIndex struct {
    nodes      []*hnswNode
    entryPoint int
    maxLevel   int
    idToIdx    map[uint64]int
}

func newHNSW() *hnswIndex {
    return &hnswIndex{
        entryPoint: -1,
        maxLevel:   -1,
        idToIdx:    make(map[uint64]int),
    }
}

func (h *hnswIndex) randomLevel() int {
    level := 0
    for rand.Float64() < 0.5 && level < 8 {
        level++
    }
    return level
}

func (h *hnswIndex) searchLayer(query []float32, ep int, ef int, layer int) []Result {
    visited := make(map[int]bool)
    visited[ep] = true

    candidates := &MaxHeap{{ID: uint64(ep), Dist: l2Squared(query, h.nodes[ep].vec)}}
    heap.Init(candidates)

    results := &MaxHeap{{ID: uint64(ep), Dist: l2Squared(query, h.nodes[ep].vec)}}
    heap.Init(results)

    for candidates.Len() > 0 {
        current := heap.Pop(candidates).(Result)

        if results.Len() >= ef {
            worstResult := (*results)[0]
            if current.Dist > worstResult.Dist {
                break
            }
        }

        idx := int(current.ID)
        if idx >= len(h.nodes) || layer >= len(h.nodes[idx].friends) {
            continue
        }
        for _, neighborIdx := range h.nodes[idx].friends[layer] {
            ni := int(neighborIdx)
            if visited[ni] {
                continue
            }
            visited[ni] = true

            dist := l2Squared(query, h.nodes[ni].vec)

            if results.Len() < ef {
                heap.Push(results, Result{ID: uint64(ni), Dist: dist})
                heap.Push(candidates, Result{ID: uint64(ni), Dist: dist})
            } else if dist < (*results)[0].Dist {
                heap.Pop(results)
                heap.Push(results, Result{ID: uint64(ni), Dist: dist})
                heap.Push(candidates, Result{ID: uint64(ni), Dist: dist})
            }
        }
    }

    out := make([]Result, results.Len())
    for i := results.Len() - 1; i >= 0; i-- {
        out[i] = heap.Pop(results).(Result)
    }
    return out
}

func (h *hnswIndex) selectNeighbors(candidates []Result, m int) []uint32 {
    sort.Slice(candidates, func(i, j int) bool {
        return candidates[i].Dist < candidates[j].Dist
    })
    n := m
    if n > len(candidates) {
        n = len(candidates)
    }
    result := make([]uint32, n)
    for i := 0; i < n; i++ {
        result[i] = uint32(candidates[i].ID)
    }
    return result
}

func (h *hnswIndex) insert(id uint64, vec []float32) {
    level := h.randomLevel()
    idx := len(h.nodes)
    node := &hnswNode{
        id:      id,
        vec:     vec,
        level:   level,
        friends: make([][]uint32, level+1),
    }
    for i := range node.friends {
        node.friends[i] = make([]uint32, 0, hnswM)
    }
    h.nodes = append(h.nodes, node)
    h.idToIdx[id] = idx

    if h.entryPoint == -1 {
        h.entryPoint = idx
        h.maxLevel = level
        return
    }

    ep := h.entryPoint

    // 从最高层贪心下降到 level+1
    for l := h.maxLevel; l > level; l-- {
        results := h.searchLayer(vec, ep, 1, l)
        if len(results) > 0 {
            ep = int(results[0].ID)
        }
    }

    // 从 min(level, maxLevel) 到第 0 层,搜索并连边
    topLayer := level
    if topLayer > h.maxLevel {
        topLayer = h.maxLevel
    }
    for l := topLayer; l >= 0; l-- {
        candidates := h.searchLayer(vec, ep, hnswEfConstruction, l)
        maxM := hnswM
        if l == 0 {
            maxM = hnswMmax0
        }
        neighbors := h.selectNeighbors(candidates, maxM)
        node.friends[l] = neighbors

        // 添加反向边
        for _, nIdx := range neighbors {
            ni := int(nIdx)
            if l < len(h.nodes[ni].friends) {
                h.nodes[ni].friends[l] = append(h.nodes[ni].friends[l], uint32(idx))
                // 裁剪超过 maxM 的邻居
                if len(h.nodes[ni].friends[l]) > maxM {
                    friendDists := make([]Result, len(h.nodes[ni].friends[l]))
                    for fi, fIdx := range h.nodes[ni].friends[l] {
                        friendDists[fi] = Result{
                            ID:   uint64(fIdx),
                            Dist: l2Squared(h.nodes[ni].vec, h.nodes[int(fIdx)].vec),
                        }
                    }
                    h.nodes[ni].friends[l] = h.selectNeighbors(friendDists, maxM)
                }
            }
        }

        if len(candidates) > 0 {
            ep = int(candidates[0].ID)
        }
    }

    if level > h.maxLevel {
        h.maxLevel = level
        h.entryPoint = idx
    }
}

func (h *hnswIndex) search(query []float32, k int, ef int) []Result {
    if h.entryPoint == -1 {
        return nil
    }

    ep := h.entryPoint
    for l := h.maxLevel; l > 0; l-- {
        results := h.searchLayer(query, ep, 1, l)
        if len(results) > 0 {
            ep = int(results[0].ID)
        }
    }

    candidates := h.searchLayer(query, ep, ef, 0)

    // 把内部索引转换为外部 ID
    for i := range candidates {
        candidates[i].ID = h.nodes[int(candidates[i].ID)].id
    }

    if len(candidates) > k {
        candidates = candidates[:k]
    }
    return candidates
}

// --- 向量引擎 ---

const bruteForceThreshold = 1000

type VectorEngine struct {
    mu       sync.RWMutex
    hnsw     *hnswIndex
    vectors  map[uint64][]float32
    nextID   uint64
    dim      int
    count    int64
    efSearch int
}

func NewVectorEngine(dim int) *VectorEngine {
    return &VectorEngine{
        hnsw:     newHNSW(),
        vectors:  make(map[uint64][]float32),
        dim:      dim,
        efSearch: 64,
    }
}

func (e *VectorEngine) Insert(vec []float32) uint64 {
    e.mu.Lock()
    defer e.mu.Unlock()

    id := e.nextID
    e.nextID++

    copied := make([]float32, len(vec))
    copy(copied, vec)
    e.vectors[id] = copied

    e.hnsw.insert(id, copied)
    atomic.AddInt64(&e.count, 1)
    return id
}

func (e *VectorEngine) BatchInsert(vecs [][]float32) []uint64 {
    e.mu.Lock()
    defer e.mu.Unlock()

    ids := make([]uint64, len(vecs))
    for i, vec := range vecs {
        id := e.nextID
        e.nextID++

        copied := make([]float32, len(vec))
        copy(copied, vec)
        e.vectors[id] = copied

        e.hnsw.insert(id, copied)
        atomic.AddInt64(&e.count, 1)
        ids[i] = id
    }
    return ids
}

func (e *VectorEngine) Search(query []float32, k int) []Result {
    e.mu.RLock()
    defer e.mu.RUnlock()

    n := atomic.LoadInt64(&e.count)

    // 向量数少时用暴力搜索
    if n < bruteForceThreshold {
        return e.bruteForceSearch(query, k)
    }
    return e.hnsw.search(query, k, e.efSearch)
}

func (e *VectorEngine) bruteForceSearch(query []float32, k int) []Result {
    h := &MaxHeap{}
    heap.Init(h)

    for id, vec := range e.vectors {
        dist := l2Squared(query, vec)
        if h.Len() < k {
            heap.Push(h, Result{ID: id, Dist: dist})
        } else if dist < (*h)[0].Dist {
            heap.Pop(h)
            heap.Push(h, Result{ID: id, Dist: dist})
        }
    }

    results := make([]Result, h.Len())
    for i := h.Len() - 1; i >= 0; i-- {
        results[i] = heap.Pop(h).(Result)
    }
    return results
}

func (e *VectorEngine) Stats() map[string]interface{} {
    return map[string]interface{}{
        "count":     atomic.LoadInt64(&e.count),
        "dimension": e.dim,
        "ef_search": e.efSearch,
    }
}

// --- HTTP API ---

type InsertRequest struct {
    Vectors [][]float32 `json:"vectors"`
}

type InsertResponse struct {
    IDs []uint64 `json:"ids"`
}

type SearchRequest struct {
    Vector []float32 `json:"vector"`
    TopK   int       `json:"top_k"`
}

type SearchResponse struct {
    Results []Result `json:"results"`
    Latency string   `json:"latency"`
}

func main() {
    dim := 128
    engine := NewVectorEngine(dim)

    http.HandleFunc("/insert", func(w http.ResponseWriter, r *http.Request) {
        if r.Method != http.MethodPost {
            http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
            return
        }
        var req InsertRequest
        if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
            http.Error(w, err.Error(), http.StatusBadRequest)
            return
        }
        for _, v := range req.Vectors {
            if len(v) != dim {
                http.Error(w, fmt.Sprintf("expected dim %d", dim), http.StatusBadRequest)
                return
            }
        }
        ids := engine.BatchInsert(req.Vectors)
        json.NewEncoder(w).Encode(InsertResponse{IDs: ids})
    })

    http.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) {
        if r.Method != http.MethodPost {
            http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
            return
        }
        var req SearchRequest
        if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
            http.Error(w, err.Error(), http.StatusBadRequest)
            return
        }
        if len(req.Vector) != dim {
            http.Error(w, fmt.Sprintf("expected dim %d", dim), http.StatusBadRequest)
            return
        }
        if req.TopK <= 0 {
            req.TopK = 10
        }

        start := time.Now()
        results := engine.Search(req.Vector, req.TopK)
        latency := time.Since(start)

        json.NewEncoder(w).Encode(SearchResponse{
            Results: results,
            Latency: latency.String(),
        })
    })

    http.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) {
        json.NewEncoder(w).Encode(engine.Stats())
    })

    addr := ":8080"
    log.Printf("vector engine listening on %s (dim=%d)", addr, dim)
    log.Fatal(http.ListenAndServe(addr, nil))
}

8.1 代码要点解读

  1. HNSW 核心在 insertsearchLayersearchLayer 维护一个大小为 \(ef\) 的结果集和一个候选队列,贪心地扩展离查询最近的未访问节点。insert 先从高层贪心定位,再在每一层搜索并连边。

  2. 暴力搜索用 MaxHeap 实现 topK。维护一个大小为 \(K\) 的最大堆,堆顶是当前第 K 近的距离。新候选只有比堆顶更近时才入堆,保证堆始终保存最近的 K 个。

  3. bruteForceThreshold = 1000。当向量数少于 1000 时,暴力搜索的延迟反而比 HNSW 低,因为省去了图维护的开销。这个阈值在实际系统中需要根据维度和硬件调整。

  4. sync.RWMutex。查询拿读锁,可以并发执行;插入拿写锁,保证图结构一致性。

8.2 测试用例

用 curl 即可验证:

# 插入两条向量
curl -s -X POST http://localhost:8080/insert \
  -d '{"vectors": [[1.0, 0.5, 0.3, ...], [0.2, 0.8, 0.1, ...]]}' \
  | jq .

# 查询最近的 5 个
curl -s -X POST http://localhost:8080/search \
  -d '{"vector": [1.0, 0.5, 0.3, ...], "top_k": 5}' \
  | jq .

# 查看统计信息
curl -s http://localhost:8080/stats | jq .

九、批量插入与并发搜索

9.1 批量插入优化

逐条插入的瓶颈在于:每次插入都要走一遍 HNSW 图遍历(搜索邻居 + 连边),而且每条都要写 WAL + fsync。批量插入可以:

  1. 攒批写 WAL:把整批向量一次性追加到 WAL,只做一次 fsync。
  2. 批量建图:虽然 HNSW 的插入天然是逐条的,但连续插入同一批向量时,缓存局部性更好,因为前一条刚访问过的节点很可能还在 L1/L2 缓存里。
func (e *VectorEngine) BatchInsertWithWAL(wal *WAL, vecs [][]float32) ([]uint64, error) {
    // 先全部写入 WAL
    ids := make([]uint64, len(vecs))
    for i, vec := range vecs {
        id := atomic.AddUint64(&e.nextID, 1) - 1
        ids[i] = id
        if err := wal.AppendInsert(id, vec); err != nil {
            return nil, err
        }
    }
    // 一次 fsync
    wal.file.Sync()

    // 再批量更新索引
    e.mu.Lock()
    defer e.mu.Unlock()
    for i, vec := range vecs {
        copied := make([]float32, len(vec))
        copy(copied, vec)
        e.vectors[ids[i]] = copied
        e.hnsw.insert(ids[i], copied)
        atomic.AddInt64(&e.count, 1)
    }

    return ids, nil
}

9.2 并发搜索性能

由于搜索只拿读锁,理论上可以线性扩展到 CPU 核数。实际测试中,瓶颈往往不在锁竞争,而在:

var resultPool = sync.Pool{
    New: func() interface{} {
        h := make(MaxHeap, 0, 128)
        return &h
    },
}

func (e *VectorEngine) SearchPooled(query []float32, k int) []Result {
    e.mu.RLock()
    defer e.mu.RUnlock()

    h := resultPool.Get().(*MaxHeap)
    *h = (*h)[:0]
    defer resultPool.Put(h)

    // ... 搜索逻辑复用 h 而非每次 make
    return e.hnsw.search(query, k, e.efSearch)
}

9.3 基准测试框架

func BenchmarkSearch(b *testing.B) {
    engine := NewVectorEngine(128)
    // 插入 100K 随机向量
    for i := 0; i < 100000; i++ {
        vec := randomVector(128)
        engine.Insert(vec)
    }
    query := randomVector(128)

    b.ResetTimer()
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            engine.Search(query, 10)
        }
    })
}

十、Benchmark:SIFT1M 上的吞吐与召回

10.1 测试设置

SIFT1M 是向量检索领域的标准测试集:100 万条 128 维 uint8 向量(转为 float32),10,000 条查询,每条查询有 100 个真实最近邻(ground truth)。

我们用以下配置测试 mini engine:

10.2 构建性能

构建时间约 180 秒,内存占用(向量 + 图)约 1.2 GB,图平均出度 14.8。相比 C++ 实现(hnswlib 约 40 秒)慢了 4-5 倍,主要原因是 Go 的内存分配开销和缺少 SIMD 加速。

10.3 查询性能

efSearch QPS(单线程) QPS(8 线程) 平均延迟
32 0.88 4,200 28,000 0.24 ms
64 0.95 2,800 19,500 0.35 ms
128 0.98 1,500 10,800 0.67 ms
256 0.993 800 5,900 1.25 ms
512 0.998 420 3,100 2.38 ms

几个观察:

  1. efSearch 翻倍,召回率提升递减,但延迟大致翻倍。这是典型的”精度-速度权衡”曲线。
  2. 8 线程下 QPS 约为单线程的 6-7 倍,没有达到理论的 8 倍,瓶颈在内存带宽。
  3. = 0.95 对应的延迟约 0.35 ms,对大多数在线服务来说已经够用。

10.4 与 PQ 结合

加入 PQ(m=8)后的变化:

配置 内存占用 QPS(单线程)
HNSW only 1.2 GB 0.95 2,800
HNSW + PQ rerank 0.6 GB 0.93 3,200
PQ brute-force 0.1 GB 0.85 1,800

PQ rerank 的内存减半,QPS 反而略有提升(因为原始向量不再需要全部常驻内存),但召回率损失约 2 个百分点。

十一、与 Milvus 和 Qdrant 的架构对比

我们的 mini engine 是一个单机、单进程的简化实现。工业级的向量数据库要复杂得多。

11.1 Milvus 架构

Milvus 采用存算分离的分布式架构:

Client -> Proxy -> QueryNode / DataNode / IndexNode -> MinIO/S3

11.2 Qdrant 架构

Qdrant 更偏向”单机强一致 + 水平扩展”的路线:用 Rust 编写,性能极高;单节点内用 segment 管理数据,每个 segment 独立持有索引和向量数据;支持 HNSW 图上的前置过滤(允许穿越被过滤节点);WAL 由自己实现,每个 collection 一个 WAL;分布式层用 Raft 做 leader 选举和一致性复制。

11.3 对比表

维度 Mini Engine Milvus Qdrant
语言 Go Go + C++ Rust
部署方式 单进程 分布式(6+ 组件) 单节点或集群
索引类型 HNSW HNSW/IVF-PQ/DiskANN HNSW
持久化 WAL + mmap S3 + Pulsar 自研 WAL + segment
过滤 简单 post-filter pre/post 混合 图上穿越过滤
扩展性 单机 水平扩展 水平扩展(Raft)
适用规模 < 1000 万 十亿级 亿级

从这个对比可以看出,分布式向量数据库的核心复杂度不在检索算法本身,而在数据管理(segment 生命周期、compaction、负载均衡)和一致性保证(WAL、复制、故障恢复)。这也是为什么我认为理解单机版本的每个组件是理解分布式版本的前提。

十二、工程踩坑表与个人思考

12.1 工程踩坑表

问题 现象 根因 解决方案
召回率突然下降 新插入的向量查不到 HNSW entry point 没有更新到最高层节点 插入时检查 level > maxLevel 并更新 entry point
内存持续增长 长时间运行后 OOM Go map 删除元素不释放底层内存 定期重建 map,或改用 slice + 空闲列表
查询延迟毛刺 P99 远高于 P50 GC STW 暂停 sync.Pool 减少分配,设置 GOGC=400
mmap 读取慢 冷启动后首次查询很慢 page fault 加载数据 启动时 madvise(MADV_WILLNEED) 预热
WAL 文件膨胀 磁盘空间不足 只追加不截断 快照后截断 WAL
并发写入丢数据 多线程插入结果不一致 WAL 的 mutex 和索引的 mutex 不是同一把 统一用引擎级别的写锁
距离计算结果异常 搜索结果完全错误 float32 精度问题,NaN 传播 插入前校验向量,拒绝含 NaN/Inf 的输入
HNSW 构建极慢 百万级构建超过一小时 efConstruction 设置过高(800) 降到 200,召回率损失 < 0.5%
连边时 panic 数组越界 节点的 friends 层数不够 确保 friends slice 长度 = level + 1
暴力搜索比 HNSW 快 1 万条向量 HNSW 反而慢 图遍历的常数开销大于线性扫描 低于阈值时自动回退暴力搜索

12.2 性能调优清单

  1. 先 profile 再优化。用 go tool pprof 看 CPU 热点,通常 80% 的时间花在距离计算上。
  2. 向量对齐。确保向量数组按 32 字节对齐(AVX2 要求),Go 的 slice 底层数组通常已经对齐。
  3. 减少堆分配。HNSW 搜索时大量创建 Result 结构体,用对象池或栈上分配。
  4. 预计算归一化。如果用余弦距离,插入时就做好归一化,查询时直接用内积。
  5. 分段构建。百万级数据集可以分批构建,每批 10 万条,利用批量插入的缓存友好性。

12.3 个人思考

做完这个 mini engine,我有几点体会:

向量搜索引擎的核心不在算法,在工程。 HNSW 论文的伪代码不超过一页,但把它工程化为一个可靠的服务,需要处理的边界条件(并发、持久化、异常输入、内存管理)比算法本身多出一个数量级。

Go 写这类系统有优势也有劣势。 优势是并发模型简单、部署方便(静态链接单二进制文件)、工具链完善(pprof、race detector)。劣势是 GC 的不确定性让延迟 SLA 难以保证,缺少 SIMD intrinsics 让距离计算比 C++/Rust 慢 3-5 倍。如果追求极致性能,核心的距离计算和 HNSW 遍历应该用 C/Rust 实现,Go 只做上层的 HTTP 服务和调度。

mmap 是把双刃剑。 它让代码简洁了很多,但把内存管理的控制权让渡给了操作系统。在可预测的工作负载下(比如向量检索这种均匀随机访问),mmap 工作得很好。但如果访问模式有明显的冷热分区,手写的缓存池可以做得比操作系统的页缓存更好。

WAL 看似简单,魔鬼在细节里。 CRC 校验、部分写入处理、fsync 的时机、WAL 文件的生命周期管理,每一项都可以展开写一篇文章。我们的实现是最小化的,生产系统需要考虑更多边界情况(比如 WAL 文件损坏后的自动修复、多段 WAL 的轮转)。

不要过早分布式化。 单机向量引擎配上 mmap 和合理的索引参数,可以轻松应对千万级数据。只有当单机的内存、算力或可用性确实不够时,才值得引入分布式架构的复杂度。Milvus 的 6+ 组件架构对小团队来说是沉重的运维负担。

最后,我想引用一句在系统设计领域广为流传的话作为结尾:

“Make it work, make it right, make it fast.” —— Kent Beck

我们的 mini engine 已经 work 了。让它 right(可靠持久化)和 fast(SIMD、更好的内存布局),就是接下来的事情了。


上一篇: ScaNN 与 DiskANN 下一篇: Dijkstra 与 A*

相关阅读: - HNSW - 乘积量化与 IVF-PQ


By .