LLMLingua
Microsoft Research — 用小模型的困惑度评分驱动 prompt 压缩,最高 20x 压缩率
架构概览
核心思想:困惑度评分
LLMLingua 的核心假设是:如果一个小型语言模型(GPT-2, 1.5B 参数)能轻松预测某个 token,那这个 token 对大模型来说也是冗余的。具体做法是用小模型对每个 token 计算困惑度(Perplexity = exp(CrossEntropyLoss))——高困惑度意味着模型难以预测,说明该 token 携带了不可替代的信息;低困惑度说明 token 可预测、可移除。通过设定阈值,保留困惑度高于阈值的 token,实现压缩。
三级预算分配
面对一个包含多段上下文(demonstration)、指令(instruction)和问题(question)的 prompt,LLMLingua 不是一刀切地压缩,而是分三级递进式分配 token 预算。第一级:上下文级——多段文档按 PPL 排名,贪心选择最相关的段落直到预算用尽。第二级:句子级——被选中段落内部,按句子 PPL 排名,保留最重要的句子。第三级:令牌级——被选中句子内部,逐 token 计算 PPL,低于阈值的移除。每一级都使用不同粒度的 PPL 评分。
迭代压缩与 KV Cache
令牌级压缩不是一次性处理整个 prompt,而是按 iterative_size(默认 200 token)分块迭代。每次处理一个块时,复用前一轮已计算的 KV cache,避免重复的注意力矩阵计算。当累计处理的 token 超出模型的 max_position_embeddings 时,直接在 KV cache 中截断中间部分(保留前缀 BOS 和尾部),相当于在推理层面同时做了「缓存压缩」。这让 LLMLingua 可以处理远超模型上下文窗口的输入。
LongLLMLingua:问题感知
LLMLingua v1 的困惑度评分是无条件的——只看 token 本身的可预测性,不考虑用户问了什么。LongLLMLingua 引入「条件困惑度」:将 question 和 context 拼接后一起送入小模型,只计算 context 部分的损失。这样评分变成了「给定这个问题,这段上下文有多重要」,而不是「这段文字本身有多可预测」。效果是:与问题高度相关的段落即使本身很平淡(低 PPL)也会被保留。
仅计算 context 自身的困惑度,不考虑问题
LLMLingua-2:分类加速
v1 和 LongLLMLingua 的瓶颈是多次前向传播——每个迭代块都需要计算一次 PPL。LLMLingua-2 换了思路:不再用困惑度,而是训练一个轻量级令牌分类器(基于 BERT/XLM-RoBERTa,约 350M 参数)。训练数据来自 GPT-4 蒸馏——让 GPT-4 标注哪些 token 应该保留/移除,再用这些标签训练分类器。推理时只需一次前向传播,输出每个 token 的保留概率,速度比 v1 快 3-6 倍。代价是失去了问题感知能力和迭代精调的灵活性。
LLMLingua v1
EMNLP 2023代码走读
从 get_ppl() 困惑度计算到 LLMLingua-2 令牌分类,走读 5 个关键函数的 Python 源码。
get_ppl():困惑度计算
核心评分函数。将文本送入小模型(GPT-2 / LLaMA),计算逐 token 的交叉熵损失。granularity="sentence" 返回损失均值(用于句子排名),granularity="token" 返回逐 token 损失向量(用于令牌级压缩)。 通过 use_cache=True 复用 KV cache。
# llmlingua/prompt_compressor.py
def get_ppl(self, text, granularity="sentence",
input_ids=None, past_key_values=None,
condition_mode="none", condition_pos_id=0):
if input_ids is None:
tokenized = self.tokenizer(text, return_tensors="pt")
input_ids = tokenized["input_ids"].to(self.device)
with torch.no_grad():
response = self.model(
input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
use_cache=True, # 复用 KV cache
)
# 计算逐 token 交叉熵损失
shift_logits = response.logits[..., :-1, :]
shift_labels = input_ids[..., past_length + 1 : end]
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
# 句子级: 返回标量均值; 令牌级: 返回逐 token 损失
res = loss.mean() if granularity == "sentence" else loss
return (res, past_key_values) if return_kv else rescontrol_context_budget():上下文级预算
第一级预算分配。对多段上下文按 PPL/BM25/embedding 排名,然后贪心选择直到 token 预算用尽。 支持动态预算公式(context_budget="+100") 和动态压缩比例(缓解 Lost in the Middle 问题——中间位置的段落压缩更激进)。
# llmlingua/prompt_compressor.py
def control_context_budget(self,
context, context_tokens_length,
target_token, question,
rank_method="longllmlingua",
context_budget="+100",
dynamic_context_compression_ratio=0.0):
# 1. 对上下文段落排名
sorted_contexts = self.get_rank_results(
context, question, rank_method,
condition_in_question, context_tokens_length,
)
# 2. 动态预算公式 (支持 "+100", "*0.8" 等)
target_token = eval("target_token" + context_budget)
# 3. 贪心选择:按重要性排序,选到预算用尽
used = []
for idx, _ in sorted_contexts:
target_token -= context_tokens_length[idx]
used.append(idx)
if target_token < 0:
break
# 4. 可选: 动态压缩比例(Lost in the Middle 缓解)
if dynamic_context_compression_ratio > 0:
dynamic_ratio = [ # 线性分布
i * (ratio / (N-1)) for i in range(-(N-1), N, 2)
][::-1]
return res, dynamic_ratio, usediterative_compress_prompt():迭代令牌压缩
第三级压缩的核心。按 iterative_size 分块处理,每块计算逐 token PPL, 用百分位数计算阈值(ratio=0.5 → 取 PPL 排序后第 50% 分位的值作为阈值), 高于阈值的 token 保留。超出位置编码上限时直接截断 KV cache 中间部分。
# llmlingua/prompt_compressor.py
def iterative_compress_prompt(self, context,
target_token, iterative_size=200):
# 计算每个迭代块的动态压缩比例
iterative_ratios = self.get_dynamic_compression_ratio(...)
input_ids = self.tokenizer(context)["input_ids"]
past_key_values = None
# 逐块迭代压缩
while end <= compressed_input_ids.shape[1]:
# 超出位置编码上限时,压缩 KV cache
if end > self.max_position_embeddings:
past_key_values = [
[torch.cat([k[..., :s, :], k[..., s+e:, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s+e:, :]], dim=-2)]
for k, v in past_key_values
]
# 计算当前块的逐 token PPL
loss, past_key_values = self.get_ppl(
"", "token", compressed_input_ids,
past_key_values=past_key_values, return_kv=True,
)
# 按比例计算阈值,选择令牌
threshold = self.get_estimate_threshold_base_distribution(
loss, ratio
)
compressed_input_ids = self.get_compressed_input(
loss, compressed_input_ids, threshold=threshold
)get_condition_ppl():条件困惑度
LongLLMLingua 的核心创新。三种模式:"none" 无条件 PPL;"before" question 前置,计算 context 在给定 question 下的困惑度;"after" question 后置,计算 question 在给定 context 下的困惑度。 通过 condition_pos_id 精确控制损失计算的区间。
# llmlingua/prompt_compressor.py — LongLLMLingua
def get_condition_ppl(self, text, question,
condition_in_question="none", granularity="sentence"):
if condition_in_question == "none":
# 无条件: PPL(text)
return self.get_ppl(text, granularity=granularity)
elif condition_in_question == "before":
# 问题前置: PPL(text | question)
# 拼接 question + text,只取 text 部分的损失
return self.get_ppl(
question + text,
condition_mode="after",
condition_pos_id=self.get_token_length(question) - 1,
)
elif condition_in_question == "after":
# 问题后置: PPL(question | text)
# 拼接 text + question,只取 text 部分的损失
return self.get_ppl(
text + question,
condition_mode="after",
condition_pos_id=self.get_token_length(text) - 1,
)__compress():LLMLingua-2 令牌分类
完全不同的路径——不用困惑度,直接用 BERT/XLM-RoBERTa 做令牌二分类(保留/移除)。 模型通过 GPT-4 蒸馏训练,推理时只需一次前向传播。子词概率通过 mean 聚合为单词概率, 然后用百分位数阈值选择。支持 force_tokens(强制保留特定词)和 force_reserve_digit(保留数字)。
# llmlingua/prompt_compressor.py — LLMLingua-2
def __compress(self, context_list, reduce_rate=0.5,
force_tokens=[], force_reserve_digit=False):
with torch.no_grad():
for batch in dataloader:
# BERT/XLM-RoBERTa 前向传播: 令牌二分类
outputs = self.model(input_ids=ids, attention_mask=mask)
probs = F.softmax(outputs.logits, dim=-1)
keep_probs = probs[j, :, 1] # 类别 1 = 保留
# 合并子词为完整单词
words, word_probs = self.__merge_token_to_word(
tokens, token_probs, force_tokens)
# 百分位数阈值 (与 v1 的 PPL 阈值类似)
threshold = np.percentile(
word_probs, int(100 * reduce_rate + 1)
)
# 选择:概率 > 阈值的词保留
for word, prob in zip(words, word_probs):
if prob > threshold:
keep_words.append(word)
word_labels.append(1) # 保留
else:
word_labels.append(0) # 移除
compressed = self.tokenizer.convert_tokens_to_string(keep_words)