一些常用位运算技巧

Table of Contents

1 整数的符号

int v;     // 需要获得 v 的符号
int sign;  // 结果

// CHAR_BIT 是一个字节的比特数
// v < 0 则是 -1,否则 0
sign = -(v < 0);
// 使用位运算更快
sign = -(int)((unsigned int)v >> sizeof(int) * CHAR_BIT - 1);
// 或者更简单的写法
sign = v >> (sizeof(int) * CHAR_BIT - 1);

// 如果想要得到 -1 或 +1,只需要在最低位或 1
sign = 1 | (v >> sizeof(int) * CHAR_BIT -1);

// 如果想要 -1, 0 或 1, 则需要根据情况或 1 或 0
sign = (v != 0) | (v >> (sizeof(int) * CHAR_BIT - 1));
// 或者更简单的写法
sign = (v > 0) - (v < 0);

// 如果要判断一个值是不是非负的
sign = 1 ^ ((unsigned int)v >> (sizeof(int) * CHAR_BIT - 1)); // v < 0 则 0, 否则 1

// 判断两个整数是不是符号相反,只需要将最高位异或
int x, y; // 两个整数
bool f = ((x ^ y) < 0); // 符号相反就是真,否则为假

// 获得绝对值
unsigned int r; // 绝对值
int const mask = v >> sizeof(int) * CHAR_BIT - 1;
r = (v + mask) ^ mask;
// 或者
r = (v ^ mask) - mask;

2 最大值和最小值

int x;
int y;
int r; // 结果

// 最小值
r = y ^ ((x ^ y) & -(x < y));
// 最大值
r = x ^ ((x ^ y) & -(x < y));

如果 x>=y 那么 -(x<y) 将是全 1,所以 r=y^(x^y)&~0=y^x^y=x ,否则, -(x<y) 将是全 0, 此时 r=y^((x^y)&0)=y 。

也可以使用下面的运算,但下面的运算并不保险,因为 x-y 可能溢出

r = y + ((x - y) & ((x - y) >> (sizeof(int) * CHAR_BIT - 1))); // min(x, y)
r = x - ((x - y) & ((x - y) >> (sizeof(int) * CHAR_BIT - 1))); // max(x, y)

3 判断是不是 2 的幂

只要确定最多只有一个比特是 1

unsigned int v;
bool f; // 结果

f = (v & (v - 1)) == 0;

// 注意到 0 不是 2 的倍数,所以修正如下
f = v && !(v & (v - 1));

4 b 比特位的数值转换为整数

有时为了节省空间,我们使用几个比特表示一个整数,而不是直接用 int,比如 4 个比特表示一个整数。当需要将这几个比特读出时,如果是正数没有什么问题,如果是负数,1101 就要转换成 -3,而不是 13 。

unsigned b; // 整数 x 的比特位
int x;      // 低 b 位保存需要读取的整数值
int r;      // 结果
int const m = 1U << (b - 1); // 原整数符号位

x = x & ((1U << b) - 1); // 将无关比特清零
r = (x ^ m) - m;

5 如果满足条件将第 b 位置 0 或 1

bool f; // 条件
int b;  // 设置第 b 位
unsigned int m = 1 << b;

// if (f) w |= m; else w &= ~m;
w ^= (-f ^ w) & m;

6 条件满足取相反数

bool fDontNegate; // 为假时取反
int v;
int r;

r = (fDontNegate ^ (fDontNegate - 1)) * v;

如果需要条件为真时取反

bool fNegate;
int v;
int r;

r = (v ^ -fNegate) + fNegate;

7 统计 1 的数量,也就是 Hamming weight

先将比特位两两分组,组内两个比特累加获得组内的 Hamming weight。此后再计算每相邻两组 Hamming weight 的和,获得按照每组 4bit 分组后各组的 Hamming weight。以此类推不断求和即可得到总的 Hamming weight。

例如,要计算数字 372063667 的 Hamming weight,需要如下图所示依次对相邻两个数字相加即可。

00010110001011010011110110110011
 0 1 1 1 0 1 2 1 0 2 2 1 1 2 0 2
   1   2   1   3   2   3   3   2
       3       4       5       5
	       7              10
			      17

要高效的实现这个过程,需要一些位运算的技巧。下面将使用 C 语言实现这个过程,并一步步减少需要的运算量。下面假设需要计算一个 uint64_t 的 Hamming weight。为了保护我的手指,先定义一些需要用到的常量:

const uint64_t m1  = 0x5555555555555555; //binary: 0101...
const uint64_t m2  = 0x3333333333333333; //binary: 00110011..
const uint64_t m4  = 0x0f0f0f0f0f0f0f0f; //binary:  4 zeros,  4 ones ...
const uint64_t m8  = 0x00ff00ff00ff00ff; //binary:  8 zeros,  8 ones ...
const uint64_t m16 = 0x0000ffff0000ffff; //binary: 16 zeros, 16 ones ...
const uint64_t m32 = 0x00000000ffffffff; //binary: 32 zeros, 32 ones
const uint64_t hff = 0xffffffffffffffff; //binary: all ones
const uint64_t h01 = 0x0101010101010101; //the sum of 256 to the power of 0,1,2,3...

首先是第一个实现,需要 24 个运算:

int popcount64a(uint64_t x)
{
    x = (x & m1 ) + ((x >>  1) & m1 ); //put count of each  2 bits into those  2 bits
    x = (x & m2 ) + ((x >>  2) & m2 ); //put count of each  4 bits into those  4 bits
    x = (x & m4 ) + ((x >>  4) & m4 ); //put count of each  8 bits into those  8 bits
    x = (x & m8 ) + ((x >>  8) & m8 ); //put count of each 16 bits into those 16 bits
    x = (x & m16) + ((x >> 16) & m16); //put count of each 32 bits into those 32 bits
    x = (x & m32) + ((x >> 32) & m32); //put count of each 64 bits into those 64 bits
    return x;
}

这个实现复用了变量 x 的内存区域,将累加的中间结果也存储在 x 中。首先是 m1 这一行, (x & m1) 获得 x 的奇数比特位, (x >> 1) & m1 获得 x 的偶数比特位,两者累加的结果就是将 x 的比特两两分组,组内累加得到的结果存储在各个组内两个比特的位置。 m2 行也是一样的道理,累加相邻的 m1 分组中累加和,存储在原来两个相邻分组的 4bit 空间中。以此类推,最终获得 Hamming weight。这有点像线段树区间求和的过程。

第二个实现使用 17 个运算,是对上一个实现的改进:

int popcount64b(uint64_t x)
{
    x -= (x >> 1) & m1;             //put count of each 2 bits into those 2 bits
    x = (x & m2) + ((x >> 2) & m2); //put count of each 4 bits into those 4 bits
    x = (x + (x >> 4)) & m4;        //put count of each 8 bits into those 8 bits
    x += x >>  8;  //put count of each 16 bits into their lowest 8 bits
    x += x >> 16;  //put count of each 32 bits into their lowest 8 bits
    x += x >> 32;  //put count of each 64 bits into their lowest 8 bits
    return x & 0x7f;
}

m1 行与上一个实现的 m1 行做的是一样的事情,这需要一点解释:

\begin{eqnarray*} x &=& a + 2b + 2^{2}c + 2^{3}d + 2^{4}e + \cdots \\ (x >> 1) &=& b + 2c + 2^{2}d + 2^{3}e + 2^{4}f + \cdots \\ (x >> 1) \& m1 &=& b + 2^{2}d + 2^{4}f + \cdots \\ x-((x>>1)\&m1) &=& a + b + 2{2}c + 2^{2}d + \cdots \end{eqnarray*}

可以看到 m1 行就是是比特位两两分组后组内累加结果保存在本组对应两个比特位,与上一个实现的 m1 行相同。

m2 行和上一个实现是一样的,下面看 m4 行。因为 8 比特的 Hamming weigth 绝对不会超过 8,所以最后累加只会存储到最多 4 比特的空间内,所以直接移位相加取后四位就是这个分组的累加和。

因为 64 比特的 Hamming weight 存储空间绝对不会超过 8 比特,后面三行代码将目前所有分组 (8 bit 一组) 的值都累加到最右边的 8 个比特中,最后直接与 0x7f 返回最右边的一个字节即可。此处最后与 0xFF 或许更明白,但 7 比特的存储空间已经足够了。

第三个实现更加简单,使用 12 个运算,但其中有一个是乘法。这是使用运算最少的实现。

int popcount64c(uint64_t x)
{
    x -= (x >> 1) & m1;             //put count of each 2 bits into those 2 bits
    x = (x & m2) + ((x >> 2) & m2); //put count of each 4 bits into those 4 bits
    x = (x + (x >> 4)) & m4;        //put count of each 8 bits into those 8 bits
    return (x * h01) >> 56;  //returns left 8 bits of x + (x<<8) + (x<<16) + (x<<24) + ...
}

与上一个实现不一样的地方只有最后一行。 (x * h01) 意思是将目前所有 8 比特一组的分组都累加到最左边的一个字节上, 使用小学提到的竖式乘法表示,很容易哈先 x*h01 就是 x + (x<<8) + (x<<16) + (x<<24) + ... 。最后右移 56 比特得到最左边的 8 比特就是结果,将其返回。

8 交换两个整数值

int a, b;
a ^= b;
b ^= a;
a ^= b;

9 判断整数中有没有某个字节全 0

对于每个字节减 1,再观察是不是字节内最高位是不是进位所得。

#define ahszero(v) v - 0x01010101UL & ~v & 0x80808080UL

10 枚举整数位图表示的集合的子集

依次减 1 直到 0 为止,这里包括空集,所以到 -1 为止

t = s;
do {
    t = (t - 1) & t;
} while (t != s);

如果要枚举所有大小为 k 的子集

comb = (1 << k) - 1
while (comb < 1 << n) {
    x = comb & -comb;
    y = comb + x;
    comb = ((comb & ~y) / x >> 1) | y;
    // comb 就是大小为 k 的子集
}

By .