Memory Bank Compression for Continual Adaptation of Large Language Models
大模型持续学习的记忆库压缩:从理论到实践
问题设定:持续学习中的记忆困境
在强化学习和持续学习(Continual Learning)的交叉领域中,我们面临一个核心挑战:如何让大语言模型(LLM)在不断接收新知识的同时,既不忘记旧知识(避免灾难性遗忘),又不让存储成本无限增长?
MDP视角下的问题建模:
- 状态 (State):当前LLM的参数状态 + 外部记忆库内容
- 动作 (Action):选择更新哪些参数、如何压缩记忆库
- 奖励 (Reward):新任务准确率 - 旧任务遗忘率 - 记忆库存储成本
- 转移 (Transition):数据流到达后的模型更新过程
这是一个在线学习(Online Learning)问题,类似于off-policy RL,我们需要在数据流中持续优化策略,同时平衡exploration(学习新知识)和exploitation(保持旧知识)。
与传统RL算法的关系:
- 类似DQN的经验回放机制(记忆库存储历史信息)
- 借鉴Actor-Critic思想(LLM主体 + 记忆库辅助)
- 引入VQ-VAE的离散化表示(码本压缩)
算法原理:记忆库压缩的数学基础
核心思想
传统记忆增强方法将每个样本的表示直接存入记忆库,导致存储量线性增长。MBC(Memory Bank Compression)通过向量量化(Vector Quantization)将连续的记忆表示映射到离散码本,实现指数级压缩。
数学推导
1. 记忆库的价值函数
定义记忆库的价值为其对模型性能的贡献:
\[V(\mathcal{M}) = \mathbb{E}_{(x,y) \sim \mathcal{D}_{\text{test}}} [\log P_\theta(y|x, \mathcal{M})]\]其中 $\mathcal{M}$ 是记忆库,$\theta$ 是LLM参数。
2. 压缩目标函数
引入码本 $\mathcal{C} = {c_1, …, c_K}$(K个码字),优化目标:
\[\min_{\mathcal{C}, \phi} \mathbb{E}_{m \sim \mathcal{M}} [\|m - c_{\phi(m)}\|^2] + \lambda \cdot \text{Size}(\mathcal{C})\]- $\phi(m)$:编码器,将记忆 $m$ 映射到最近的码字索引
- $\lambda$:存储成本权重
3. 在线更新的Bellman方程
在时间步 $t$,接收新数据 $(x_t, y_t)$:
\[\mathcal{C}_{t+1} = \arg\min_{\mathcal{C}} \sum_{i=1}^{t} \alpha^{t-i} \|m_i - c_{\phi(m_i)}\|^2\]指数衰减因子 $\alpha$ 平衡新旧数据重要性。
4. 防止码本坍缩的重置机制
当码字 $c_k$ 的使用频率低于阈值 $\tau$ 时:
\[c_k \leftarrow m_{\text{worst}} + \epsilon, \quad \text{where} \quad m_{\text{worst}} = \arg\max_{m \in \mathcal{M}_t} \|m - c_{\phi(m)}\|\]将未使用码字重置为重建误差最大的记忆附近。
算法伪代码
Algorithm: MBC (Memory Bank Compression)
Input:
- LLM with LoRA adapters θ
- Codebook C = {c₁, ..., cₖ}
- Data stream {(xₜ, yₜ)}ₜ₌₁^∞
Hyperparameters:
- Codebook size K
- Reset threshold τ
- Learning rate η
For each timestep t:
1. 提取新样本的键值表示:
m_t = Encoder(x_t)
2. 向量量化(找最近码字):
idx = argmin_k ||m_t - c_k||
m̂_t = c_idx
3. 使用压缩记忆增强推理:
ŷ_t = LLM(x_t | {m̂_1, ..., m̂_t})
4. 计算损失并更新:
L = CrossEntropy(ŷ_t, y_t) + β||m_t - m̂_t||²
θ ← θ - η∇_θ L
C ← C - η∇_C L
5. 码本重置(防止坍缩):
For each c_k:
If usage(c_k) < τ:
c_k ← m_worst + noise
6. 更新记忆库索引:
M_indices.append(idx)
关键创新点:
- 在线码本优化:不需要预训练,边学习边压缩
- 自适应重置:防止码字”死亡”,保持表达能力
- LoRA集成:仅更新低秩适配器,降低计算成本
实现:简单环境(问答任务)
环境定义
我们使用一个简化的问答数据流环境:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple, Dict
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
# 模拟持续学习的问答数据流
class ContinualQADataset(Dataset):
"""
模拟数据流:每个时间段有不同主题的问答对
模拟概念漂移(concept drift)场景
"""
def __init__(self, num_topics=5, samples_per_topic=100, vocab_size=1000):
self.data = []
self.topics = []
# 生成不同主题的问答对
for topic_id in range(num_topics):
for _ in range(samples_per_topic):
# 简化表示:每个主题有特定的token分布
question = torch.randint(
topic_id * 200,
(topic_id + 1) * 200,
(20,) # 问题长度20
)
answer = torch.randint(
topic_id * 200,
(topic_id + 1) * 200,
(10,) # 答案长度10
)
self.data.append((question, answer))
self.topics.append(topic_id)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx][0], self.data[idx][1], self.topics[idx]
# 创建数据流
def create_data_stream(batch_size=8):
dataset = ContinualQADataset(num_topics=5, samples_per_topic=100)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
核心组件实现
1. 向量量化码本
class VectorQuantizer(nn.Module):
"""
向量量化模块:将连续记忆映射到离散码本
核心思想:类似VQ-VAE,但针对在线学习优化
"""
def __init__(
self,
codebook_size: int = 256, # 码本大小(K)
embedding_dim: int = 768, # 记忆向量维度
commitment_cost: float = 0.25, # 承诺损失权重(β)
decay: float = 0.99, # EMA衰减率
epsilon: float = 1e-5
):
super().__init__()
self.codebook_size = codebook_size
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
# 码本:K个d维向量
self.codebook = nn.Embedding(codebook_size, embedding_dim)
self.codebook.weight.data.uniform_(-1/codebook_size, 1/codebook_size)
# 用于EMA更新的统计量
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed_avg', self.codebook.weight.data.clone())
self.decay = decay
self.epsilon = epsilon
# 记录每个码字的使用频率(用于重置)
self.register_buffer('usage_count', torch.zeros(codebook_size))
self.reset_threshold = 0.01 # 使用率低于1%时重置
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
z: 输入记忆 [batch_size, embedding_dim]
Returns:
z_q: 量化后的记忆
loss: VQ损失
indices: 码字索引
"""
# 展平输入
z_flattened = z.view(-1, self.embedding_dim)
# 计算距离:||z - c_k||^2 = ||z||^2 + ||c_k||^2 - 2<z, c_k>
distances = (
torch.sum(z_flattened ** 2, dim=1, keepdim=True)
+ torch.sum(self.codebook.weight ** 2, dim=1)
- 2 * torch.matmul(z_flattened, self.codebook.weight.t())
)
# 找到最近的码字索引
indices = torch.argmin(distances, dim=1)
# 更新使用计数
self.usage_count.index_add_(
0, indices, torch.ones_like(indices, dtype=torch.float)
)
# 量化:用码字替换原始向量
z_q = self.codebook(indices).view(z.shape)
# 计算VQ损失
# 1. Codebook loss: 让码字靠近输入
# 2. Commitment loss: 让输入靠近码字(防止输入空间膨胀)
codebook_loss = F.mse_loss(z_q.detach(), z)
commitment_loss = F.mse_loss(z_q, z.detach())
loss = codebook_loss + self.commitment_cost * commitment_loss
# Straight-through estimator:前向用量化值,反向传梯度给输入
z_q = z + (z_q - z).detach()
return z_q, loss, indices
def update_codebook_ema(self, z: torch.Tensor, indices: torch.Tensor):
"""
使用指数移动平均(EMA)更新码本
比梯度下降更稳定,适合在线学习
"""
indices_onehot = F.one_hot(indices, self.codebook_size).float()
# 更新聚类大小
self.cluster_size.mul_(self.decay).add_(
indices_onehot.sum(0), alpha=1 - self.decay
)
# 拉普拉斯平滑
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.epsilon)
/ (n + self.codebook_size * self.epsilon) * n
)
# 更新码字的EMA
embed_sum = torch.matmul(indices_onehot.t(), z.view(-1, self.embedding_dim))
self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
# 更新码本权重
self.codebook.weight.data.copy_(self.embed_avg / cluster_size.unsqueeze(1))
def reset_unused_codes(self):
"""
重置未充分使用的码字
防止码本坍缩(codebook collapse)
"""
total_usage = self.usage_count.sum()
if total_usage == 0:
return
usage_freq = self.usage_count / total_usage
unused_mask = usage_freq < self.reset_threshold
if unused_mask.any():
# 找到重建误差最大的样本
# 实际中应该从最近的batch中采样
# 这里简化为随机重置
num_reset = unused_mask.sum().item()
new_codes = torch.randn(num_reset, self.embedding_dim).to(self.codebook.weight.device)
self.codebook.weight.data[unused_mask] = new_codes * 0.1
# 重置使用计数
self.usage_count[unused_mask] = 0
print(f"重置了 {num_reset} 个未使用的码字")
2. LoRA低秩适配器
class LoRALayer(nn.Module):
"""
Low-Rank Adaptation (LoRA) 层
原理:W' = W + BA,其中B∈R^(d×r), A∈R^(r×k), r<<d
只训练A和B,冻结原始权重W
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 8,
alpha: float = 16.0
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
# LoRA的两个低秩矩阵
self.lora_A = nn.Parameter(torch.randn(in_features, rank) * 0.01)
self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch_size, seq_len, in_features]
"""
# LoRA增量:x @ A @ B
lora_output = (x @ self.lora_A @ self.lora_B) * self.scaling
return lora_output
class LoRALinear(nn.Module):
"""
带LoRA的线性层:组合原始权重和LoRA增量
"""
def __init__(
self,
linear_layer: nn.Linear,
rank: int = 8,
alpha: float = 16.0
):
super().__init__()
self.linear = linear_layer
self.linear.weight.requires_grad = False # 冻结原始权重
self.lora = LoRALayer(
linear_layer.in_features,
linear_layer.out_features,
rank,
alpha
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 原始输出 + LoRA增量
return self.linear(x) + self.lora(x)
3. 记忆增强的LLM
class MemoryAugmentedLLM(nn.Module):
"""
带记忆库的大语言模型
架构:
1. 编码器:提取输入的键值表示
2. 向量量化:压缩记忆
3. 记忆检索:从码本中检索相关记忆
4. LLM推理:使用检索到的记忆增强生成
"""
def __init__(
self,
hidden_dim: int = 768,
codebook_size: int = 256,
lora_rank: int = 8,
max_memory_size: int = 1000
):
super().__init__()
# 简化的编码器(实际中应使用预训练LLM的encoder)
self.encoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 向量量化器
self.vq = VectorQuantizer(
codebook_size=codebook_size,
embedding_dim=hidden_dim
)
# LoRA适配器(应用于注意力层的Q、K、V)
self.lora_q = LoRALayer(hidden_dim, hidden_dim, rank=lora_rank)
self.lora_k = LoRALayer(hidden_dim, hidden_dim, rank=lora_rank)
self.lora_v = LoRALayer(hidden_dim, hidden_dim, rank=lora_rank)
# 简化的解码器
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 记忆库:存储码字索引而非原始向量
self.memory_indices = deque(maxlen=max_memory_size)
self.memory_keys = deque(maxlen=max_memory_size) # 用于检索的键
def encode_memory(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
编码输入为记忆表示并量化
Returns:
memory: 原始记忆向量
quantized_memory: 量化后的记忆
vq_loss: 向量量化损失
indices: 码字索引
"""
# 提取记忆表示
memory = self.encoder(x.mean(dim=1)) # [batch, hidden_dim]
# 向量量化
quantized_memory, vq_loss, indices = self.vq(memory)
return memory, quantized_memory, vq_loss, indices
def retrieve_memory(self, query: torch.Tensor, top_k: int = 5) -> torch.Tensor:
"""
从记忆库中检索相关记忆
Args:
query: 查询向量 [batch, hidden_dim]
top_k: 检索top-k个记忆
Returns:
retrieved: 检索到的记忆 [batch, top_k, hidden_dim]
"""
if len(self.memory_indices) == 0:
return torch.zeros(query.size(0), top_k, query.size(-1)).to(query.device)
# 计算查询与所有记忆键的相似度
memory_keys_tensor = torch.stack(list(self.memory_keys)).to(query.device)
similarities = torch.matmul(query, memory_keys_tensor.t()) # [batch, num_memories]
# 选择top-k
top_k = min(top_k, len(self.memory_indices))
topk_indices = torch.topk(similarities, k=top_k, dim=1).indices
# 从码本中检索
retrieved = []
for batch_idx in range(query.size(0)):
batch_memories = []
for k_idx in topk_indices[batch_idx]:
code_idx = self.memory_indices[k_idx.item()]
memory_vector = self.vq.codebook.weight[code_idx]
batch_memories.append(memory_vector)
retrieved.append(torch.stack(batch_memories))
return torch.stack(retrieved)
def forward(
self,
x: torch.Tensor,
use_memory: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
x: 输入 [batch, seq_len, hidden_dim]
use_memory: 是否使用记忆增强
Returns:
output: 模型输出
vq_loss: 向量量化损失
"""
batch_size, seq_len, hidden_dim = x.shape
# 1. 编码当前输入为记忆
memory, quantized_memory, vq_loss, indices = self.encode_memory(x)
# 2. 检索相关记忆
if use_memory and len(self.memory_indices) > 0:
retrieved_memories = self.retrieve_memory(memory, top_k=5)
# 将检索到的记忆拼接到输入
memory_context = retrieved_memories.mean(dim=1, keepdim=True).expand(-1, seq_len, -1)
else:
memory_context = torch.zeros_like(x)
# 3. 使用LoRA增强的注意力
q = x + self.lora_q(x)
k = (x + memory_context) + self.lora_k(x + memory_context)
v = (x + memory_context) + self.lora_v(x + memory_context)
# 简化的注意力计算
attn_weights = torch.softmax(
torch.matmul(q, k.transpose(-2, -1)) / (hidden_dim ** 0.5),
dim=-1
)
attn_output = torch.matmul(attn_weights, v)
# 4. 解码
output = self.decoder(attn_output)
return output, vq_loss
def update_memory(self, keys: torch.Tensor, indices: torch.Tensor):
"""
更新记忆库(只存储索引,不存储原始向量)
"""
for key, idx in zip(keys, indices):
self.memory_keys.append(key.detach().cpu())
self.memory_indices.append(idx.item())
训练循环
```python class MBCTrainer: “”” MBC训练器:在线持续学习 “”” def init( self, model: MemoryAugmentedLLM, learning_rate: float = 1e-4, vq_weight: float = 0.25, reset_interval: int = 100 ): self.model = model self.vq_weight = vq_weight self.reset_interval = reset_interval
# 只优化LoRA参数和VQ码本
self.optimizer = torch.optim.Adam([
{'params': model.lora_q.parameters()},
{'params': model.lora_k.parameters()},
{'params': model.lora_v.parameters()},
{'params': model.vq.parameters()},
], lr=learning_rate)
# 记录训练指标
self.metrics = {
'loss': [],
'vq_loss': [],
'accuracy': [],
'memory_size': [],
'codebook_usage': []
}
def train_step(
self,
questions: torch.Tensor,
answers: torch.Tensor
) -> Dict[str, float]:
"""
单步训练
Args:
questions: [batch, seq_len, hidden_dim]
answers: [batch, seq_len, hidden_dim]
"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
output, vq_loss = self.model(questions, use_memory=True)
# 计算任务损失(简化为MSE,实际应为交叉熵)
task_loss = F.mse_loss(output, answers)
# 总损失
total_loss = task_loss + self.vq_weight * vq_loss
# 反向传播
total_loss.backward()
self.optimizer.step()
# 更新码本(EMA)
with torch.no_grad():
memory, _, _, indices = self.model.encode_memory(questions)
self.model.vq.update_codebook_ema(memory, indices)
self.model.update_memory(memory, indices)
return {
'loss': total_loss.item(),
'vq_loss': vq_loss.item(),
'task_loss': task_loss.item()
}
def evaluate(
self,
test_loader: DataLoader,
topic_id: int = None
) -> float:
"""
评估模型在特定主题上的性能
"""
self.model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for questions, answers, topics in test_loader:
if topic_id is not None:
mask = topics == topic_id
if not mask.any():
continue
questions = questions[mask]
answers = answers[mask]
# 转换为浮点并添加特征维度
questions = questions.float().unsqueeze(-1).expand(-1, -1, 768)
answers = answers.float().unsqueeze(-1).expand(-1, -1, 768)
output, _ = self.model(questions, use_memory=True)
loss = F.mse_loss(output, answers)
total_loss += loss.item()
num_batches += 1
return total_loss / max(num_batches, 1)
def train_online(
self,
data_stream: DataLoader,
num_steps: int = 1000,
eval_interval: int = 100
):
"""
在线训练主循环
"""
step = 0
data_iter = iter(data_stream)
# 用于评估的历史主题数据
topic_test_data = {i: [] for i in range(5)}
while step < num_steps:
try:
questions, answers, topics = next(data_iter)
except StopIteration:
data_iter = iter(data_stream)
questions, answers, topics = next(data_iter)
# 保存测试数据
for i in range(len(topics)):
topic_id = topics[i].item()
if len(topic_test_data[topic_id]) < 20:
topic_test_data[topic_id].append(
(questions[i], answers[i])
)
# 转换为模型输入格式
questions = questions.float().unsqueeze(-1).expand(-1, -1, 768)
answers = answers.float().unsqueeze(-1).expand(-1, -1, 768)
# 训练步
metrics = self.train_step(questions, answers)
# 记录指标
self.metrics['loss'].append(metrics['loss'])
self.metrics['vq_loss'].append(metrics['vq_loss'])
self.metrics['memory_size'].append(len(self.model.memory_indices))
# 计算码本使用率
usage = (self.model.vq.usage_count > 0).sum().item()
self.metrics['codebook_usage'].append(
usage / self.model.vq.codebook_size
)
# 定期重置未使用的码字
if step % self.reset_interval == 0 and step > 0:
self.model.vq.reset_unused_codes()
# 定期评估
if step % eval_interval == 0:
print(f"\n步骤 {step}/{num_steps}")
print(f"损失: {metrics['loss']:.4f} | VQ损失: {metrics['vq_loss']:.4f}")
print(f"记忆库大小: {len(self.model.memory_indices)}")
print(f"码本使用率: {self.metrics['codebook_usage'][-1]:.2%}")
# 评估所有历史主题(检测遗忘)
for topic_id in range(5):
if len(topic_test_data[topic_id]) > 0:
# 创建临时测试集
test_questions = torch.stack([x[0] for x in topic_test_data[topic_id]])
test_answers = torch.stack([x[1] for x in topic_test_data[topic_id]])
test_questions = test_
Comments