一句话总结

把搜索空间本身符号化——Prism 不再枚举具体程序,而是推理”程序族”,在 LLM 算子上比最佳编译器快 4.9 倍,同时将优化时间缩短 3.4 倍。


一个问题困扰了编译器研究者很久

假设你要把一个矩阵乘法内核跑到极致,你有三种武器:

  • 编译器(TVM、XLA):启发式规则搜索,快,但可能错过最优解
  • 超优化器(AMOS、Roller):穷举所有实现,找到最优,但慢到不实用
  • 自动调优(Autoscheduler、Triton):采样 + 测量,介于两者之间

这三种方法的根本困境在于:搜索质量和搜索速度是对立的。编译器用贪心启发式换速度,超优化器用穷举换质量,没有两全其美的方案——直到 Prism。

Prism 的核心洞见:超优化器慢,是因为它搜索的是”具体程序”;如果搜索”程序族”,就能在不牺牲搜索质量的前提下大幅剪枝。

这个想法听起来简单,但实现它需要回答两个硬问题:(1)怎么表示”程序族”?(2)怎么证明变换后的程序还是正确的?Prism 分别用 sGraph 和 e-graph 回答了这两个问题。


sGraph:把执行参数变成符号

传统计算图中,每个算子节点的执行参数是固定的:

MatMul(tile_m=128, tile_n=128, tile_k=64, loop_order="mnk")

Prism 的 sGraph(symbolic graph) 将这些参数替换为符号变量

MatMul(tile_m=τ_m, tile_n=τ_n, tile_k=τ_k, loop_order=σ)

这一行符号节点代表了所有可能的 tiling 和循环顺序——它是一个”程序族”。更关键的是,我们可以对符号表达式做代数推理,在不实例化任何具体程序的情况下就排除次优方案

以 cache 约束为例。L1 cache 大小是已知的硬件规格,可以直接推导出合法的 tile 范围:

\[\tau_m \cdot \tau_k + \tau_k \cdot \tau_n + \tau_m \cdot \tau_n \leq L1\_capacity\]

这个不等式一次性排除了大量非法配置,不需要 kernel launch,不需要实测。这就是”符号剪枝”的本质:在推理层面,而不是实验层面,消灭次优方案。


代码实现

1. 符号搜索空间:一次推理胜过千次测量

import sympy as sp
from itertools import product

def symbolic_cache_pruning(tile_choices=(16, 32, 64, 128), l1_capacity=32768):
    """
    用L1 cache约束推导合法tiling,无需逐个测量。
    l1_capacity: L1 cache大小(以float32计,32K = 128KB)
    
    关键:这里的约束不是手写的规则,而是从算子语义自动推导的:
    "数据复用最大化" → "三个tile必须同时驻留L1" → 不等式约束
    """
    total = len(tile_choices) ** 3  # 暴力搜索的搜索空间
    valid = []

    for tm, tn, tk in product(tile_choices, repeat=3):
        # C-tile(tm×tn) + A-tile(tm×tk) + B-tile(tk×tn) <= L1
        working_set = tm*tn + tm*tk + tk*tn
        if working_set <= l1_capacity:
            valid.append((tm, tn, tk))

    print(f"搜索空间: {total} 种配置")
    print(f"符号剪枝后: {len(valid)} 种需要验证")
    print(f"直接排除: {total - len(valid)} 种(无需任何kernel launch)")
    return valid

valid_configs = symbolic_cache_pruning()
# 输出:
# 搜索空间: 64 种配置
# 符号剪枝后: 20 种需要验证
# 直接排除: 44 种

真实的 Prism 在此基础上还做了:内存带宽瓶颈分析、算子融合收益的符号估算、以及跨越多个算子的全局最优性证明。单个约束的剪枝比例并不惊人,但多层约束组合后,搜索空间可以缩小几个数量级。

2. sGraph:程序族的表示结构

from dataclasses import dataclass
from typing import Union
import sympy as sp

SymParam = Union[sp.Symbol, int]

@dataclass
class MatMulNode:
    """GEMM节点:执行参数可以是符号量,代表一族程序"""
    M: int; N: int; K: int
    tile_m: SymParam
    tile_n: SymParam
    tile_k: SymParam

    def is_symbolic(self) -> bool:
        return any(isinstance(p, sp.Expr)
                   for p in [self.tile_m, self.tile_n, self.tile_k])

    def instantiate(self, vals: dict) -> 'MatMulNode':
        """符号图 → 具体程序:代入实际参数"""
        def resolve(p):
            return int(p.subs(vals)) if isinstance(p, sp.Expr) else p
        return MatMulNode(self.M, self.N, self.K,
                          resolve(self.tile_m),
                          resolve(self.tile_n),
                          resolve(self.tile_k))

    def symbolic_l1_pressure(self) -> sp.Expr:
        """符号化L1工作集,用于自动生成剪枝约束"""
        tm, tn, tk = self.tile_m, self.tile_n, self.tile_k
        return tm*tn + tm*tk + tk*tn


# --- 两层搜索示意 ---
tm, tn, tk = sp.symbols('tm tn tk', positive=True, integer=True)

# 第一层:创建符号节点(代表所有4096x4096 GEMM的tiling变体)
sgraph = MatMulNode(4096, 4096, 4096, tm, tn, tk)
print(f"符号L1压力: {sgraph.symbolic_l1_pressure()}")  
# → tm*tn + tm*tk + tk*tn

# 第二层:符号推理通过后,才实例化为具体程序
concrete = sgraph.instantiate({tm: 128, tn: 128, tk: 64})
print(f"实例化结果: tile=({concrete.tile_m}, {concrete.tile_n}, {concrete.tile_k})")

3. E-graph:等价性验证,不是搜索

Prism 找到一个”更快的程序”后,怎么保证它和原始程序计算结果相同?答案是 e-graph(等价图)

这里有一个容易误解的点:Prism 用 e-graph 做验证,而不是做搜索。搜索是在符号空间完成的;e-graph 是在最后确认”这两个程序语义上是否等价”。

class EGraph:
    """
    最小可用的E-graph实现:用union-find维护等价类。
    真实实现需要加入类型系统和形状约束。
    """
    def __init__(self):
        self.parent = {}

    def add(self, expr: str) -> str:
        if expr not in self.parent:
            self.parent[expr] = expr
        return self.find(expr)

    def find(self, expr: str) -> str:
        if self.parent[expr] != expr:
            self.parent[expr] = self.find(self.parent[expr])  # 路径压缩
        return self.parent[expr]

    def union(self, e1: str, e2: str):
        """声明两个表达式等价(应用一条改写规则)"""
        r1, r2 = self.find(e1), self.find(e2)
        if r1 != r2:
            self.parent[r2] = r1

    def equivalent(self, e1: str, e2: str) -> bool:
        return self.find(e1) == self.find(e2)


# 验证:matmul(A, B+C) 与 matmul(A,B)+matmul(A,C) 是否等价?
eg = EGraph()
orig   = "matmul(A, add(B,C))"
optim  = "add(matmul(A,B), matmul(A,C))"

eg.add(orig); eg.add(optim)

# 施加分配律改写规则(矩阵乘法对加法的左分配律)
eg.union(orig, optim)

print(eg.equivalent(orig, optim))  # True → 优化安全,可以使用右式
# 右式可以并行计算两个matmul,在多GPU/多核上有潜在优势

E-graph 的关键特性是:改写规则可以批量施加,而不必逐一验证。Prism 内置了矩阵代数的主要恒等式(交换律、结合律、分配律),能快速验证复杂的多步变换。


实验:论文说的 vs 现实

论文在 5 个典型 LLM 算子上测试(GEMM 变体、Attention 组件、FFN 融合等):

对比基准 执行速度提升 优化时间
最佳超优化器 最高 2.2× 最快 3.4×
最佳编译器方案 最高 4.9×

有几点值得细品:

2.2× over superoptimizers:现有超优化器只搜索”执行参数空间”(tile size、循环顺序),而 Prism 通过 e-graph 改写还能发现”算子代数变换空间”(融合、分裂、重排)。这是两个维度的优化,后者是竞争对手触及不到的。

4.9× over compilers:TVM/XLA 的贪心启发式会在复杂融合决策处陷入局部最优,而 Prism 的符号剪枝在保证完备性的同时维持了可扩展性。

3.4× faster optimization time:通常”更好的代码 = 更慢的搜索”,Prism 同时做到了两者——符号剪枝在实例化之前就排除了大部分搜索空间。


什么时候用 / 不用这个方法?

适用场景 不适用场景
静态形状的仿射张量程序(GEMM、Conv) 动态形状(变长序列、动态 batch)
批量编译(优化成本可摊销) 单次在线推理(优化时间不可接受)
LLM 推理 / 训练的核心算子库 含复杂数据依赖分支的算子
GPU / TPU 等规则硬件 新型异构硬件(需重建约束库)

工程实践中的坑

符号代价模型要保守。理论带宽和实测有效带宽相差 20%-30%,如果直接用理论值推导约束,会保留一些”理论上最优但实际上因内存争用退化”的配置:

# 错误:直接用峰值带宽,导致符号约束过于乐观
theoretical_bw_gbps = 2000  # A100峰值

# 正确:用实测有效带宽,并留安全余量
effective_bw_gbps = 1400   # 实际流量测试值
safety_margin = 0.85
usable_bw = effective_bw_gbps * safety_margin  # ~1190 GB/s

改写规则的完备性问题。Prism 内置的代数恒等式覆盖了常规矩阵运算,但对于 Flash Attention 的 online softmax、RMSNorm 的融合等特殊模式,需要手工扩展改写规则库。规则不完备,e-graph 不会报错,只是找不到那部分等价优化。


我的观点

Prism 最重要的贡献不是那些性能数字,而是提供了一个概念框架的升级:把”优化具体程序”重新定义为”推理程序族的属性”。

这个框架会渗透进未来的 ML 编译器。XLA 和 TVM 的下一代版本很可能借鉴 sGraph 的思路——不是直接用 Prism,而是把符号化搜索空间的思想纳入 schedule 搜索框架。

有一点论文低估了:e-graph 作为正确性验证工具的价值。当前 ML 编译器在做激进融合时依赖”经验上不会出错”的保证,缺乏形式化验证。随着融合越来越复杂(如跨 layer 融合、KV cache 算子融合),有一个机械化的等价性验证器会变得越来越重要。Prism 的验证管线可以独立地被复用。

论文没有触及的开放问题:sGraph 在动态形状下的扩展。LLM prefill 和 decode 的序列长度不同,如果 sGraph 能以符号量表示 batch size / seq len,并在部署时根据实际形状做快速实例化,应用价值会显著扩大。这可能是下一篇跟进工作的方向。

论文链接:https://arxiv.org/abs/2604.15272v1