单卡复现 DeepSeek R1 Zero教程来了!
Page content
单卡复现 DeepSeek R1 Zero教程来了! by Datawhale
Datawhale干货
作者:邓恺俊,Datawhale成员
Datawhale干货
为什么单卡就能复现?
强化学习算法优化:集成了多种强化学习(RL)算法,并通过底层代码优化(如优化计算图、减少冗余操作),显著提升了大模型在推理和微调时的性能。 最新量化技术:大幅降低显存消耗,使得原本需要多卡的大模型也能在单卡上运行。 完整的 LoRA 和 QLoRA 微调支持:即使显存有限,也能通过少量资源复现 R1 Zero。
环境搭建
安装 Unsloth
本文中仅展示与前文有差异的代码部分,同时我们提供了完整的训练代码,请在文末获取。
注意:为了兼容 Unsloth,我们需要安装特定版本的 trl。具体命令如下:
# 安装 unsloth 和 vllm
pip install unsloth vllm
# 安装指定版本的 trl(兼容 unsloth)
pip install trl==0.15.0
参考自:https://docs.unsloth.ai/get-started/unsloth-notebooks
配置文件修改
大部分配置与之前的 Datawhale-R1.yaml 文件保持一致。为了支持单卡复现 R1 Zero,我们做了如下调整:
LoRA 参数设置:启用 LoRA 微调,调整 LoRA 秩数(lora_r)为 64(常用的选择有 8、16、32、64、128 等),并设置 lora_alpha 为 32。 限制回答长度:将 max_completion_length 设置为 1024,以控制输出长度。 优化器调整:优化器设置为 adamw_8bit,以加速训练。
# LoRA 参数调整
lora_r: 64 # LoRA 秩数,选择任意大于 0 的数字!建议使用 8, 16, 32, 64, 128
lora_alpha: 32 # LoRA alpha 值
# 训练参数
learning_rate: 1.0e-5 # 学习率,调整为1e-5
# GRPO 算法参数
beta: 0.001 # KL 惩罚因子
optim: adamw_8bit # 使用 8bit 优化器以加速训练
max_prompt_length: 256 # 输入 prompt 的最大长度
max_completion_length: 1024 # 输出回答长度,包含推理思维链
num_generations: 4
use_vllm: true # 启用 vLLM 加速推理
vllm_gpu_memory_utilization: 0.4 # vLLM 的 GPU 内存利用率(内存紧张时可适当降低)
LoRA微调参考:https://zhuanlan.zhihu.com/p/663557294
启动训练
启动训练的代码很简单,由于我们只需要单卡,不需要涉及到配置复杂的 Accelerate 库,直接运行以下代码即可运行。
python train_Datawhale-R1_unsloth.py --config Datawhale-R1_unsloth.yaml
训练代码优化解读
基于 Unsloth 框架,我们对原始代码做了简化和优化。主要思路有两点:
打补丁提升训练速度
在执行强化学习训练的代码之前,我们添加了两行代码,利用 PatchFastRL 函数对某些 RL 算法(如 GRPO)进行“打补丁”。这个操作实际上在底层优化了计算图、减少了冗余计算,从而加速训练过程。
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel) # 对 GRPO 算法打补丁
GRPO 训练函数的改进
模型加载:通过 FastLanguageModel.from_pretrained 方法加载预训练模型,并启用 vLLM 快速推理,同时支持 4 位加载(或 LoRA 16 位)。 PEFT 微调:利用 get_peft_model 方法对模型应用 LoRA 微调,指定了目标模块、LoRA 参数以及梯度检查点,确保在有限显存条件下依然能有效训练。
# 定义 GRPO 训练函数
def grpo_function(
model_args: ModelConfig,
dataset_args: DatasetArguments,
training_args: GRPOConfig,
callbacks: List,
):
# 记录模型参数
logger.info(f"Model parameters {model_args}")
# 记录训练/评估参数
logger.info(f"Training/evaluation parameters {training_args}")
# 从预训练模型加载模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_args.model_name_or_path, # 模型名称或路径
fast_inference=True, # 启用 vLLM 快速推理
load_in_4bit=True, # 是否以 4 位加载模型,False 表示使用 LoRA 16 位
max_lora_rank=model_args.lora_r, # 设置 LoRA 的最大秩
max_seq_length=training_args.max_completion_length, # 设置最大序列长度
gpu_memory_utilization=training_args.vllm_gpu_memory_utilization, # GPU 内存利用率,若内存不足可减少
attn_implementation=model_args.attn_implementation, # 设置注意力实现方式 flash attention
)
# PEFT 模型
model = FastLanguageModel.get_peft_model(
model,
r = model_args.lora_r,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", # 如果 OOM 内存不足,可以移除 QKVO
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = model_args.lora_alpha, # 设置 LoRA 的 alpha 值
use_gradient_checkpointing = "unsloth", # 启用 unsloth 的梯度检查
random_state = training_args.seed, # 设置随机种子
)
参考自:https://unsloth.ai/blog/r1-reasoning
模型量化参考:LLM量化综合指南(8bits/4bits)https://zhuanlan.zhihu.com/p/671007819
训练结果与一些思考
Aha moment 是 RL 训练的结果吗?
思考长度越长越有效吗?
S1 文章的一些结论和思考
通过添加 "end-of-thinking token 分隔符" 和 "Final Answer" 来控制思考上限; 通过禁止生成分隔符并添加 "wait" 提示词来控制思考下限。
总结与展望
完整文件获取
原文链接