顺子の杂货铺
生命不息,折腾不止,且行且珍惜~

0x05-近似算法HyperLogLog

DMIT VPS

近似算法HyperLogLog

本篇是《大数据算法与UDF系列》的第5篇,讲解大数据中的基数估计神器——HyperLogLog(HLL),它可以用极小的空间计算亿级UV数据。


1. 什么是基数估计?

1.1 问题背景

在数据分析中,经常需要统计去重后的数量(基数):

UV (Unique Visitor): 独立访客数
Cardinality: 去重后的元素个数

困难

  • 数据量大(亿级别)
  • 需要精确去重
  • 传统方案内存开销大

1.2 精确方案 vs 近似方案

方案空间精度速度
HashSetO(N)100%
数据库DISTINCTO(N)100%
HyperLogLogO(1)~2%
BitmapO(N/8)100%
空间对比(10亿用户):
- HashSet: ~40GB
- Bitmap: 125MB
- HyperLogLog: ~12KB ✨

2. HyperLogLog原理

2.1 核心思想

抛硬币实验

连续抛出硬币,直到出现正面为止

实验1: 正          → 1次
实验2: 反正正       → 3次
实验3: 正          → 1次
实验4: 凹凸正       → 4次
...

关键发现

  • 如果某个实验抛了很多次才出现正面,说明前面的硬币大部分都是反面
  • 如果硬币是公平的,长期来看平均需要抛2次

类比到哈希

  • 对每个元素计算哈希
  • 哈希值的二进制表示相当于"抛硬币"
  • 统计第一个"1"出现的位置
哈希值: 001001...  → 第一个1在第3位
哈希值: 101001...  → 第一个1在第1位
哈希值: 000001...  → 第一个1在第6位

2.2 算法图解

┌─────────────────────────────────────────────────────────────────┐
│                    HyperLogLog 原理                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   Step 1: 哈希                                                   │
│   ─────────────────────────────────────                        │
│                                                                 │
│   原始数据: user_1, user_2, user_3, user_4                    │
│       ↓ hash                                                    │
│   user_1 → 0b100101010...  (第一个1在第3位)                   │
│   user_2 → 0b001010101...  (第一个1在第4位)                   │
│   user_3 → 0b110010101...  (第一个1在第1位)                   │
│   user_4 → 0b010001010...  (第一个1在第5位)                   │
│                                                                 │
│   Step 2: 分桶统计                                              │
│   ─────────────────────────────────────                        │
│                                                                 │
│   假设16个桶: bucket[0] ~ bucket[15]                           │
│   取哈希的前4位决定桶号:                                        │
│       user_1 → 0b1001 → bucket 9 → max_pos = 3               │
│       user_2 → 0b0010 → bucket 2 → max_pos = 4               │
│       user_3 → 0b1100 → bucket 12 → max_pos = 1              │
│       user_4 → 0b0100 → bucket 4 → max_pos = 5               │
│                                                                 │
│   bucket[0]: 0, bucket[1]: 0, bucket[2]: 4, ...              │
│                                                                 │
│   Step 3: 估算                                                  │
│   ─────────────────────────────────────                        │
│                                                                 │
│   m = 16  (桶数)                                               │
│   M = max(bucket) = 5                                          │
│                                                                 │
│   估算公式: 基数 ≈ m * 2^M                                     │
│            = 16 * 2^5 = 16 * 32 = 512                          │
│                                                                 │
│   实际数量: 4                                                   │
│   估算结果: 512  ← 误差较大(因为样本太少)                       │
│                                                                 │
│   实际应用中,数据量越大,估算越准确                             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

2.3 估算公式

E = m * 2^M

其中:
- m: 桶的数量(通常是2^n,n=4~16)
- M: 所有桶中最大的计数值
- E: 估算的基数

修正公式(小数据量时):

if E < 5m/2:
    使用线性计数: E_raw = m * log(m / (m - V))
    其中 V = 值为0的桶数量

3. Spark中的HyperLogLog

3.1 内置函数

Spark SQL提供了approx_count_distinct函数:

from pyspark.sql import functions as F

# 近似计数(标准误差约4%)
result = df.agg(
    F.approx_count_distinct("user_id").alias("uv")
)

# 指定标准误差
result = df.agg(
    F.approx_count_distinct("user_id", 0.01).alias("uv_1pct")  # 1%误差
)

3.2 Python实现

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

# 创建1万条测试数据(有重复)
import random
data = [(str(i), random.randint(1, 1000)) for i in range(10000)]
df = spark.createDataFrame(data, ["user_id", "value"])

# 1. 精确去重
exact_count = df.select("user_id").distinct().count()
print(f"精确去重: {exact_count}")

# 2. 近似去重(默认误差~4%)
approx_count = df.agg(F.approx_count_distinct("user_id")).collect()[0][0]
print(f"近似去重(默认): {approx_count}")

# 3. 近似去重(1%误差)
approx_count_1pct = df.agg(F.approx_count_distinct("user_id", 0.01)).collect()[0][0]
print(f"近似去重(1%误差): {approx_count_1pct}")

# 4. 按分组近似计数
grouped = df.groupBy("value").agg(
    F.approx_count_distinct("user_id").alias("uv")
)
grouped.show()

3.3 Scala实现

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object HyperLogLogDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("HyperLogLogDemo")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    // 创建测试数据
    val data = (1 to 10000).map { i =>
      (s"user_$i", s"category_${i % 1000}")
    }.toDF("user_id", "category")

    // 精确去重
    val exactCount = data.select("user_id").distinct().count()
    println(s"精确去重: $exactCount")

    // 近似去重(默认)
    val approxCount = data.agg(approx_count_distinct("user_id")).first().getLong(0)
    println(s"近似去重(默认): $approxCount")

    // 近似去重(1%误差)
    val approxCount1pct = data.agg(approx_count_distinct("user_id", 0.01)).first().getLong(0)
    println(s"近似去重(1%误差): $approxCount1pct")

    // 按分组近似计数
    val grouped = data.groupBy("category")
      .agg(approx_count_distinct("user_id").as("uv"))

    grouped.show(10)

    spark.stop()
  }
}

4. 深入理解误差

4.1 误差来源

标准误差 = 1.04 / sqrt(m)

其中 m = 桶的数量

常用配置:
- 12位哈希: m = 4096 → 误差 ≈ 1.625%
- 默认: m = 1024 → 误差 ≈ 3.25%
- 粗略估计: m = 256 → 误差 ≈ 6.5%

4.2 精度 vs 空间

桶数内存标准误差适用场景
6464 bytes~13%快速估算
256256 bytes~6.5%一般分析
10241 KB~3.25%推荐默认值
40964 KB~1.6%精确分析
1638416 KB~0.8%高精度

4.3 误差实验

# 验证误差
import random

for rsd in [0.5, 0.25, 0.1, 0.05, 0.01]:
    error_ratios = []
    for _ in range(100):
        # 生成10000个唯一ID
        unique_ids = list(range(10000))
        # 随机采样进行近似估计
        sample_size = int(10000 * rsd * 10)
        sample = random.sample(unique_ids, sample_size)
        # 使用简单估算
        estimated = sample_size * (1 / rsd)
        error = abs(estimated - 10000) / 10000
        error_ratios.append(error)

    print(f"采样率 {rsd*100:.0f}%: 平均误差 {sum(error_ratios)/len(error_ratios)*100:.2f}%")

5. 实战:UV统计

5.1 业务场景

实时统计:

  • 全站UV
  • 频道UV
  • 页面UV

5.2 完整代码

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import random

spark = SparkSession.builder.getOrCreate()

# 模拟用户行为日志(1天数据)
channels = ["home", "product", "cart", "checkout", "payment"]
data = []

for i in range(100000):
    user_id = f"user_{random.randint(1, 10000)}"
    channel = random.choice(channels)
    timestamp = f"2024-01-01 {random.randint(0,23):02d}:{random.randint(0,59):02d}:00"
    data.append((user_id, channel, timestamp))

df = spark.createDataFrame(data, ["user_id", "channel", "timestamp"])
df = df.withColumn("ts", F.to_timestamp("timestamp"))

print("=" * 60)
print("1. 全站UV(精确 vs 近似)")
print("=" * 60)

exact_uv = df.select("user_id").distinct().count()
approx_uv = df.agg(F.approx_count_distinct("user_id")).collect()[0][0]

print(f"精确UV:   {exact_uv}")
print(f"近似UV:   {approx_uv}")
print(f"误差率:   {abs(approx_uv - exact_uv) / exact_uv * 100:.2f}%")

print("\n" + "=" * 60)
print("2. 各频道UV(近似)")
print("=" * 60)

channel_uv = df.groupBy("channel").agg(
    F.approx_count_distinct("user_id").alias("uv"),
    F.count("*").alias("pv")
).orderBy(F.desc("uv"))

channel_uv.show()

print("\n" + "=" * 60)
print("3. 每小时UV(近似,使用滑动窗口)")
print("=" * 60)

hourly_uv = df.groupBy(
    F.window("ts", "1 hour", "1 hour")
).agg(
    F.approx_count_distinct("user_id").alias("uv"),
    F.count("*").alias("pv")
).select(
    F.col("window.start").alias("hour"),
    "uv",
    "pv"
).orderBy("hour")

hourly_uv.show(truncate=False)

print("\n" + "=" * 60)
print("4. 各频道每小时UV(近似)")
print("=" * 60)

channel_hourly_uv = df.groupBy(
    "channel",
    F.window("ts", "1 hour", "1 hour")
).agg(
    F.approx_count_distinct("user_id").alias("uv")
).select(
    "channel",
    F.col("window.start").alias("hour"),
    "uv"
).orderBy("channel", "hour")

channel_hourly_uv.show(truncate=False)

5.3 性能对比

import time

# 创建大数据量(1000万)
big_data = [(f"user_{i % 500000}", i) for i in range(10000000)]
big_df = spark.createDataFrame(big_data, ["user_id", "value"])

# 精确去重
start = time.time()
exact = big_df.select("user_id").distinct().count()
exact_time = time.time() - start
print(f"精确去重: {exact} 条,耗时 {exact_time:.2f}秒")

# 近似去重(默认)
start = time.time()
approx = big_df.agg(F.approx_count_distinct("user_id")).collect()[0][0]
approx_time = time.time() - start
print(f"近似去重: {approx} 条,耗时 {approx_time:.4f}秒")

print(f"性能提升: {exact_time/approx_time:.0f}x")
print(f"误差率: {abs(approx - exact) / exact * 100:.2f}%")

6. HLL与其他算法对比

6.1 基数估计算法对比

算法空间精度特点
HashSetO(N)100%简单精确
BitmapN/8100%需要连续ID
LogLogm~1.3%HLL前身
HyperLogLogm~1%最常用
HyperLogLog++m~0.8%Google改进版

6.2 Spark中的选择

# 场景1: 数据量小(<100万),需要精确
df.select("user_id").distinct().count()

# 场景2: 数据量大,允许2%误差
df.agg(F.approx_count_distinct("user_id"))

# 场景3: 需要分组统计UV
df.groupBy("date").agg(
    F.approx_count_distinct("user_id").alias("uv")
)

# 场景4: 需要跨天UV(需要合并HLL)
# 使用HLL Merge函数
df.groupBy("date").agg(
    F.approx_count_distinct("user_id").alias("daily_uv")
).agg(
    F.approx_count_distinct(F.concat_ws(",", F.collect_list("daily_uv"))).alias("weekly_uv")
)

7. 自定义HyperLogLog

7.1 Python实现

import hashlib
import math

class HyperLogLog:
    """HyperLogLog Python实现"""

    def __init__(self, p=12):
        """
        p: 位数,决定桶数量 m = 2^p
        常用 p=12, m=4096, 标准误差约1.6%
        """
        self.p = p
        self.m = 1 << p  # 2^p
        self.registers = [0] * self.m
        self.alpha = self._get_alpha()

    def _get_alpha(self):
        """获取修正因子"""
        if self.p <= 10:
            return 0.673
        elif self.p == 11:
            return 0.697
        elif self.p == 12:
            return 0.709
        else:
            return 0.7213 / (1 + 1.079 / self.m)

    def _hash(self, item):
        """计算哈希值,返回前p位作为桶索引"""
        h = hashlib.sha256(str(item).encode()).hexdigest()
        # 取前16位作为哈希值
        hash_val = int(h[:16], 16)
        bucket = hash_val & (self.m - 1)  # 取低p位
        # 取剩余位计算第一个1的位置
        hash_for_count = hash_val >> self.p
        return bucket, hash_for_count

    def add(self, item):
        """添加一个元素"""
        bucket, hash_val = self._hash(item)

        # 计算前导0的数量(从左边第一个1开始数)
        # 实际上我们计算:从最高位开始,第一个1出现的位置
        # 由于hash_val是整数,我们计算 32 - leading_zeros
        leading_zeros = hash_val.bit_length()
        if leading_zeros == 0:
            max_ones = 32 - self.p + 1  # 如果hash_val=0
        else:
            max_ones = 32 - leading_zeros + 1

        # 更新桶的最大值
        self.registers[bucket] = max(self.registers[bucket], max_ones)

    def count(self):
        """估算基数"""
        # 计算调和平均数
        m = self.m
        registers = self.registers

        # 检查是否为空
        if all(r == 0 for r in registers):
            return 0

        # 调和平均
        sum_inv = sum(1.0 / (1 << r) for r in registers)
        estimate = self.alpha * m * m / sum_inv

        # 小数据量修正
        if estimate < 5 * m / 2:
            zero_count = registers.count(0)
            if zero_count != 0:
                estimate = m * math.log(m / zero_count)

        # 大数据量修正
        if estimate > (1 << 32):
            estimate = -(1 << 32) * math.log(1 - estimate / (1 << 32))

        return int(estimate)

    @staticmethod
    def merge(*hlls):
        """合并多个HLL"""
        if not hlls:
            return None

        p = hlls[0].p
        result = HyperLogLog(p)

        for hll in hlls:
            for i in range(hll.m):
                result.registers[i] = max(result.registers[i], hll.registers[i])

        return result

# 使用示例
hll = HyperLogLog(p=12)

# 添加10000个元素(实际只有100个唯一)
for i in range(10000):
    hll.add(i % 100)

print(f"实际基数: 100")
print(f"估算基数: {hll.count()}")
print(f"误差率: {abs(hll.count() - 100) / 100 * 100:.2f}%")

8. 常见问题

8.1 数据倾斜

# ❌ 问题:某些key数据量特别大
df.groupBy("category").agg(
    F.approx_count_distinct("user_id")
)

# ✅ 优化:先过滤异常值
df.filter(F.col("category") != "unknown").groupBy("category").agg(
    F.approx_count_distinct("user_id")
)

8.2 跨天UV计算

# 方案1: 每天分别计算,存储HLL
# 每天的HLL可以合并

# 方案2: 使用Spark的HLL merge功能
# 在Spark 3.0+ 中,可以使用 builtin 函数
from pyspark.sql.functions import approx_count_distinct

# 正确的跨天UV计算
daily_uv = df.groupBy("date").agg(
    F.approx_count_distinct("user_id").alias("daily_uv")
)

# 总体UV(近似)
total_uv = df.agg(F.approx_count_distinct("user_id"))

8.3 内存问题

# HLL 本身非常小,不会造成内存问题
# 但如果数据量太大,可以:
# 1. 增加分区数
df.repartition(1000)

# 2. 使用采样
sampled = df.sample(0.1)
sampled.agg(F.approx_count_distinct("user_id"))

9. 练习题

练习1

对比不同p值(8, 10, 12, 14)对精度的影响。

练习2

使用HLL实现每日UV、每周UV、每月UV的统计。

练习3

实现两个HLL的合并操作。

练习4

将HLL应用到实际的用户行为数据中,验证误差率。


赞(0)
未经允许不得转载:順子の杂货铺 » 0x05-近似算法HyperLogLog
搬瓦工VPS

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

分享创造快乐

联系我们联系我们