聚合与TopN问题
本篇是《大数据算法与UDF系列》的第3篇,深入讲解聚合函数的底层原理,以及大数据场景下TopN问题的各种解法。
1. 聚合函数基础
1.1 什么是聚合?
聚合(Aggregation) 是将多行数据压缩成一行或多行的操作:
原始数据 (4行):
┌────────┬───────┐
│ name │ score │
├────────┼───────┤
│ Alice │ 90 │
│ Bob │ 85 │
│ Charlie│ 95 │
│ David │ 80 │
└────────┴───────┘
↓ 聚合
聚合结果 (1行):
┌───────────────┐
│ avg_score:87.5│
└───────────────┘1.2 Spark聚合执行原理
┌─────────────────────────────────────────────────────────────────┐
│ Spark 聚合执行流程 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入: 按行读取 │
│ ──────────────────────────────────────────────────────── │
│ Row1, Row2, Row3, Row4, Row5... │
│ │
│ Step 1: shuffle (按Group Key分区) │
│ ──────────────────────────────────────────────────────── │
│ Group A → partition1 │
│ Group B → partition2 │
│ Group C → partition3 │
│ │
│ Step 2: 内存聚合 │
│ ──────────────────────────────────────────────────────── │
│ partition1: A [90, 85, 70] → 聚合 → (A, 245) │
│ partition2: B [95, 88] → 聚合 → (B, 183) │
│ partition3: C [60] → 聚合 → (C, 60) │
│ │
│ Step 3: 输出 │
│ ──────────────────────────────────────────────────────── │
│ ┌────────┬────────┐ │
│ │ Group │ Total │ │
│ ├────────┼────────┤ │
│ │ A │ 245 │ │
│ │ B │ 183 │ │
│ │ C │ 60 │ │
│ └────────┴────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘1.3 常用聚合函数
from pyspark.sql import functions as F
# 基础聚合
df.agg(
F.count("*"), # 总行数
F.count("col"), # 非空行数
F.countDistinct("col"),# 去重计数
F.sum("col"), # 求和
F.avg("col"), # 平均
F.min("col"), # 最小值
F.max("col"), # 最大值
F.first("col"), # 第一个值
F.last("col"), # 最后一个值
)2. 分组聚合
2.1 基础分组
# 按单列分组
df.groupBy("department").agg(
F.sum("salary").alias("total_salary"),
F.avg("salary").alias("avg_salary"),
F.count("employee_id").alias("employee_count")
)// Scala版本
import org.apache.spark.sql.functions._
df.groupBy("department").agg(
sum("salary").as("total_salary"),
avg("salary").as("avg_salary"),
count("employee_id").as("employee_count")
)2.2 多列分组
# 按多列分组
df.groupBy("department", "year").agg(
F.sum("sales").alias("total_sales"),
F.max("sales").alias("max_sales")
)2.3 分组后过滤(Having)
# SQL: SELECT department, SUM(salary) as total
# FROM employees GROUP BY department
# HAVING SUM(salary) > 100000
df.groupBy("department") \
.agg(F.sum("salary").alias("total")) \
.filter(F.col("total") > 100000)3. 聚合优化:预聚合
3.1 为什么要预聚合?
场景:1年的日活数据(3.6亿条),需要统计月度UV
❌ 低效方案:每次查询都从原始数据计算
- 读取3.6亿条数据
- 按月分组、去重、计数
- 每次查询都要几分钟
✅ 优化方案:预聚合(日活→月活)
- 每天凌晨计算日活,存入日活表(365条)
- 查询月度UV只需读取12条数据,毫秒级3.2 预聚合实现
# 每天凌晨运行:计算日活
daily_uv = df.groupBy("date").agg(
F.approx_count_distinct("user_id").alias("daily_uv")
)
daily_uv.write.mode("overwrite").partitionBy("date") \
.saveAsTable("daily_uv_table")
# 查询月度UV:直接汇总日活表
monthly_uv = spark.sql("""
SELECT
substr(date, 1, 7) as month,
SUM(daily_uv) as monthly_uv
FROM daily_uv_table
WHERE date >= '2024-01-01' AND date < '2024-02-01'
GROUP BY substr(date, 1, 7)
""")4. TopN问题详解
4.1 什么是TopN?
TopN 是指从数据集中找出排名前N的记录:
原始数据(按销售额):
┌────────┬───────┐
│ product│ sales │
├────────┼───────┤
│ A │ 1000 │
│ B │ 800 │
│ C │ 600 │
│ D │ 400 │
│ E │ 200 │
└────────┴───────┘
Top3 结果:
┌────────┬───────┐
│ product│ sales │
├────────┼───────┤
│ A │ 1000 │
│ B │ 800 │
│ C │ 600 │
└────────┴───────┘4.2 TopN的业务场景
| 场景 | 需求 |
|---|---|
| 销售分析 | 每个类别销量Top10商品 |
| 用户分析 | 活跃度Top100用户 |
| 日志分析 | 最近100条错误日志 |
| 实时榜单 | 实时热度Top10 |
4.3 精确TopN实现
方法1:窗口函数 + 过滤
from pyspark.sql import functions as F
from pyspark.sql.window import Window
# 每个类别销售额Top3
window_spec = Window.partitionBy("category").orderBy(F.desc("sales"))
top3 = df.withColumn("rank", F.row_number().over(window_spec)) \
.filter(F.col("rank") <= 3) \
.drop("rank")// Scala
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy("category").orderBy(desc("sales"))
val top3 = df.withColumn("rank", row_number().over(windowSpec))
.filter(col("rank") <= 3)
.drop("rank")执行流程图:
┌─────────────────────────────────────────────────────────────────┐
│ TopN 执行流程 (窗口函数方法) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 原始数据: │
│ ┌─────────┬──────┬──────┐ │
│ │ category│product│sales │ │
│ ├─────────┼──────┼──────┤ │
│ │ A │ P1 │ 1000 │ │
│ │ A │ P2 │ 800 │ │
│ │ A │ P3 │ 600 │ │
│ │ A │ P4 │ 400 │ │
│ │ B │ P5 │ 900 │ │
│ │ B │ P6 │ 700 │ │
│ │ B │ P7 │ 500 │ │
│ └─────────┴──────┴──────┘ │
│ ↓ │
│ Step 1: 按category分区,区内按sales排序 │
│ ───────────────────────────────────── │
│ Category A: Category B: │
│ [P1:1000] [P5:900] │
│ [P2:800 ] [P6:700] │
│ [P3:600 ] [P7:500] │
│ [P4:400 ] │
│ ↓ │
│ Step 2: 添加排名 │
│ ───────────────────────────────────── │
│ ┌─────────┬──────┬──────┬─────┐ │
│ │ category│product│sales │rank │ │
│ ├─────────┼──────┼──────┼─────┤ │
│ │ A │ P1 │ 1000 │ 1 │ │
│ │ A │ P2 │ 800 │ 2 │ ← 取rank <= 3 │
│ │ A │ P3 │ 600 │ 3 │ │
│ │ A │ P4 │ 400 │ 4 │ ← 过滤掉 │
│ │ B │ P5 │ 900 │ 1 │ │
│ │ B │ P6 │ 700 │ 2 │ ← 取rank <= 3 │
│ │ B │ P7 │ 500 │ 3 │ │
│ └─────────┴──────┴──────┴─────┘ │
│ ↓ │
│ 最终结果: │
│ ┌─────────┬──────┬──────┐ │
│ │ category│product│sales │ │
│ ├─────────┼──────┼──────┤ │
│ │ A │ P1 │ 1000 │ │
│ │ A │ P2 │ 800 │ │
│ │ A │ P3 │ 600 │ │
│ │ B │ P5 │ 900 │ │
│ │ B │ P6 │ 700 │ │
│ │ B │ P7 │ 500 │ │
│ └─────────┴──────┴──────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘4.4 分组TopN:每个组取TopN
完整示例
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
spark = SparkSession.builder.getOrCreate()
# 创建示例数据
data = [
("2024-01-01", "Electronics", "iPhone", 1000),
("2024-01-01", "Electronics", "MacBook", 800),
("2024-01-01", "Electronics", "iPad", 600),
("2024-01-01", "Electronics", "AirPods", 400),
("2024-01-01", "Electronics", "Watch", 200),
("2024-01-01", "Clothing", "Jacket", 500),
("2024-01-01", "Clothing", "Jeans", 300),
("2024-01-01", "Clothing", "Shirt", 200),
("2024-01-01", "Clothing", "Hat", 100),
("2024-01-02", "Electronics", "iPhone", 1100),
("2024-01-02", "Electronics", "MacBook", 900),
("2024-01-02", "Electronics", "iPad", 650),
]
df = spark.createDataFrame(data, ["date", "category", "product", "sales"])
# 方案1:窗口函数(适合中小数据量)
window_spec = Window.partitionBy("date", "category").orderBy(F.desc("sales"))
result = df.withColumn("rank", F.row_number().over(window_spec)) \
.filter(F.col("rank") <= 3) \
.drop("rank")
result.show()Scala版本
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
object TopNDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("TopNDemo")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val data = Seq(
("2024-01-01", "Electronics", "iPhone", 1000),
("2024-01-01", "Electronics", "MacBook", 800),
("2024-01-01", "Electronics", "iPad", 600),
("2024-01-01", "Electronics", "AirPods", 400),
("2024-01-01", "Electronics", "Watch", 200),
("2024-01-01", "Clothing", "Jacket", 500),
("2024-01-01", "Clothing", "Jeans", 300),
("2024-01-01", "Clothing", "Shirt", 200),
("2024-01-02", "Electronics", "iPhone", 1100),
("2024-01-02", "Electronics", "MacBook", 900)
)
val df = data.toDF("date", "category", "product", "sales")
val windowSpec = Window.partitionBy("date", "category")
.orderBy(desc("sales"))
val result = df.withColumn("rank", row_number().over(windowSpec))
.filter(col("rank") <= 3)
.drop("rank")
result.show()
spark.stop()
}
}5. 近似TopN优化
5.1 何时需要近似算法?
| 场景 | 数据量 | 推荐方案 |
|---|---|---|
| 每日Top100 | 100万 | 精确方案 |
| 实时Top100 | 1亿/秒 | 流式近似算法 |
| 历史全量分析 | 10亿+ | 采样 + 近似 |
5.2 近似TopN:空间优化的秘密
精确TopN问题:
- 需要保存所有数据以便排序
- 内存占用 = O(N),N=总数据量
- 时间复杂度 = O(N log N)
近似TopN(Count-Min Sketch):
- 只保存TopN的候选
- 内存占用 = O(N_top * epsilon)
- 时间复杂度 = O(N)
空间对比(1亿条数据,取Top100):
- 精确:需要保存1亿条
- 近似:只需保存~1万条(采样率1%)5.3 近似TopN实现
使用Spark内置近似函数
# 近似计数(用于估算基数)
df.agg(F.approx_count_distinct("user_id", 0.01).alias("uv"))
# 近似分位数
df.approxQuantile("salary", [0.25, 0.5, 0.75], 0.01)自定义近似TopN UDF
import heapq
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType
def approximate_topn(n=100):
"""近似TopN:使用堆保持TopN"""
heap = []
def add(item):
if len(heap) < n:
heapq.heappush(heap, item)
elif item[1] > heap[0][1]:
heapq.heapreplace(heap, item)
return sorted(heap, reverse=True)
return add
# 使用
topn_udf = udf(approximate_topn(10), ArrayType(...))6. 高级TopN技巧
6.1 TopN with Tie(并列排名)
# 使用 RANK 或 DENSE_RANK 处理并列
window_spec = Window.partitionBy("category").orderBy(F.desc("sales"))
# RANK: 1,1,3,4(跳跃)
df.withColumn("rank", F.rank().over(window_spec))
# DENSE_RANK: 1,1,2,3(连续)
df.withColumn("dense_rank", F.dense_rank().over(window_spec))
# 取Top3(包含并列)
df.withColumn("dense_rank", F.dense_rank().over(window_spec)) \
.filter(F.col("dense_rank") <= 3)6.2 TopN with Window Frame
# 取每个类别销量Top3,以及对应的销售额
window_spec = Window.partitionBy("category") \
.orderBy(F.desc("sales")) \
.rowsBetween(Window.currentRow, Window.currentRow + 2)
# 窗口内求和(Top3的总和)
df.withColumn("top3_sum", F.sum("sales").over(window_spec))6.3 全局TopN vs 分组TopN
# 全局TopN(不分組)
window_global = Window.orderBy(F.desc("sales"))
df.withColumn("global_rank", F.row_number().over(window_global)) \
.filter(F.col("global_rank") <= 10)
# 分组TopN(按类别)
window_group = Window.partitionBy("category").orderBy(F.desc("sales"))
df.withColumn("group_rank", F.row_number().over(window_group)) \
.filter(F.col("group_rank") <= 10)7. 常见问题与优化
7.1 数据倾斜
问题:某个分组数据量特别大
# ❌ 低效:直接分组TopN
df.groupBy("category", "product") \
.agg(F.sum("sales").alias("total_sales"))
# 数据倾斜!
# ✅ 优化:先按category分组,再取TopN
# 第一步:每个category内部先TopN
category_top = df.groupBy("category", "product") \
.agg(F.sum("sales").alias("sales")) \
.withColumn("rank", F.row_number().over(
Window.partitionBy("category").orderBy(F.desc("sales"))
)) \
.filter(F.col("rank") <= 100) # 先过滤
# 第二步:全局再聚合
result = category_top.groupBy("category", "product") \
.agg(F.sum("sales").alias("total_sales"))7.2 内存溢出
问题:数据量太大,内存不够
解决方案:
增加分区数
df.repartition(1000) # 增加分区数,减少单分区数据使用磁盘
spark.conf.set("spark.sql.shuffle.spill", "true")分批处理
# 分批处理每个key for batch in df.toLocalIterator(): process(batch)
7.3 重复计算
问题:多个TopN查询重复扫描数据
解决方案:预计算 + 缓存
# 预计算所有分组TopN
precomputed = df.groupBy("category", "product") \
.agg(F.sum("sales").alias("total_sales"))
precomputed.cache() # 缓存结果
# 后续查询直接从缓存读取
precomputed.filter(F.col("category") == "Electronics") \
.orderBy(F.desc("total_sales")).limit(10)8. 实时TopN(流处理)
8.1 Flink SQL实现
-- 滚动TopN:每5分钟更新一次Top10
SELECT *
FROM (
SELECT
ROW_NUMBER() OVER (ORDER BY cnt DESC) as rank,
product_id,
cnt
FROM (
SELECT
product_id,
COUNT(*) as cnt
FROM orders
WHERE order_time >= CURRENT_TIMESTAMP - INTERVAL '5' MINUTE
GROUP BY product_id
)
)
WHERE rank <= 108.2 PySpark Streaming
from pyspark.sql import functions as F
from pyspark.sql.window import Window
# 增量TopN(滑动窗口)
streaming_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "orders") \
.load()
# 10分钟窗口,每5分钟更新
result = streaming_df \
.groupBy(F.window("timestamp", "10 minutes"), "product_id") \
.agg(F.count("*").alias("cnt")) \
.withColumn("rank", F.row_number().over(
Window.partitionBy("window").orderBy(F.desc("cnt"))
)) \
.filter(F.col("rank") <= 10)9. 完整示例:电商销售分析
9.1 需求
分析每个品类、每天的:
- 销售额Top3商品
- 销售额总计
- 销售额占比
9.2 完整代码
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
spark = SparkSession.builder.getOrCreate()
# 模拟数据
sales = [
("2024-01-01", "Electronics", "iPhone", 1000),
("2024-01-01", "Electronics", "MacBook", 800),
("2024-01-01", "Electronics", "iPad", 600),
("2024-01-01", "Electronics", "AirPods", 400),
("2024-01-01", "Clothing", "Jacket", 500),
("2024-01-01", "Clothing", "Jeans", 300),
("2024-01-02", "Electronics", "iPhone", 1100),
("2024-01-02", "Electronics", "MacBook", 900),
("2024-01-02", "Electronics", "iPad", 700),
("2024-01-02", "Clothing", "Jacket", 550),
("2024-01-02", "Clothing", "Jeans", 320),
("2024-01-02", "Clothing", "Shirt", 180),
]
df = spark.createDataFrame(sales, ["date", "category", "product", "sales"])
# Step 1: 按日期+品类+商品聚合
agg_df = df.groupBy("date", "category", "product") \
.agg(F.sum("sales").alias("total_sales"))
# Step 2: 添加品类内排名
window_spec = Window.partitionBy("date", "category") \
.orderBy(F.desc("total_sales"))
ranked_df = agg_df.withColumn(
"rank",
F.row_number().over(window_spec)
)
# Step 3: 取Top3
top3 = ranked_df.filter(F.col("rank") <= 3)
# Step 4: 计算品类总销售额
category_total = agg_df.groupBy("date", "category") \
.agg(F.sum("total_sales").alias("category_total"))
# Step 5: 关联计算占比
result = top3.join(category_total, ["date", "category"]) \
.withColumn("sales_ratio",
F.round(F.col("total_sales") / F.col("category_total") * 100, 2)
) \
.select("date", "category", "product", "total_sales", "rank", "sales_ratio")
result.show()输出结果:
+----------+-----------+-------+-----------+----+----------+
| date| category|product|total_sales|rank|sales_ratio|
+----------+-----------+-------+-----------+----+----------+
|2024-01-01|Electronics| iPhone| 1000| 1| 40.82|
|2024-01-01|Electronics|MacBook| 800| 2| 32.65|
|2024-01-01|Electronics| iPad| 600| 3| 24.49|
|2024-01-01| Clothing| Jacket| 500| 1| 62.5|
|2024-01-01| Clothing| Jeans| 300| 2| 37.5|
|2024-01-02|Electronics| iPhone| 1100| 1| 38.46|
|2024-01-02|Electronics|MacBook| 900| 2| 31.58|
|2024-01-02|Electronics| iPad| 700| 3| 24.56|
|2024-01-02| Clothing| Jacket| 550| 1| 56.7|
|2024-01-02| Clothing| Jeans| 320| 2| 32.99|
|2024-01-02| Clothing| Shirt| 180| 3| 18.56|
+----------+-----------+-------+-----------+----+----------+10. 练习题
练习1:双重TopN
找出每个部门、每个月销售额Top3的员工。
练习2:TopN with 聚合
计算每个品类Top3商品的销售额占该品类总销售额的比例。
练习3:实时TopN
使用PySpark Structured Streaming实现实时Top10热门商品。
順子の杂货铺


