Python 实现 GRPO 简版
Page content
Python 实现 GRPO 简版 by 数据STUDIO
今天我们将深入探讨GRPO的实现。先简要介绍这一概念,讨论方法,然后开始具体实现。
什么是GRPO?
GRPO是一种训练技术,旨在通过捕捉特定偏好的奖励函数来优化语言模型。与其他强化学习方法(如PPO或RLHF)不同,GRPO不需要复杂的评判模型和大量计算资源,而是直接优化语言模型,并通过在生成的响应组内计算相对优势来实现目标。
GRPO的关键特点
GRPO的独特之处
GRPO是一种新兴的强化学习技术,相比传统方法具有以下优势:
直接优化:不同于需要独立奖励模型的方法,GRPO直接使用显式奖励函数优化语言模型。 多奖励信号:可以定义多个奖励函数,针对生成内容的不同方面(如正确性、格式、风格)。 探索效率:GRPO通过在训练过程中为每个提示生成多个补全内容,有效探索输出空间。
奖励函数
代码实现了多个协同工作的奖励函数,用于指导模型:
correctness_reward_func:当模型提取的答案与真实答案匹配时,奖励2.0分。这是事实正确性的主要学习信号。 int_reward_func:当答案是数字时奖励0.5分,适用于数学问题,引导模型生成数值响应。 soft_format_reward_func和strict_format_reward_func:奖励正确的XML格式(0.5分),教导模型使用正确的标签结构响应。 xmlcount_reward_func:为每个正确使用的XML标签提供部分奖励(每个标签0.125分),形成平滑的学习梯度。
实现组件
奖励函数:根据特定标准评估模型输出: 正确性:检查提取的答案是否与真实答案匹配。 格式遵循:确保响应符合请求的XML格式。 整数检测:奖励数值答案。 数据集准备:使用GSM8K(数学应用题)数据集,并进行特定格式化。 训练配置:使用LoRA进行参数高效微调。
训练过程
对于数据集中的每个提示,模型生成多个补全内容(由num_generations设置,代码中为4)。 每个补全内容由所有奖励函数评估。 奖励用于更新模型权重,鼓励模型生成更高奖励的输出。 此过程持续指定的周期数。
参数高效微调
我们使用LoRA(低秩适应)高效微调模型。LoRA向注意力层添加小型可训练的“适配器”矩阵,大幅减少训练参数数量(通常>99%)。peft_config定义了目标层和适配器的秩。
实施考虑
使用较小的模型(Qwen2.5-1.5B-Instruct)以适应内存限制。 减小批次大小和生成数量以管理内存使用,并使用较小的数据集子集(20个示例)进行快速实验。 测试代码可在训练后立即评估结果。可通过增加max_samples行更全面的训练,或尝试不同的奖励函数。
代码实现
安装所需包
pip install -q transformers datasets trl peft accelerate
import re
import torch
import numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer
定义系统和响应提示
SYSTEM_PROMPT = """
请按以下格式响应:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
提取答案的辅助函数
def extract_xml_answer(text: str) -> str:
"""从XML格式的响应中提取答案部分。"""
if"<answer>" not in text or "</answer>" not in text:
return""
answer = text.split("</answer>")[-1]
answer = answer.split("<answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str:
"""从GSM8K格式中提取答案(###标记之后)。"""
if"####" not in text:
return""
return text.split("###")[1].strip().replace(".", "").replace("$", "")
加载并准备GSM8K数据集
def get_gsm8k_questions(split="train", max_samples=100) -> Dataset: """
加载GSM8K数据集并格式化为GRPO训练所需形式。
参数:
split: 使用的数据集划分(train, test)
max_samples: 使用的最大样本数(用于快速实验)
""" data = load_dataset('openai/gsm8k', 'main')[split] # 限制数据集大小以加快实验 if max_samples and max_samples < len(data): data = data.select(range(max_samples)) # 格式化数据为所需的提示结构 data = data.map(lambda x: { 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) return data
奖励函数
def correctness_reward_func(prompts, completions, answer, **kwargs):
"""
检查提取的答案是否与真实答案匹配的奖励函数。
正确答案返回2.0,否则返回0.0。
""" responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] # 打印调试信息 if kwargs.get('debug', False) and len(responses) > 0: print('-'*20) print(f"问题:\n{q}") print(f"\n真实答案:\n{answer[0]}") print(f"\n模型响应:\n{responses[0]}") print(f"\n提取的答案:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: """
检查提取的答案是否为数字的奖励函数。
整数答案返回0.5,否则返回0.0。
""" responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """
检查补全内容是否完全符合格式的奖励函数。
匹配格式返回0.5,否则返回0.0。
""" pattern = r"^\n<reasoning>.*?</reasoning>\n\n<answer>.*?</answer>\n$" responses = [completion[0]['content'] for completion in completions] matches = [bool(re.search(pattern, r, flags=re.DOTALL)) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """
宽松的格式检查奖励函数。
匹配格式返回0.5,否则返回0.0。
""" pattern = r"<reasoning>.*?</reasoning>.*?<answer>.*?</answer>" responses = [completion[0]['content'] for completion in completions] matches = [bool(re.search(pattern, r, flags=re.DOTALL)) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: """
统计XML标签并为每个正确放置的标签提供部分奖励。
""" count = 0.0 if text.count("<reasoning>") == 1: count += 0.125 if text.count("</reasoning>") == 1: count += 0.125 if text.count("<answer>") == 1: count += 0.125 if text.count("</answer>") == 1: count += 0.125 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: """
基于响应中XML标签计数的奖励函数。
""" contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents]
模型设置
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# 设置输出目录和运行名称
output_dir = "outputs/Qwen-1.5B-GRPO"
run_name = "Qwen-1.5B-GRPO-gsm8k"
配置GRPO训练
training_args = GRPOConfig(
output_dir=output_dir, run_name=run_name, learning_rate=5e-6, adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type='cosine', logging_steps=1, bf16=False, # 设置为False,因为Colab不支持 fp16=True, # 使用fp16以提高兼容性 per_device_train_batch_size=4, # 增加以兼容GRPO gradient_accumulation_steps=2, num_generations=4, # 必须是per_device_train_batch_size的除数 max_prompt_length=256, max_completion_length=512, num_train_epochs=1, save_steps=50, max_grad_norm=0.1, report_to="none", log_on_each_node=False, )
配置LoRA进行参数高效微调
peft_config = LoraConfig(
r=8, # 从16减少以适应Colab内存
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # 简化目标模块
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
加载并准备模型
print(f"加载模型:{model_name}") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # 使用float16而非bfloat16 device_map="auto", # 让模型自动选择最佳设备配置 low_cpu_mem_usage=True, # 提高内存效率 trust_remote_code=True # 新模型有时需要 ) # 加载分词器 tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 加载GSM8K数据集的子集 print("加载数据集...") dataset = get_gsm8k_questions(max_samples=20) print(f"数据集加载完成,共{len(dataset)}个示例")
初始化GRPO训练器
print("初始化GRPO训练器...")
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
int_reward_func,
correctness_reward_func
],
args=training_args,
train_dataset=dataset,
peft_config=peft_config # 启用LoRA进行高效微调
)
运行GRPO
# 开始训练
print("开始GRPO训练...") trainer.train() # 保存最终模型 print("训练完成。保存模型...") trainer.save_model() # 训练后测试模型 print("\n--- 测试训练后的模型 ---\n") # 生成预测的函数 def generate_prediction(model, tokenizer, question, max_length=512): prompt = [ {'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': question} ] # 格式化提示 messages = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) # 分词输入 inputs = tokenizer(messages, return_tensors="pt").to(model.device) # 生成响应 with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, do_sample=False ) # 解码响应 response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response # 测试数据集中的几个示例 test_examples = dataset.select(range(3)) for i, example in enumerate(test_examples): question = example['prompt'][-1]['content'] ground_truth = example['answer'] print(f"\n示例 {i+1}:") print(f"问题:{question}") print(f"真实答案:{ground_truth}") # 生成预测 response = generate_prediction(model, tokenizer, question) print(f"模型响应:{response}") # 提取答案 extracted_answer = extract_xml_answer(response) print(f"提取的答案:{extracted_answer}") print(f"是否正确:{extracted_answer == ground_truth}") print("-" * 50) print("搞定!")
总结
这一实现展示了GRPO的工作原理,以及如何利用它优化语言模型以适应特定格式和任务。数学问题解决任务与XML格式的结合,清晰地体现了该技术的能力。
真是一次有趣的实践!
作者:arjun链接:https://www.k-a.in/grpo-1B.html
编辑:AI翻译、「深度学习自然语言处理」公众号润色
🏴☠️宝藏级🏴☠️ 原创公众号『数据STUDIO』内容超级硬核。公众号以Python为核心语言,垂直于数据科学领域,包括可戳👉 Python|MySQL|数据分析|数据可视化|机器学习与数据挖掘|爬虫 等,从入门到进阶!
长按👇关注- 数据STUDIO -设为星标,干货速递
原文链接