一句话总结

HullFT 将每条 query 的 embedding 表达为训练样本的稀疏凸组合,同时解决了测试时微调的两个瓶颈:如何快速选出相关且多样的支撑集,以及如何复用重复样本的梯度计算。


为什么测试时微调很难?

测试时微调(Test-Time Finetuning,TTFT)的思路很直接:推理前,针对当前 query 检索若干相关训练序列,在这些序列上短暂微调模型,再做推理。

问题在于,每条 query 都要做一次微调,延迟直接叠在推理路径上。现有方法陷入两难:

  • 快速检索(k-NN):用向量相似度取 Top-k,速度快但选出来的样本高度冗余,对质量提升有限
  • 多样性感知检索:显式优化覆盖度,效果好但每次都要跑一个优化问题,单次开销就把延迟打爆

HullFT 的核心洞见是:这两个问题其实共享同一个几何结构。把 query 表达成训练样本的凸组合,相关性和多样性都自然涌现——你需要足够近的点(相关性),也需要足够多方向的点才能”撑起” query 所在位置(多样性)。


核心方法:三步走

第一步:凸包重建(Convex Reconstruction)

给定 query embedding $q \in \mathbb{R}^d$ 和训练集 embeddings $X = {x_1, \ldots, x_n} \subset \mathbb{R}^d$,求:

\[\min_{w \in \Delta^n} \left\| \sum_{i=1}^n w_i x_i - q \right\|^2\]

其中 $\Delta^n = {w \mid w_i \geq 0,\ \sum_i w_i = 1}$ 是概率单纯形。

这个问题的解 $w^*$ 是稀疏的——大多数 $w_i = 0$,只有少量样本有正权重,这些样本就是支撑集。与 k-NN 不同,这里的选择是全局最优的:给定预算 k,它找到能最好还原 query embedding 的 k 个样本组合。

第二步:Frank-Wolfe 求解

直接对单纯形投影开销 $O(n \log n)$,而 Frank-Wolfe(条件梯度法)只需每轮做一次线性极小化,对单纯形而言就是 argmin——时间复杂度 $O(n)$。

迭代格式:

\[s^{(t)} = e_{i^*}, \quad i^* = \arg\min_i \left[ \nabla_w f(w^{(t)}) \right]_i\] \[w^{(t+1)} = (1 - \gamma_t)\, w^{(t)} + \gamma_t\, s^{(t)}\]

目标函数梯度为 $\nabla_w f = 2X(Xw - q)^T$,精确步长可解析求得:

\[\gamma^* = \frac{-\nabla f(w)^\top d}{2\, \|Xd\|^2}, \quad d = s^{(t)} - w^{(t)}\]

Frank-Wolfe 每步把一个新样本”加进来”,天然产生稀疏解,且每次迭代只扫一遍候选集。

第三步:整数化 + 梯度复用

凸组合权重 $w^$ 是连续的,但微调需要离散的训练样本集合。HullFT 把 $w^$ 转为整数多重集:给定总预算 $B$,令 $m_i = \lfloor w_i^* \cdot B \rfloor$,再把剩余配额分给小数部分最大的样本。

整数化后,某些样本会出现多次(multiplicity $m_i > 1$)。对于这些重复样本,梯度计算只需做一次,然后乘以倍数:

\[g_i^{\text{reuse}} = m_i \cdot \nabla_\theta \mathcal{L}(\theta;\, x_i)\]
这就是 Gradient Reuse——把本来 $\sum_i m_i$ 次前向-反向传播,降到只做 $ \text{support} $ 次。

代码实现

Frank-Wolfe 凸组合求解器

import torch
import torch.nn.functional as F

def frank_wolfe_reconstruct(query: torch.Tensor,
                            corpus: torch.Tensor,
                            n_iter: int = 50,
                            tol: float = 1e-6) -> torch.Tensor:
    """
    将 query 表达为 corpus 行向量的稀疏凸组合。
    query:  (d,)
    corpus: (n, d)  —— 预先 L2 归一化效果更好
    返回:  w (n,),稀疏的概率权重向量
    """
    n = corpus.shape[0]
    # 初始化:选最相似的单个样本
    sims = corpus @ query  # (n,)
    w = torch.zeros(n, dtype=query.dtype)
    w[sims.argmax()] = 1.0

    for _ in range(n_iter):
        recon = w @ corpus          # 当前重建 (d,)
        residual = recon - query    # (d,)

        # 梯度 ∂f/∂w_i = 2 * x_i · residual
        grad = 2.0 * (corpus @ residual)  # (n,)

        # FW 线性步:在单纯形上极小化线性近似 = 选 argmin
        s = torch.zeros_like(w)
        s[grad.argmin()] = 1.0

        d = s - w  # 下降方向
        Xd = d @ corpus  # (d,)

        denom = 2.0 * (Xd @ Xd)
        if denom < 1e-12:
            break

        # 精确线搜索
        gamma = (-(grad @ d) / denom).clamp(0.0, 1.0)
        w_new = w + gamma * d

        if (w_new - w).norm() < tol:
            w = w_new
            break
        w = w_new

    return w.clamp(min=0)  # 数值误差修正

权重整数化

def integerize_weights(w: torch.Tensor, budget: int) -> torch.Tensor:
    """
    把连续权重转为整数多重集,总和恰好等于 budget。
    采用"最大余数法"(Hamilton method),避免舍入偏差。
    """
    scaled = w * budget
    m = scaled.floor().long()
    remainder = budget - m.sum().item()

    if remainder > 0:
        # 把剩余名额分给小数部分最大的样本
        fracs = scaled - m.float()
        top_idx = fracs.topk(int(remainder)).indices
        m[top_idx] += 1

    return m  # shape (n,), sum == budget

梯度复用微调

def gradient_reuse_step(model, loss_fn, examples, multiplicities, optimizer):
    """
    对支撑集做一步微调,重复样本只计算一次梯度。
    examples:      List[样本]
    multiplicities: torch.Tensor (n,),整数倍数
    """
    optimizer.zero_grad()
    support_idx = multiplicities.nonzero(as_tuple=False).squeeze(-1)

    for idx in support_idx:
        m = multiplicities[idx].item()
        loss = loss_fn(model, examples[idx])
        # 反向传播,梯度乘以倍数,累积到 .grad
        loss.backward()
        with torch.no_grad():
            for p in model.parameters():
                if p.grad is not None:
                    p.grad.mul_(m)  # 梯度复用:乘以出现次数

    # 所有支撑样本的梯度已累积
    optimizer.step()

完整 HullFT 推理流程

class HullFT:
    def __init__(self, model, corpus_embeddings, corpus_examples,
                 budget=10, n_fw_iter=50, finetune_steps=1):
        self.model = model
        self.corpus_emb = F.normalize(corpus_embeddings, dim=-1)
        self.corpus_ex  = corpus_examples
        self.budget = budget
        self.n_fw_iter = n_fw_iter
        self.ft_steps = finetune_steps

    @torch.no_grad()
    def select(self, query_emb: torch.Tensor):
        q = F.normalize(query_emb, dim=-1)
        w = frank_wolfe_reconstruct(q, self.corpus_emb, self.n_fw_iter)
        m = integerize_weights(w, self.budget)
        return m  # 各样本在微调集中出现次数

    def adapt_and_infer(self, query, query_emb, loss_fn, infer_fn):
        # 保存原始权重
        original_state = {k: v.clone() for k, v in self.model.state_dict().items()}

        # 选择支撑集
        m = self.select(query_emb)

        # 微调
        opt = torch.optim.SGD(self.model.parameters(), lr=1e-4)
        for _ in range(self.ft_steps):
            gradient_reuse_step(self.model, loss_fn,
                                self.corpus_ex, m, opt)

        # 推理
        result = infer_fn(self.model, query)

        # 恢复原始权重(每次 query 互不干扰)
        self.model.load_state_dict(original_state)
        return result

实现中的坑

坑 1:corpus embedding 必须预先归一化

Frank-Wolfe 用内积衡量相似度,不归一化会让高范数样本被过度选中:

# 错误:直接用原始 embedding
corpus_emb = model.encode(corpus_texts)

# 正确:L2 归一化后存储
corpus_emb = F.normalize(model.encode(corpus_texts), dim=-1)

坑 2:多步微调时梯度复用会引入偏差

finetune_steps > 1 时,第二步的梯度是在更新后的参数上计算的,无法再复用第一步的梯度。论文结果主要在 finetune_steps=1 下成立,多步时需重新计算:

# 仅单步时梯度复用是精确的
# 多步微调时应退化为标准循环
if self.ft_steps > 1:
    # 展开重复样本,正常训练
    expanded = [self.corpus_ex[i] for i, cnt in enumerate(m)
                for _ in range(cnt.item())]

坑 3:Frank-Wolfe 收敛速度对初始点敏感

对高维 embedding(≥ 768d),从随机初始点出发可能需要数百次迭代。用最近邻作为起点可以把迭代次数降到 20-30 次。


论文结果 vs 现实

论文报告 HullFT 在语言建模(bits-per-byte 指标)上优于 kNN-TTFT 等基线,同时总 runtime 更低。我认为有几点需要注意:

可信度较高的部分:

  • Gradient Reuse 的加速是精确的,savings 随 support 的稀疏度线性增长
  • Frank-Wolfe 在低支撑数(k ≤ 10)时确实比暴力最优化快

需要谨慎的部分:

  • 论文的实验集中在语言建模任务;对 reasoning、code 等任务是否同样有效尚不明确
  • budget B 和 n_fw_iter 是新增超参数,实验室场景下调好了,生产环境不一定迁移

适用边界

适用场景 不适用场景
query 分布与训练集接近(in-distribution) OOD query 超出训练集凸包范围
已经在用 TTFT,想降低每次微调开销 对延迟极度敏感、不能接受任何额外检索开销
训练集足够大(> 10K),凸包覆盖度好 训练集极小(< 1K),凸包近似质量差
单步或少步微调(1-3 步) 需要大量微调步数才能收敛

我的看法

HullFT 最聪明的地方不是 Frank-Wolfe,也不是梯度复用——这两个技术都不新。聪明的是把它们串联成一个整体:凸组合权重自然产生重复样本,重复样本自然触发梯度复用。整个流程里没有多余的设计,每步都在服务下一步。

真正的开放问题是:TTFT 本身是否值得?每次推理都要微调模型,这在延迟预算有限的场景下依然是奢侈的。HullFT 改进了效率,但没有改变 TTFT 的根本代价。更可能的未来是把这种”基于检索的自适应”思路融合进 in-context learning 或 LoRA adapter 缓存里,而不是每次都动原始参数。

不过,对于离线批处理、医疗/法律等高准确率要求的场景,TTFT 是有实际意义的。HullFT 在这个细分赛道里是目前最干净的解法之一。


论文链接arxiv 2605.30337