顺子の杂货铺
生命不息,折腾不止,且行且珍惜~

0x03-聚合与TopN问题

DMIT VPS

聚合与TopN问题

本篇是《大数据算法与UDF系列》的第3篇,深入讲解聚合函数的底层原理,以及大数据场景下TopN问题的各种解法。


1. 聚合函数基础

1.1 什么是聚合?

聚合(Aggregation) 是将多行数据压缩成一行或多行的操作:

原始数据 (4行):
┌────────┬───────┐
│ name   │ score │
├────────┼───────┤
│ Alice  │ 90    │
│ Bob    │ 85    │
│ Charlie│ 95    │
│ David  │ 80    │
└────────┴───────┘
        ↓ 聚合
聚合结果 (1行):
┌───────────────┐
│ avg_score:87.5│
└───────────────┘

1.2 Spark聚合执行原理

┌─────────────────────────────────────────────────────────────────┐
│                    Spark 聚合执行流程                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   输入: 按行读取                                                │
│   ────────────────────────────────────────────────────────     │
│        Row1, Row2, Row3, Row4, Row5...                        │
│                                                                 │
│   Step 1: shuffle (按Group Key分区)                            │
│   ────────────────────────────────────────────────────────     │
│        Group A → partition1                                     │
│        Group B → partition2                                     │
│        Group C → partition3                                     │
│                                                                 │
│   Step 2: 内存聚合                                             │
│   ────────────────────────────────────────────────────────     │
│        partition1: A [90, 85, 70] → 聚合 → (A, 245)            │
│        partition2: B [95, 88]    → 聚合 → (B, 183)            │
│        partition3: C [60]       → 聚合 → (C, 60)              │
│                                                                 │
│   Step 3: 输出                                                 │
│   ────────────────────────────────────────────────────────     │
│        ┌────────┬────────┐                                     │
│        │ Group  │ Total  │                                     │
│        ├────────┼────────┤                                     │
│        │ A      │ 245    │                                     │
│        │ B      │ 183    │                                     │
│        │ C      │ 60     │                                     │
│        └────────┴────────┘                                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.3 常用聚合函数

from pyspark.sql import functions as F

# 基础聚合
df.agg(
    F.count("*"),           # 总行数
    F.count("col"),         # 非空行数
    F.countDistinct("col"),# 去重计数
    F.sum("col"),           # 求和
    F.avg("col"),           # 平均
    F.min("col"),           # 最小值
    F.max("col"),           # 最大值
    F.first("col"),         # 第一个值
    F.last("col"),          # 最后一个值
)

2. 分组聚合

2.1 基础分组

# 按单列分组
df.groupBy("department").agg(
    F.sum("salary").alias("total_salary"),
    F.avg("salary").alias("avg_salary"),
    F.count("employee_id").alias("employee_count")
)
// Scala版本
import org.apache.spark.sql.functions._

df.groupBy("department").agg(
  sum("salary").as("total_salary"),
  avg("salary").as("avg_salary"),
  count("employee_id").as("employee_count")
)

2.2 多列分组

# 按多列分组
df.groupBy("department", "year").agg(
    F.sum("sales").alias("total_sales"),
    F.max("sales").alias("max_sales")
)

2.3 分组后过滤(Having)

# SQL: SELECT department, SUM(salary) as total
#      FROM employees GROUP BY department
#      HAVING SUM(salary) > 100000

df.groupBy("department") \
  .agg(F.sum("salary").alias("total")) \
  .filter(F.col("total") > 100000)

3. 聚合优化:预聚合

3.1 为什么要预聚合?

场景:1年的日活数据(3.6亿条),需要统计月度UV

❌ 低效方案:每次查询都从原始数据计算
   - 读取3.6亿条数据
   - 按月分组、去重、计数
   - 每次查询都要几分钟

✅ 优化方案:预聚合(日活→月活)
   - 每天凌晨计算日活,存入日活表(365条)
   - 查询月度UV只需读取12条数据,毫秒级

3.2 预聚合实现

# 每天凌晨运行:计算日活
daily_uv = df.groupBy("date").agg(
    F.approx_count_distinct("user_id").alias("daily_uv")
)
daily_uv.write.mode("overwrite").partitionBy("date") \
    .saveAsTable("daily_uv_table")

# 查询月度UV:直接汇总日活表
monthly_uv = spark.sql("""
    SELECT
        substr(date, 1, 7) as month,
        SUM(daily_uv) as monthly_uv
    FROM daily_uv_table
    WHERE date >= '2024-01-01' AND date < '2024-02-01'
    GROUP BY substr(date, 1, 7)
""")

4. TopN问题详解

4.1 什么是TopN?

TopN 是指从数据集中找出排名前N的记录:

原始数据(按销售额):
┌────────┬───────┐
│ product│ sales │
├────────┼───────┤
│ A      │ 1000  │
│ B      │ 800   │
│ C      │ 600   │
│ D      │ 400   │
│ E      │ 200   │
└────────┴───────┘

Top3 结果:
┌────────┬───────┐
│ product│ sales │
├────────┼───────┤
│ A      │ 1000  │
│ B      │ 800   │
│ C      │ 600   │
└────────┴───────┘

4.2 TopN的业务场景

场景需求
销售分析每个类别销量Top10商品
用户分析活跃度Top100用户
日志分析最近100条错误日志
实时榜单实时热度Top10

4.3 精确TopN实现

方法1:窗口函数 + 过滤

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 每个类别销售额Top3
window_spec = Window.partitionBy("category").orderBy(F.desc("sales"))

top3 = df.withColumn("rank", F.row_number().over(window_spec)) \
         .filter(F.col("rank") <= 3) \
         .drop("rank")
// Scala
import org.apache.spark.sql.expressions.Window

val windowSpec = Window.partitionBy("category").orderBy(desc("sales"))

val top3 = df.withColumn("rank", row_number().over(windowSpec))
  .filter(col("rank") <= 3)
  .drop("rank")

执行流程图

┌─────────────────────────────────────────────────────────────────┐
│                  TopN 执行流程 (窗口函数方法)                      │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   原始数据:                                                     │
│   ┌─────────┬──────┬──────┐                                    │
│   │ category│product│sales │                                    │
│   ├─────────┼──────┼──────┤                                    │
│   │ A       │ P1   │ 1000 │                                    │
│   │ A       │ P2   │ 800  │                                    │
│   │ A       │ P3   │ 600  │                                    │
│   │ A       │ P4   │ 400  │                                    │
│   │ B       │ P5   │ 900  │                                    │
│   │ B       │ P6   │ 700  │                                    │
│   │ B       │ P7   │ 500  │                                    │
│   └─────────┴──────┴──────┘                                    │
│                        ↓                                        │
│   Step 1: 按category分区,区内按sales排序                       │
│   ─────────────────────────────────────                        │
│   Category A:     Category B:                                  │
│   [P1:1000]       [P5:900]                                     │
│   [P2:800 ]       [P6:700]                                     │
│   [P3:600 ]       [P7:500]                                     │
│   [P4:400 ]                                                │
│                        ↓                                        │
│   Step 2: 添加排名                                              │
│   ─────────────────────────────────────                        │
│   ┌─────────┬──────┬──────┬─────┐                              │
│   │ category│product│sales │rank │                             │
│   ├─────────┼──────┼──────┼─────┤                              │
│   │ A       │ P1   │ 1000 │  1  │                              │
│   │ A       │ P2   │ 800  │  2  │  ← 取rank <= 3               │
│   │ A       │ P3   │ 600  │  3  │                              │
│   │ A       │ P4   │ 400  │  4  │  ← 过滤掉                   │
│   │ B       │ P5   │ 900  │  1  │                              │
│   │ B       │ P6   │ 700  │  2  │  ← 取rank <= 3               │
│   │ B       │ P7   │ 500  │  3  │                              │
│   └─────────┴──────┴──────┴─────┘                              │
│                        ↓                                        │
│   最终结果:                                                     │
│   ┌─────────┬──────┬──────┐                                    │
│   │ category│product│sales │                                    │
│   ├─────────┼──────┼──────┤                                    │
│   │ A       │ P1   │ 1000 │                                    │
│   │ A       │ P2   │ 800  │                                    │
│   │ A       │ P3   │ 600  │                                    │
│   │ B       │ P5   │ 900  │                                    │
│   │ B       │ P6   │ 700  │                                    │
│   │ B       │ P7   │ 500  │                                    │
│   └─────────┴──────┴──────┘                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

4.4 分组TopN:每个组取TopN

完整示例

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

# 创建示例数据
data = [
    ("2024-01-01", "Electronics", "iPhone", 1000),
    ("2024-01-01", "Electronics", "MacBook", 800),
    ("2024-01-01", "Electronics", "iPad", 600),
    ("2024-01-01", "Electronics", "AirPods", 400),
    ("2024-01-01", "Electronics", "Watch", 200),
    ("2024-01-01", "Clothing", "Jacket", 500),
    ("2024-01-01", "Clothing", "Jeans", 300),
    ("2024-01-01", "Clothing", "Shirt", 200),
    ("2024-01-01", "Clothing", "Hat", 100),
    ("2024-01-02", "Electronics", "iPhone", 1100),
    ("2024-01-02", "Electronics", "MacBook", 900),
    ("2024-01-02", "Electronics", "iPad", 650),
]

df = spark.createDataFrame(data, ["date", "category", "product", "sales"])

# 方案1:窗口函数(适合中小数据量)
window_spec = Window.partitionBy("date", "category").orderBy(F.desc("sales"))

result = df.withColumn("rank", F.row_number().over(window_spec)) \
           .filter(F.col("rank") <= 3) \
           .drop("rank")

result.show()

Scala版本

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

object TopNDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TopNDemo")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    val data = Seq(
      ("2024-01-01", "Electronics", "iPhone", 1000),
      ("2024-01-01", "Electronics", "MacBook", 800),
      ("2024-01-01", "Electronics", "iPad", 600),
      ("2024-01-01", "Electronics", "AirPods", 400),
      ("2024-01-01", "Electronics", "Watch", 200),
      ("2024-01-01", "Clothing", "Jacket", 500),
      ("2024-01-01", "Clothing", "Jeans", 300),
      ("2024-01-01", "Clothing", "Shirt", 200),
      ("2024-01-02", "Electronics", "iPhone", 1100),
      ("2024-01-02", "Electronics", "MacBook", 900)
    )

    val df = data.toDF("date", "category", "product", "sales")

    val windowSpec = Window.partitionBy("date", "category")
      .orderBy(desc("sales"))

    val result = df.withColumn("rank", row_number().over(windowSpec))
      .filter(col("rank") <= 3)
      .drop("rank")

    result.show()

    spark.stop()
  }
}

5. 近似TopN优化

5.1 何时需要近似算法?

场景数据量推荐方案
每日Top100100万精确方案
实时Top1001亿/秒流式近似算法
历史全量分析10亿+采样 + 近似

5.2 近似TopN:空间优化的秘密

精确TopN问题:
- 需要保存所有数据以便排序
- 内存占用 = O(N),N=总数据量
- 时间复杂度 = O(N log N)

近似TopN(Count-Min Sketch):
- 只保存TopN的候选
- 内存占用 = O(N_top * epsilon)
- 时间复杂度 = O(N)

空间对比(1亿条数据,取Top100):
- 精确:需要保存1亿条
- 近似:只需保存~1万条(采样率1%)

5.3 近似TopN实现

使用Spark内置近似函数

# 近似计数(用于估算基数)
df.agg(F.approx_count_distinct("user_id", 0.01).alias("uv"))

# 近似分位数
df.approxQuantile("salary", [0.25, 0.5, 0.75], 0.01)

自定义近似TopN UDF

import heapq
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType

def approximate_topn(n=100):
    """近似TopN:使用堆保持TopN"""
    heap = []

    def add(item):
        if len(heap) < n:
            heapq.heappush(heap, item)
        elif item[1] > heap[0][1]:
            heapq.heapreplace(heap, item)
        return sorted(heap, reverse=True)

    return add

# 使用
topn_udf = udf(approximate_topn(10), ArrayType(...))

6. 高级TopN技巧

6.1 TopN with Tie(并列排名)

# 使用 RANK 或 DENSE_RANK 处理并列
window_spec = Window.partitionBy("category").orderBy(F.desc("sales"))

# RANK: 1,1,3,4(跳跃)
df.withColumn("rank", F.rank().over(window_spec))

# DENSE_RANK: 1,1,2,3(连续)
df.withColumn("dense_rank", F.dense_rank().over(window_spec))

# 取Top3(包含并列)
df.withColumn("dense_rank", F.dense_rank().over(window_spec)) \
  .filter(F.col("dense_rank") <= 3)

6.2 TopN with Window Frame

# 取每个类别销量Top3,以及对应的销售额
window_spec = Window.partitionBy("category") \
                    .orderBy(F.desc("sales")) \
                    .rowsBetween(Window.currentRow, Window.currentRow + 2)

# 窗口内求和(Top3的总和)
df.withColumn("top3_sum", F.sum("sales").over(window_spec))

6.3 全局TopN vs 分组TopN

# 全局TopN(不分組)
window_global = Window.orderBy(F.desc("sales"))
df.withColumn("global_rank", F.row_number().over(window_global)) \
  .filter(F.col("global_rank") <= 10)

# 分组TopN(按类别)
window_group = Window.partitionBy("category").orderBy(F.desc("sales"))
df.withColumn("group_rank", F.row_number().over(window_group)) \
  .filter(F.col("group_rank") <= 10)

7. 常见问题与优化

7.1 数据倾斜

问题:某个分组数据量特别大

# ❌ 低效:直接分组TopN
df.groupBy("category", "product") \
  .agg(F.sum("sales").alias("total_sales"))
# 数据倾斜!

# ✅ 优化:先按category分组,再取TopN
# 第一步:每个category内部先TopN
category_top = df.groupBy("category", "product") \
    .agg(F.sum("sales").alias("sales")) \
    .withColumn("rank", F.row_number().over(
        Window.partitionBy("category").orderBy(F.desc("sales"))
    )) \
    .filter(F.col("rank") <= 100)  # 先过滤

# 第二步:全局再聚合
result = category_top.groupBy("category", "product") \
    .agg(F.sum("sales").alias("total_sales"))

7.2 内存溢出

问题:数据量太大,内存不够

解决方案

  1. 增加分区数

    df.repartition(1000)  # 增加分区数,减少单分区数据
  2. 使用磁盘

    spark.conf.set("spark.sql.shuffle.spill", "true")
  3. 分批处理

    # 分批处理每个key
    for batch in df.toLocalIterator():
       process(batch)

7.3 重复计算

问题:多个TopN查询重复扫描数据

解决方案:预计算 + 缓存

# 预计算所有分组TopN
precomputed = df.groupBy("category", "product") \
    .agg(F.sum("sales").alias("total_sales"))

precomputed.cache()  # 缓存结果

# 后续查询直接从缓存读取
precomputed.filter(F.col("category") == "Electronics") \
    .orderBy(F.desc("total_sales")).limit(10)

8. 实时TopN(流处理)

8.1 Flink SQL实现

-- 滚动TopN:每5分钟更新一次Top10
SELECT *
FROM (
    SELECT
        ROW_NUMBER() OVER (ORDER BY cnt DESC) as rank,
        product_id,
        cnt
    FROM (
        SELECT
            product_id,
            COUNT(*) as cnt
        FROM orders
        WHERE order_time >= CURRENT_TIMESTAMP - INTERVAL '5' MINUTE
        GROUP BY product_id
    )
)
WHERE rank <= 10

8.2 PySpark Streaming

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 增量TopN(滑动窗口)
streaming_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "orders") \
    .load()

# 10分钟窗口,每5分钟更新
result = streaming_df \
    .groupBy(F.window("timestamp", "10 minutes"), "product_id") \
    .agg(F.count("*").alias("cnt")) \
    .withColumn("rank", F.row_number().over(
        Window.partitionBy("window").orderBy(F.desc("cnt"))
    )) \
    .filter(F.col("rank") <= 10)

9. 完整示例:电商销售分析

9.1 需求

分析每个品类、每天的:

  1. 销售额Top3商品
  2. 销售额总计
  3. 销售额占比

9.2 完整代码

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

# 模拟数据
sales = [
    ("2024-01-01", "Electronics", "iPhone", 1000),
    ("2024-01-01", "Electronics", "MacBook", 800),
    ("2024-01-01", "Electronics", "iPad", 600),
    ("2024-01-01", "Electronics", "AirPods", 400),
    ("2024-01-01", "Clothing", "Jacket", 500),
    ("2024-01-01", "Clothing", "Jeans", 300),
    ("2024-01-02", "Electronics", "iPhone", 1100),
    ("2024-01-02", "Electronics", "MacBook", 900),
    ("2024-01-02", "Electronics", "iPad", 700),
    ("2024-01-02", "Clothing", "Jacket", 550),
    ("2024-01-02", "Clothing", "Jeans", 320),
    ("2024-01-02", "Clothing", "Shirt", 180),
]

df = spark.createDataFrame(sales, ["date", "category", "product", "sales"])

# Step 1: 按日期+品类+商品聚合
agg_df = df.groupBy("date", "category", "product") \
    .agg(F.sum("sales").alias("total_sales"))

# Step 2: 添加品类内排名
window_spec = Window.partitionBy("date", "category") \
                     .orderBy(F.desc("total_sales"))

ranked_df = agg_df.withColumn(
    "rank",
    F.row_number().over(window_spec)
)

# Step 3: 取Top3
top3 = ranked_df.filter(F.col("rank") <= 3)

# Step 4: 计算品类总销售额
category_total = agg_df.groupBy("date", "category") \
    .agg(F.sum("total_sales").alias("category_total"))

# Step 5: 关联计算占比
result = top3.join(category_total, ["date", "category"]) \
    .withColumn("sales_ratio",
        F.round(F.col("total_sales") / F.col("category_total") * 100, 2)
    ) \
    .select("date", "category", "product", "total_sales", "rank", "sales_ratio")

result.show()

输出结果

+----------+-----------+-------+-----------+----+----------+
|      date|   category|product|total_sales|rank|sales_ratio|
+----------+-----------+-------+-----------+----+----------+
|2024-01-01|Electronics| iPhone|       1000|   1|     40.82|
|2024-01-01|Electronics|MacBook|        800|   2|     32.65|
|2024-01-01|Electronics|    iPad|        600|   3|     24.49|
|2024-01-01|  Clothing| Jacket|        500|   1|     62.5|
|2024-01-01|  Clothing|  Jeans|        300|   2|      37.5|
|2024-01-02|Electronics| iPhone|       1100|   1|     38.46|
|2024-01-02|Electronics|MacBook|        900|   2|     31.58|
|2024-01-02|Electronics|    iPad|        700|   3|     24.56|
|2024-01-02|  Clothing| Jacket|        550|   1|      56.7|
|2024-01-02|  Clothing|  Jeans|        320|   2|     32.99|
|2024-01-02|  Clothing|  Shirt|        180|   3|     18.56|
+----------+-----------+-------+-----------+----+----------+

10. 练习题

练习1:双重TopN

找出每个部门、每个月销售额Top3的员工。

练习2:TopN with 聚合

计算每个品类Top3商品的销售额占该品类总销售额的比例。

练习3:实时TopN

使用PySpark Structured Streaming实现实时Top10热门商品。


赞(0)
未经允许不得转载:順子の杂货铺 » 0x03-聚合与TopN问题
搬瓦工VPS

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

分享创造快乐

联系我们联系我们