模型效果的衰减,往往不是从复杂的算法层面开始,而是从不起眼的数据管道延迟开始。在构建一个实时推荐系统中,我们面临的第一个核心挑战就是特征的“时效性”,尤其是对于向量化特征。一个用户的行为刚刚发生,我们期望在秒级内就能更新其向量表达,并投入到下一次的推荐计算中。传统的批处理架构,哪怕是 mini-batch,其分钟级的延迟在这里都是不可接受的。
问题的症结在于,特征的生命周期被割裂在了三个独立的系统里:
- **数据湖 (Source of Truth)**:存储着原始用户行为日志,通常是 Append-only 模式,可靠但查询缓慢。
- **在线特征存储 (Online Serving)**:为了低延迟查询,通常使用 Redis 或类似的 KV 存储。但这对于向量的相似度搜索(ANN)无能为力。
- **监控系统 (Observability)**:独立的监控栈,通常只关心 CPU、内存等系统指标,对数据管道内部的“健康度”——例如特征延迟、处理吞吐量、数据质量等——几乎一无所知。
当管道出现问题,比如上游数据流中断、特征计算任务卡死、或在线存储写入失败,我们往往是最后一个知道的。等到线上推荐效果的业务指标(如CTR)下跌时,再去回溯排查,为时已晚。
我们的目标是构建一个端到端可观测的实时向量特征管道。它必须满足以下条件:
- 可靠的数据源: 使用具备事务能力的 Delta Lake 作为行为日志的“真理之源”。
- 实时处理: 能够捕获源头的微小变化,进行流式处理和向量化。
- 高效向量存储: 将生成的向量实时写入 Milvus,用于高性能的 ANN 查询。
- 深度可观测性: 管道的每一步都必须有精确的量化指标,并通过 Prometheus 暴露出来,实现对数据流健康度的实时监控和告警。
整个架构的核心数据流如下所示:
graph TD subgraph "数据源 (Data Source)" A[User Behavior Events] --> B{Delta Lake Table}; B -- Change Data Feed --> C[Spark Streaming Job]; end subgraph "实时处理与向量化 (Real-time Processing & Vectorization)" C -- processBatch --> D[Feature Engineering]; D --> E[Embedding Model]; E --> F[User/Item Vectors]; end subgraph "在线服务与监控 (Online Serving & Monitoring)" F --> G{Milvus Collection}; C -- exposes metrics --> H(Prometheus Endpoint); I[Prometheus Server] -- scrapes --> H; J[Grafana/Alertmanager] -- queries --> I; end
第一步: 奠定坚实的数据基石 - Delta Lake 与 Change Data Feed
一切的起点是可靠的数据捕获。我们选择 Delta Lake,不是因为它是一个数据湖,而是因为它提供的 ACID 事务和特别是“变更数据流(Change Data Feed, CDF)”功能。CDF 允许我们像订阅数据库 binlog 一样,消费一个 Delta 表发生的所有行级变更(INSERT
, UPDATE
, DELETE
)。
首先,我们需要创建一个启用了 CDF 的 Delta 表来存储原始用户行为。
-- DDL for the raw user interaction table
CREATE TABLE user_interactions (
event_id STRING,
user_id STRING,
item_id STRING,
event_type STRING, -- e.g., 'click', 'view', 'purchase'
event_timestamp TIMESTAMP
)
USING DELTA
LOCATION '/path/to/delta/user_interactions'
TBLPROPERTIES (
'delta.enableChangeDataFeed' = 'true'
);
这个 TBLPROPERTIES
是关键。一旦启用,任何对 user_interactions
表的写操作都会额外记录变更事件。这为我们的下游流式处理提供了精确、有序、不多不少的数据源。
第二步: 管道核心 - 带有监控探针的 Spark Streaming 作业
这是整个系统的引擎。我们将编写一个 PySpark 结构化流作业,它会持续读取 user_interactions
表的变更,调用(模拟的)模型服务生成向量,然后写入 Milvus。最核心的设计在于,这个作业本身就是一个 Prometheus metrics exporter。
我们将使用 prometheus_client
库在 Spark Driver 节点上启动一个轻量级的 HTTP 服务,用于暴露自定义指标。
# feature_pipeline.py
import os
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
from prometheus_client import start_http_server, Counter, Gauge, Histogram
# --- Prometheus Metrics Definition ---
# These metrics will be updated within the streaming query's foreachBatch function.
# They provide deep insights into the pipeline's health.
PIPELINE_RECORDS_PROCESSED = Counter(
'pipeline_records_processed_total',
'Total number of records processed by the pipeline',
['status'] # 'success' or 'failure'
)
PIPELINE_DATA_LATENCY = Gauge(
'pipeline_data_latency_seconds',
'End-to-end latency from event timestamp to processing time'
)
MILVUS_INSERT_LATENCY = Histogram(
'milvus_insert_latency_seconds',
'Latency of inserting vectors into Milvus',
buckets=[0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]
)
EMBEDDING_GENERATION_DURATION = Histogram(
'embedding_generation_duration_seconds',
'Duration of vector embedding generation',
buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5]
)
LAST_SUCCESSFUL_BATCH_TIMESTAMP = Gauge(
'pipeline_last_successful_batch_timestamp_seconds',
'Timestamp of the last successfully processed batch'
)
# --- Milvus Connection & Schema Definition ---
# In a real project, these should come from a config file.
MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost")
MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
COLLECTION_NAME = "user_vectors"
VECTOR_DIM = 128 # Dimension of the output vectors from the embedding model
def get_milvus_collection():
"""
Connects to Milvus and ensures the collection exists.
Returns the Milvus Collection object.
"""
try:
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
print(f"Successfully connected to Milvus at {MILVUS_HOST}:{MILVUS_PORT}")
except Exception as e:
print(f"Failed to connect to Milvus: {e}")
# In a production system, this should trigger a critical alert.
raise
if not utility.has_collection(COLLECTION_NAME):
print(f"Collection '{COLLECTION_NAME}' does not exist. Creating...")
fields = [
FieldSchema(name="user_id", dtype=DataType.VARCHAR, is_primary=True, max_length=256),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=VECTOR_DIM)
]
schema = CollectionSchema(fields, "User vectors collection")
collection = Collection(COLLECTION_NAME, schema)
# Create an index for efficient search
index_params = {
"metric_type": "L2",
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
collection.create_index(field_name="vector", index_params=index_params)
print("Collection and index created.")
else:
collection = Collection(COLLECTION_NAME)
collection.load()
return collection
# --- Mock Embedding Model ---
# This function simulates calling a model serving endpoint or loading a local model.
def generate_embeddings(user_ids):
"""
Simulates generating embeddings for a list of user IDs.
In a real scenario, this would involve a batch inference call.
Returns a dictionary mapping user_id to its vector.
"""
start_time = time.time()
# Replace this with your actual model inference logic
import numpy as np
embeddings = {
user_id: np.random.rand(VECTOR_DIM).tolist()
for user_id in user_ids
}
duration = time.time() - start_time
# Observe the duration for each user. In a batch, this is an approximation.
if user_ids:
EMBEDDING_GENERATION_DURATION.observe(duration / len(user_ids))
return embeddings
# --- The Core Processing Logic for Each Micro-Batch ---
def process_batch(batch_df, batch_id):
"""
This function is the heart of the streaming job.
It's executed for each micro-batch of data from the Delta CDF.
"""
start_ts = time.time()
print(f"--- Starting processing for batch ID: {batch_id} ---")
if batch_df.isEmpty():
print("Batch is empty, skipping.")
return
# Filter only for new insertions. We could also handle updates ('update_postimage').
# A common mistake is to not filter by `_change_type`, processing redundant data.
insert_df = batch_df.filter(col("_change_type") == "insert")
if insert_df.isEmpty():
print("No new inserts in this batch, skipping vector generation.")
return
# Collect user IDs to process. `collect()` can be a bottleneck. For very large batches,
# consider using `mapInPandas` or other distributed UDF approaches.
user_ids_to_process = [row.user_id for row in insert_df.select("user_id").distinct().collect()]
if not user_ids_to_process:
return
print(f"Processing {len(user_ids_to_process)} unique users for batch {batch_id}.")
try:
# 1. Generate Embeddings
embeddings_map = generate_embeddings(user_ids_to_process)
# 2. Prepare data for Milvus
# Milvus requires data in a list-of-lists format.
# This is a critical data transformation step.
milvus_data = [
[user_id for user_id in embeddings_map.keys()],
[vector for vector in embeddings_map.values()]
]
# 3. Upsert into Milvus
milvus_collection = get_milvus_collection()
with MILVUS_INSERT_LATENCY.time():
# Using `upsert` is crucial. It overwrites existing user vectors,
# ensuring the feature store always has the latest representation.
# Simple `insert` would fail on duplicate primary keys.
milvus_collection.upsert(milvus_data)
# 4. Update Prometheus Metrics on success
record_count = insert_df.count()
PIPELINE_RECORDS_PROCESSED.labels(status='success').inc(record_count)
# Calculate and record data latency.
# This is a key business-level metric.
# It measures the time from the actual event to its processing.
latest_event_ts_row = insert_df.agg({"event_timestamp": "max"}).collect()[0]
latest_event_ts = latest_event_ts_row[0]
if latest_event_ts:
latency = start_ts - latest_event_ts.timestamp()
PIPELINE_DATA_LATENCY.set(latency)
LAST_SUCCESSFUL_BATCH_TIMESTAMP.set_to_current_time()
print(f"Successfully processed {record_count} records and upserted to Milvus.")
except Exception as e:
# Robust error handling is non-negotiable in production.
print(f"!!! An error occurred in batch {batch_id}: {e}")
PIPELINE_RECORDS_PROCESSED.labels(status='failure').inc(insert_df.count())
# Depending on the error, you might want to implement retries with backoff,
# or push failed batches to a dead-letter queue (DLQ) for later analysis.
# For now, we just log and count the failure.
if __name__ == "__main__":
# --- Start Prometheus HTTP server ---
# This server will run on the Spark driver node.
# Ensure the driver's port (8000) is accessible by the Prometheus scraper.
PROMETHEUS_PORT = 8000
start_http_server(PROMETHEUS_PORT)
print(f"Prometheus metrics server started on port {PROMETHEUS_PORT}")
# --- Spark Session Initialization ---
spark = SparkSession.builder \
.appName("DeltaMilvusRealtimePipeline") \
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# --- Read from Delta Change Data Feed ---
# `readStream` is the entry point for Structured Streaming.
# `startingVersion` is set to 'latest' to process only new data.
# `maxFilesPerTrigger` controls the micro-batch size for latency tuning.
print("Starting to read from Delta table's Change Data Feed...")
cdf_stream_df = spark.readStream \
.format("delta") \
.option("readChangeFeed", "true") \
.option("startingVersion", "latest") \
.load("/path/to/delta/user_interactions")
# --- Start the Streaming Query ---
# `foreachBatch` provides the flexibility needed for complex operations
# like calling external services (Milvus, model servers) and updating metrics.
query = cdf_stream_df.writeStream \
.foreachBatch(process_batch) \
.option("checkpointLocation", "/path/to/checkpoints/delta_milvus_pipeline") \
.trigger(processingTime='10 seconds') \
.start()
print("Streaming query started. Waiting for termination...")
query.awaitTermination()
代码剖析与生产考量
监控指标的选型: 我们定义的
Counter
,Gauge
,Histogram
不是随机的。-
PIPELINE_RECORDS_PROCESSED
:Counter
最适合统计累计处理量。通过status
标签,我们可以清晰区分成功与失败的记录数,快速发现错误率的变化。 -
PIPELINE_DATA_LATENCY
:Gauge
用于表示一个瞬时值。数据新鲜度是这个管道的生命线,这个指标直接反映了其健康状况。 -
MILVUS_INSERT_LATENCY
:Histogram
至关重要。它不仅告诉我们平均延迟,还通过分桶(buckets)展示了延迟的分布。例如,如果 p99 延迟飙升而平均值变化不大,说明存在长尾请求,这可能是 Milvus 内部出现压力的信号。 -
LAST_SUCCESSFUL_BATCH_TIMESTAMP
: 这是一个心跳指标。如果这个时间戳在一段时间内没有更新,说明整个流处理作业已经卡死或崩溃,需要立即告警。
-
foreachBatch
的威力与陷阱: 虽然foreachBatch
提供了最大的灵活性,但它也容易被误用。一个常见的错误是在循环中逐条处理数据,这会丧失 Spark 的并行计算优势。正确的做法是像代码中那样,对整个批次batch_df
进行 DataFrame 级别的操作,只在最后需要与外部系统交互时(如collect()
用户ID列表)才将数据拉到 Driver。幂等性:
milvus_collection.upsert()
是实现端到端幂等性的关键。如果因为网络问题导致一个批次被 Spark 重试,upsert
操作能确保用户向量只是被覆盖更新,而不会因为重复插入导致数据错误。
第三步: 配置 Prometheus 抓取与可视化
现在,我们的 Spark 作业已经通过 :8000
端口暴露了监控指标。我们需要配置 Prometheus 来发现并抓取它。
在 prometheus.yml
中添加以下抓取配置:
# prometheus.yml
scrape_configs:
- job_name: 'spark_feature_pipeline'
# How often to scrape the metrics.
scrape_interval: 15s
# In a real K8s environment, you would use service discovery.
# For a simple setup, we can statically define the Spark driver's endpoint.
# This assumes the Spark driver is running on 'spark-driver-host.domain'.
static_configs:
- targets: ['spark-driver-host.domain:8000']
启动 Prometheus 后,我们就可以在 Grafana 中创建仪表盘,将这些指标可视化。一个典型的仪表盘可能包含以下面板:
- 处理速率:
rate(pipeline_records_processed_total{status="success"}[5m])
- 错误率:
rate(pipeline_records_processed_total{status="failure"}[5m]) / rate(pipeline_records_processed_total[5m])
- 端到端数据延迟:
pipeline_data_latency_seconds
- Milvus 插入延迟 (p99):
histogram_quantile(0.99, sum(rate(milvus_insert_latency_seconds_bucket[5m])) by (le))
- 管道停滞告警:
time() - pipeline_last_successful_batch_timestamp_seconds > 300
,当此表达式为真时触发告警,表示管道已超过5分钟没有成功处理批次。
方案的局限性与未来迭代路径
尽管这个架构解决了核心的实时性和可观测性问题,但在投入更大规模的生产环境前,仍有几个方面需要加固。
首先,将 Prometheus exporter 直接运行在 Spark Driver 上存在单点风险。Driver 节点的重启会导致监控历史的中断。更稳健的方案是使用 Prometheus Pushgateway,由 Spark executor 将指标推送过去。但这会引入新的复杂性和依赖。
其次,当前的错误处理机制还比较初级。对于无法处理的“毒丸”消息(例如,格式错误导致模型推理失败),简单的重试会造成批次阻塞。引入一个死信队列(Dead-Letter Queue, DLQ),将处理失败的数据写入另一个 Delta 表或 Kafka topic,供后续离线分析和修复,是生产级管道的必要组件。
最后,向量生成模型的性能本身可能成为瓶颈。当数据吞吐量巨大时,单点调用 generate_embeddings
函数会拖慢整个批次。可以考虑使用 Pandas UDFs (mapInPandas
) 结合 GPU 实例来分布式地执行模型推理,进一步压榨处理性能。这个优化将是下一步迭代的重点。