一句话总结

通过在多个 LLM 之间共享 Prefill 阶段的计算和 KV Cache,PrefillShare 将 Multi-Agent 系统的 P95 延迟降低 4.5 倍,吞吐量提升 3.9 倍。

为什么需要这个?

问题:Multi-Agent 系统的性能瓶颈

当前的 Multi-Agent 系统(如 LangChain、AutoGPT)会这样工作:

# 典型的 Multi-Agent 调用模式
shared_context = "用户问题:请分析这份财报..."  # 5000 tokens

# 调用多个专业模型
financial_model.generate(shared_context + "财务分析...")
legal_model.generate(shared_context + "法律风险...")
summary_model.generate(shared_context + "总结报告...")

核心问题:每个模型都要对 shared_context 重复执行 Prefill!

实测数据(H100 GPU):

  • 单个模型 Prefill 5000 tokens:~120ms
  • 3 个模型重复 Prefill:~360ms(纯浪费
  • KV Cache 占用:3 × 2.5GB = 7.5GB(重复存储

硬件层面发生了什么?

传统 Disaggregated Serving 架构虽然将 Prefill 和 Decode 分离到不同 GPU 以提高资源利用率,但仍然存在根本性问题:每个模型的 Prefill GPU 都在重复相同的计算

┌─────────────────┐
│  Prefill GPU    │──┐
│                 │  │  Model A: Prefill(共享上下文)
│  ┌──────────┐   │  │  → KV Cache A (2.5GB)
│  │KV Cache A│   │  │
└──┴──────────┴───┘  │
                      │
┌─────────────────┐  │
│  Prefill GPU    │──┤
│                 │  │  Model B: Prefill(共享上下文) ← 重复计算!
│  ┌──────────┐   │  │  → KV Cache B (2.5GB)
│  │KV Cache B│   │  │
└──┴──────────┴───┘  │
                      │
┌─────────────────┐  │
│  Decode GPU     │──┘
│  (生成输出)      │
└─────────────────┘

瓶颈分析

  1. 计算冗余:相同的 Transformer 计算被执行了 N 次(N = 模型数量)。对于 LLaMA-3-8B,每次 Prefill 需要约 15 TFLOPs(5000 tokens × 32 layers × 4096 dim),重复计算浪费了 30 TFLOPs
  2. 内存浪费:KV Cache 存储了 N 份完全相同的数据。每个 token 在每层存储 2 个向量(key/value),每个向量 4096 维,float16 精度:$2 \times 32 \times 5000 \times 4096 \times 2 \text{ bytes} = 2.5\text{GB}$
  3. 带宽压力:每个模型都要从 HBM 读取相同的权重。H100 的 HBM 带宽虽高达 3TB/s,但重复读取 8B 参数(16GB)仍然造成不必要的延迟

核心原理

直觉:为什么可以共享 Prefill?

关键洞察:Prefill 阶段的计算只依赖输入 tokens,与任务无关!

在 Transformer 架构中,Prefill 阶段执行的是标准的 Self-Attention 和 FFN 计算:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

其中 $Q$、$K$、$V$ 完全由输入 tokens 和模型权重决定,与后续生成什么内容无关。这意味着:

  1. 相同输入 → 相同 KV Cache:对于同一段文本,不同模型的前 N 层(如果权重相同)会产生完全相同的 Key/Value 向量
  2. 任务差异仅在 Decode:专业模型(财务、法律)的差异主要体现在如何利用这些 Key/Value 来生成特定领域的输出
输入: "请分析这份财报..."
      ↓
   [Prefill]  ← 这部分计算对所有模型都一样!
      ↓
   KV Cache (key/value 向量)
      ↓
   [Decode]  ← 这部分才需要任务专用权重
      ↓
   输出: "财务状况良好..."

架构拆分:冻结 vs 微调

PrefillShare 的核心设计是将模型分成两部分:

  1. Shared Prefill Module(冻结)
    • 包含:Embedding 层 + 前 28 层 Transformer
    • 作用:将输入 tokens 映射为统一的语义表示
    • 训练策略:保持预训练权重不变
  2. Task-Specific Decode Module(可微调)
    • 包含:后 4 层 Transformer + LM Head
    • 作用:将语义表示转化为领域特定的输出分布
    • 训练策略:在专业数据集上微调(如金融报告、法律文书)

为什么这样拆分? 根据实验观察,Transformer 的浅层倾向于学习通用语法和语义模式,而深层则更关注任务特定的推理模式。将拆分点设在第 28 层(共 32 层)既保证了足够的共享计算量,又为任务定制保留了足够的自由度。

硬件层面的改变

PrefillShare 架构:

┌─────────────────────────┐
│  共享 Prefill GPU        │
│                          │
│  ┌──────────────────┐   │  只计算一次!
│  │ Shared KV Cache  │   │  → 2.5GB (省了 5GB)
│  └──────────────────┘   │
└────────┬────────────────┘
         │ 广播 KV Cache (NCCL, ~5ms)
         ├──────────┬──────────┬──────────
         ↓          ↓          ↓
    ┌────────┐ ┌────────┐ ┌────────┐
    │Decode A│ │Decode B│ │Decode C│
    │(财务)  │ │(法律)  │ │(总结)  │
    └────────┘ └────────┘ └────────┘

性能提升来源

  • 计算:120ms vs 360ms(节省 66%)。关键在于 Attention 计算量与序列长度呈平方关系:$O(n^2 d)$,共享后只需计算一次
  • 内存:2.5GB vs 7.5GB(节省 66%)。更重要的是,这释放了 GPU 显存用于更大的 batch size
  • 带宽:只读一次权重,减少 PCIe 流量。对于 8B 模型,从读取 48GB(3×16GB)降低到 16GB

与 vLLM PagedAttention 的互补性:PagedAttention 解决的是单个 KV Cache 的内存碎片问题(通过分页管理),而 PrefillShare 解决的是多个模型间的 KV Cache 重复问题(通过共享)。两者可以结合使用,进一步提升效率。

代码实现

Baseline:朴素的多模型调用

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class NaiveMultiModelServer:
    def __init__(self, model_paths):
        self.models = [
            AutoModelForCausalLM.from_pretrained(path, device_map="auto")
            for path in model_paths
        ]
        self.tokenizers = [AutoTokenizer.from_pretrained(path) for path in model_paths]
    
    def multi_agent_query(self, shared_context, tasks):
        results = []
        for i, task in enumerate(tasks):
            prompt = f"{shared_context}\n{task}:"
            inputs = self.tokenizers[i](prompt, return_tensors="pt")
            
            with torch.no_grad():
                # 每个模型独立执行 Prefill + Decode
                outputs = self.models[i].generate(**inputs, max_new_tokens=100, use_cache=True)
            
            results.append(self.tokenizers[i].decode(outputs[0]))
        return results

性能分析(H100 GPU,5000 token context):

  • Total Prefill Time: 360ms(3 × 120ms)
  • GPU Memory: 7.5GB KV Cache + 48GB Model Weights
  • Prefill Compute Redundancy: 200%

优化版本:PrefillShare

import torch
import torch.nn as nn
from transformers.models.llama import LlamaModel, LlamaForCausalLM

class SharedPrefillModule(nn.Module):
    """共享的 Prefill 模块(冻结权重)"""
    def __init__(self, base_model_path, num_shared_layers=28):
        super().__init__()
        full_model = LlamaModel.from_pretrained(base_model_path)
        
        self.shared_layers = nn.ModuleList(full_model.layers[:num_shared_layers])
        self.embed_tokens = full_model.embed_tokens
        self.num_shared_layers = num_shared_layers
        
        # 冻结所有参数
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask=None):
        """
        返回: (hidden_states, kv_cache)
        kv_cache: List[(key, value)] for each layer
        """
        hidden_states = self.embed_tokens(input_ids)
        kv_cache = []
        
        batch_size, seq_len = input_ids.shape
        # 构造 causal mask
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, dtype=torch.bool, device=input_ids.device),
            diagonal=1
        )
        
        for layer in self.shared_layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=attention_mask,
                use_cache=True
            )
            hidden_states = layer_outputs[0]
            kv_cache.append(layer_outputs[1])  # (key, value) tensors
        
        return hidden_states, kv_cache


class TaskSpecificDecodeModule(nn.Module):
    """任务专用的 Decode 模块(可微调)"""
    def __init__(self, base_model_path, num_shared_layers=28):
        super().__init__()
        full_model = LlamaForCausalLM.from_pretrained(base_model_path)
        
        # 只保留后 4 层 + LM Head
        self.decode_layers = nn.ModuleList(
            full_model.model.layers[num_shared_layers:]
        )
        self.norm = full_model.model.norm
        self.lm_head = full_model.lm_head
    
    def forward(self, hidden_states, kv_cache):
        for i, layer in enumerate(self.decode_layers):
            layer_outputs = layer(
                hidden_states,
                past_key_value=kv_cache[i] if kv_cache else None,
                use_cache=True
            )
            hidden_states = layer_outputs[0]
        
        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits


class PrefillShareServer:
    def __init__(self, base_model_path, task_model_paths, num_shared_layers=28):
        self.num_shared_layers = num_shared_layers
        
        # 共享 Prefill 模块(放在专用 GPU)
        self.prefill_module = SharedPrefillModule(
            base_model_path, num_shared_layers
        ).cuda(0)
        
        # 多个 Decode 模块(分布到多 GPU)
        self.decode_modules = [
            TaskSpecificDecodeModule(path, num_shared_layers).cuda(i % 4)
            for i, path in enumerate(task_model_paths)
        ]
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    
    @torch.no_grad()
    def multi_agent_query(self, shared_context, tasks):
        # Step 1: 执行一次共享 Prefill
        inputs = self.tokenizer(shared_context, return_tensors="pt").to("cuda:0")
        
        shared_hidden, shared_kv = self.prefill_module(
            inputs.input_ids,
            inputs.attention_mask
        )
        
        # Step 2: 并行执行多个 Decode(使用 NCCL 广播 KV Cache)
        if torch.distributed.is_initialized():
            # GPU 间高速传输,避免 CPU 中转
            for kv_pair in shared_kv:
                torch.distributed.broadcast(kv_pair[0], src=0)  # key
                torch.distributed.broadcast(kv_pair[1], src=0)  # value
        
        results = []
        for task, decode_module in zip(tasks, self.decode_modules):
            device = next(decode_module.parameters()).device
            
            # 将共享数据传输到对应 GPU
            hidden = shared_hidden.to(device)
            kv = [(k.to(device), v.to(device)) for k, v in shared_kv]
            
            # 生成输出
            output_ids = self._generate(decode_module, hidden, kv, max_new_tokens=100)
            results.append(self.tokenizer.decode(output_ids[0]))
        
        return results
    
    def _generate(self, module, hidden, kv_cache, max_new_tokens):
        """基于共享 KV Cache 的生成"""
        # ... (autoregressive generation 实现省略)
        # 核心:每步只计算 Decode 层,复用 kv_cache
        pass

关键优化点说明

  1. 可配置的层数分割num_shared_layers 参数允许根据实际任务调整拆分点。金融/法律等高专业性任务可能需要更多 Decode 层(如 8 层),而通用任务可用更少(如 2 层)

  2. NCCL 广播优化:使用 torch.distributed.broadcast 在 GPU 间直接传输 KV Cache,避免先拷贝到 CPU 再分发(后者会引入 ~50ms 额外延迟)

  3. 内存带宽计算:对于 5000 tokens 的 KV Cache(2.5GB),通过 NVLink(900GB/s)广播仅需 ~3ms,而通过 PCIe(64GB/s)则需 ~40ms

完整的环境配置(Docker)

FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04

# 安装 Python 和依赖
RUN apt-get update && apt-get install -y python3-pip git
RUN pip3 install torch==2.1.0 transformers==4.35.0 accelerate

# 配置多 GPU 通信
RUN pip3 install torch-distributed

# 设置环境变量
ENV NCCL_DEBUG=INFO
ENV CUDA_VISIBLE_DEVICES=0,1,2,3

WORKDIR /app
COPY . /app

# 运行命令
CMD ["python3", "prefillshare_server.py"]

运行示例

# 构建镜像
docker build -t prefillshare .

# 启动服务(需要 4 GPU)
docker run --gpus all -p 8000:8000 prefillshare

# 性能测试
python3 benchmark.py --num-models 3 --context-length 5000

常见错误与修复

# ❌ 错误 1:KV Cache 维度不匹配
# 问题:past_key_value 的 shape 必须是 (batch, num_heads, seq_len, head_dim)
kv_cache_wrong = torch.randn(1, 5000, 4096)  # 缺少 num_heads 维度

# ✅ 正确
kv_cache_correct = torch.randn(1, 32, 5000, 128)  # (batch, heads, seq, dim)


# ❌ 错误 2:微调时解冻了 Prefill 层
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # 错误:包含冻结层

# ✅ 正确:只优化 Decode 层
trainable_params = [
    p for name, p in model.named_parameters()
    if 'decode_layers' in name or 'lm_head' in name
]
optimizer = torch.optim.Adam(trainable_params, lr=1e-4)


# ❌ 错误 3:GPU OOM 时未及时释放 KV Cache
def process_batch(inputs):
    hidden, kv = prefill_module(inputs)
    results = decode_module(hidden, kv)
    return results  # kv 仍占用显存!

# ✅ 正确
def process_batch(inputs):
    hidden, kv = prefill_module(inputs)
    results = decode_module(hidden, kv)
    del kv  # 立即释放
    torch.cuda.empty_cache()
    return results

性能分析工具

# 使用 Nsight Compute 对比 FLOPs
nsys profile --stats=true \
  python benchmark.py --mode baseline

nsys profile --stats=true \
  python benchmark.py --mode prefillshare

# 分析 KV Cache 内存使用
python -m torch.utils.bottleneck benchmark.py

# 预期输出:
# Baseline:     Prefill FLOPs = 45 TFLOPs (3 models × 15 TFLOPs)
# PrefillShare: Prefill FLOPs = 15 TFLOPs (1 shared computation)

性能实测

实验环境:

  • GPU: 4 × H100 (80GB)
  • 模型: LLaMA-3-8B × 3 个任务模型
  • Context: 5000 tokens
  • Decode: 100 tokens per task
指标 Baseline PrefillShare 提升
延迟(P50) 420ms 145ms 2.9x
延迟(P95) 680ms 150ms 4.5x
吞吐量 120 req/s 468 req/s 3.9x
GPU Memory 22GB 8GB 2.75x
Prefill FLOPs 45 TFLOPs 15 TFLOPs 3x

详细分解(单次请求):

阶段 Baseline PrefillShare 说明
Prefill (Model A) 120ms - 消除
Prefill (Model B) 120ms - 消除
Prefill (Model C) 120ms - 消除
共享 Prefill - 120ms 只算一次
KV Cache 广播 - 5ms NCCL over NVLink
Decode (并行) 60ms 25ms 多 GPU 并行
总延迟 420ms 145ms 2.9x 提升

为什么 P95 延迟提升更大? 在 Baseline 中,三个模型的 Prefill 是串行的,任何一个模型的性能抖动都会累积到总延迟。而 PrefillShare 只执行一次 Prefill,消除了串行累积效应,因此长尾延迟显著降低。

什么时候用 / 不用?

适用场景 不适用场景
✅ Multi-Agent 系统(多模型共享上下文) ❌ 单模型单任务
✅ 长上下文场景(>2000 tokens) ❌ 短上下文(<500 tokens,Prefill 占比小)
✅ 批量推理(同时调用多个专家模型) ❌ 在线服务(单用户单请求)
✅ 有多 GPU 资源 ❌ 单 GPU 环境
✅ 任务相关性强(共享上下文有意义) ❌ 任务完全独立(无共享上下文)

性能拐点分析

  • Context Length > 2000 tokens:此时 Prefill 时间占总延迟的 70% 以上,共享收益显著
  • 模型数量 ≥ 3:2 个模型时节省 50%,3 个模型节省 66%,边际收益递减
  • Prefill/Decode 比例 > 2:1:对于短生成任务(<50 tokens),优化效果最明显

与其他方法的对比

方法 解决的问题 局限性
vLLM PagedAttention 单模型的 KV Cache 碎片 不处理跨模型重复
DistServe Prefill/Decode 资源分离 未共享 Prefill 计算
PrefillShare 跨模型的 Prefill 重复 需要多 GPU + 任务相关性

三者可以结合使用:DistServe 提供基础架构,PagedAttention 管理内存,PrefillShare 消除冗余。

延伸阅读

  1. Disaggregated Serving 基础
  2. KV Cache 优化
    • PagedAttention (vLLM)
    • 建议阅读:Section 3.2 的内存管理机制,理解如何与 PrefillShare 结合
  3. 模型分解与微调
  4. Multi-Agent 系统实践
  5. 官方实现
    • 原论文代码:https://arxiv.org/abs/2602.12029v1
    • 先读 Section 4(System Design)理解整体架构,再看代码实现细节