Python 实现 GRPO 简版

Page content

Python 实现 GRPO 简版 by 数据STUDIO


今天我们将深入探讨GRPO的实现。先简要介绍这一概念,讨论方法,然后开始具体实现。

什么是GRPO?

GRPO是一种训练技术,旨在通过捕捉特定偏好的奖励函数来优化语言模型。与其他强化学习方法(如PPO或RLHF)不同,GRPO不需要复杂的评判模型和大量计算资源,而是直接优化语言模型,并通过在生成的响应组内计算相对优势来实现目标。

GRPO的关键特点

GRPO的独特之处

GRPO是一种新兴的强化学习技术,相比传统方法具有以下优势:

  • 直接优化:不同于需要独立奖励模型的方法,GRPO直接使用显式奖励函数优化语言模型。
  • 多奖励信号:可以定义多个奖励函数,针对生成内容的不同方面(如正确性、格式、风格)。
  • 探索效率:GRPO通过在训练过程中为每个提示生成多个补全内容,有效探索输出空间。

奖励函数

代码实现了多个协同工作的奖励函数,用于指导模型:

  • correctness_reward_func:当模型提取的答案与真实答案匹配时,奖励2.0分。这是事实正确性的主要学习信号。
  • int_reward_func:当答案是数字时奖励0.5分,适用于数学问题,引导模型生成数值响应。
  • soft_format_reward_func和strict_format_reward_func:奖励正确的XML格式(0.5分),教导模型使用正确的标签结构响应。
  • xmlcount_reward_func:为每个正确使用的XML标签提供部分奖励(每个标签0.125分),形成平滑的学习梯度。

实现组件

  • 奖励函数:根据特定标准评估模型输出:
    • 正确性:检查提取的答案是否与真实答案匹配。
    • 格式遵循:确保响应符合请求的XML格式。
    • 整数检测:奖励数值答案。
  • 数据集准备:使用GSM8K(数学应用题)数据集,并进行特定格式化。
  • 训练配置:使用LoRA进行参数高效微调。

训练过程

  • 对于数据集中的每个提示,模型生成多个补全内容(由num_generations设置,代码中为4)。
  • 每个补全内容由所有奖励函数评估。
  • 奖励用于更新模型权重,鼓励模型生成更高奖励的输出。
  • 此过程持续指定的周期数。

参数高效微调

我们使用LoRA(低秩适应)高效微调模型。LoRA向注意力层添加小型可训练的“适配器”矩阵,大幅减少训练参数数量(通常>99%)。peft_config定义了目标层和适配器的秩。

实施考虑

  • 使用较小的模型(Qwen2.5-1.5B-Instruct)以适应内存限制。
  • 减小批次大小和生成数量以管理内存使用,并使用较小的数据集子集(20个示例)进行快速实验。
  • 测试代码可在训练后立即评估结果。可通过增加max_samples行更全面的训练,或尝试不同的奖励函数。

代码实现

安装所需包

pip install -q transformers datasets trl peft accelerate
import re
import torch
import numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer

定义系统和响应提示

SYSTEM_PROMPT = """  
请按以下格式响应:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"
""

XML_COT_FORMAT = """  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"
""

提取答案的辅助函数

def extract_xml_answer(text: str) -> str:  
    """从XML格式的响应中提取答案部分。"""
    if"<answer>" not in text or "</answer>" not in text:  
        return""
    answer = text.split("</answer>")[-1]  
    answer = answer.split("<answer>")[0]  
    return answer.strip()  

def extract_hash_answer(text: str) -> str:  
    """从GSM8K格式中提取答案(###标记之后)。"""
    if"####" not in text:  
        return""
    return text.split("###")[1].strip().replace(".""").replace("$""")

加载并准备GSM8K数据集

def get_gsm8k_questions(split="train", max_samples=100) -> Dataset:      """  
    加载GSM8K数据集并格式化为GRPO训练所需形式。  
    参数:  
        split: 使用的数据集划分(train, test)  
        max_samples: 使用的最大样本数(用于快速实验)  
    "
""    data = load_dataset('openai/gsm8k''main')[split]      # 限制数据集大小以加快实验      if max_samples and max_samples < len(data):          data = data.select(range(max_samples))      # 格式化数据为所需的提示结构      data = data.map(lambda x: {          'prompt': [              {'role''system''content': SYSTEM_PROMPT},              {'role''user''content': x['question']}          ],          'answer': extract_hash_answer(x['answer'])      })      return data

奖励函数

def correctness_reward_func(prompts, completions, answer, **kwargs):  
    """  
    检查提取的答案是否与真实答案匹配的奖励函数。  
    正确答案返回2.0,否则返回0.0。  
    "
""    responses = [completion[0]['content'for completion in completions]      q = prompts[0][-1]['content']      extracted_responses = [extract_xml_answer(r) for r in responses]      # 打印调试信息      if kwargs.get('debug', False) and len(responses) > 0:          print('-'*20)          print(f"问题:\n{q}")          print(f"\n真实答案:\n{answer[0]}")          print(f"\n模型响应:\n{responses[0]}")          print(f"\n提取的答案:\n{extracted_responses[0]}")      return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]  def int_reward_func(completions, **kwargs) -> list[float]:      """  
    检查提取的答案是否为数字的奖励函数。  
    整数答案返回0.5,否则返回0.0。  
    "
""    responses = [completion[0]['content'for completion in completions]      extracted_responses = [extract_xml_answer(r) for r in responses]      return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]  def strict_format_reward_func(completions, **kwargs) -> list[float]:      """  
    检查补全内容是否完全符合格式的奖励函数。  
    匹配格式返回0.5,否则返回0.0。  
    "
""    pattern = r"^\n<reasoning>.*?</reasoning>\n\n<answer>.*?</answer>\n$"    responses = [completion[0]['content'for completion in completions]      matches = [bool(re.search(pattern, r, flags=re.DOTALL)) for r in responses]      return [0.5 if match else 0.0 for match in matches]  def soft_format_reward_func(completions, **kwargs) -> list[float]:      """  
    宽松的格式检查奖励函数。  
    匹配格式返回0.5,否则返回0.0。  
    "
""    pattern = r"<reasoning>.*?</reasoning>.*?<answer>.*?</answer>"    responses = [completion[0]['content'for completion in completions]      matches = [bool(re.search(pattern, r, flags=re.DOTALL)) for r in responses]      return [0.5 if match else 0.0 for match in matches]  def count_xml(text) -> float:      """  
    统计XML标签并为每个正确放置的标签提供部分奖励。  
    "
""    count = 0.0      if text.count("<reasoning>") == 1:          count += 0.125      if text.count("</reasoning>") == 1:          count += 0.125      if text.count("<answer>") == 1:          count += 0.125      if text.count("</answer>") == 1:          count += 0.125      return count  def xmlcount_reward_func(completions, **kwargs) -> list[float]:      """  
    基于响应中XML标签计数的奖励函数。  
    "
""    contents = [completion[0]["content"for completion in completions]      return [count_xml(c) for c in contents]

模型设置

model_name = "Qwen/Qwen2.5-1.5B-Instruct"  
# 设置输出目录和运行名称  
output_dir = "outputs/Qwen-1.5B-GRPO"  
run_name = "Qwen-1.5B-GRPO-gsm8k"

配置GRPO训练

training_args = GRPOConfig(  
    output_dir=output_dir,      run_name=run_name,      learning_rate=5e-6,      adam_beta1=0.9,      adam_beta2=0.99,      weight_decay=0.1,      warmup_ratio=0.1,      lr_scheduler_type='cosine',      logging_steps=1,      bf16=False,  # 设置为False,因为Colab不支持      fp16=True,   # 使用fp16以提高兼容性      per_device_train_batch_size=4,  # 增加以兼容GRPO      gradient_accumulation_steps=2,      num_generations=4,  # 必须是per_device_train_batch_size的除数      max_prompt_length=256,      max_completion_length=512,      num_train_epochs=1,      save_steps=50,      max_grad_norm=0.1,      report_to="none",      log_on_each_node=False,  )

配置LoRA进行参数高效微调

peft_config = LoraConfig(  
    r=8,  # 从16减少以适应Colab内存  
    lora_alpha=32,  
    target_modules=["q_proj""k_proj""v_proj""o_proj"],  # 简化目标模块  
    task_type="CAUSAL_LM",  
    lora_dropout=0.05,  
)

加载并准备模型

print(f"加载模型:{model_name}")  model = AutoModelForCausalLM.from_pretrained(      model_name,      torch_dtype=torch.float16,  # 使用float16而非bfloat16      device_map="auto",          # 让模型自动选择最佳设备配置      low_cpu_mem_usage=True,    # 提高内存效率      trust_remote_code=True     # 新模型有时需要  )  # 加载分词器  tokenizer = AutoTokenizer.from_pretrained(model_name)  if tokenizer.pad_token is None:      tokenizer.pad_token = tokenizer.eos_token  # 加载GSM8K数据集的子集  print("加载数据集...")  dataset = get_gsm8k_questions(max_samples=20)  print(f"数据集加载完成,共{len(dataset)}个示例")

初始化GRPO训练器

print("初始化GRPO训练器...")  

trainer = GRPOTrainer(  
    model=model,  
    processing_class=tokenizer,  
    reward_funcs=[  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        int_reward_func,  
        correctness_reward_func  
    ],  
    args=training_args,  
    train_dataset=dataset,  
    peft_config=peft_config  # 启用LoRA进行高效微调  
)

运行GRPO

# 开始训练  
print("开始GRPO训练...")  trainer.train()  # 保存最终模型  print("训练完成。保存模型...")  trainer.save_model()  # 训练后测试模型  print("\n--- 测试训练后的模型 ---\n")  # 生成预测的函数  def generate_prediction(model, tokenizer, question, max_length=512):      prompt = [          {'role''system''content': SYSTEM_PROMPT},          {'role''user''content': question}      ]      # 格式化提示      messages = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)      # 分词输入      inputs = tokenizer(messages, return_tensors="pt").to(model.device)      # 生成响应      with torch.no_grad():          outputs = model.generate(              **inputs,              max_new_tokens=max_length,              do_sample=False          )      # 解码响应      response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)      return response  # 测试数据集中的几个示例  test_examples = dataset.select(range(3))  for i, example in enumerate(test_examples):      question = example['prompt'][-1]['content']      ground_truth = example['answer']      print(f"\n示例 {i+1}:")      print(f"问题:{question}")      print(f"真实答案:{ground_truth}")      # 生成预测      response = generate_prediction(model, tokenizer, question)      print(f"模型响应:{response}")      # 提取答案      extracted_answer = extract_xml_answer(response)      print(f"提取的答案:{extracted_answer}")      print(f"是否正确:{extracted_answer == ground_truth}")      print("-" * 50)  print("搞定!")

总结

这一实现展示了GRPO的工作原理,以及如何利用它优化语言模型以适应特定格式和任务。数学问题解决任务与XML格式的结合,清晰地体现了该技术的能力。

真是一次有趣的实践!

作者:arjun链接:https://www.k-a.in/grpo-1B.html

编辑:AI翻译、「深度学习自然语言处理」公众号润色


🏴‍☠️宝藏级🏴‍☠️ 原创公众号『数据STUDIO』内容超级硬核。公众号以Python为核心语言,垂直于数据科学领域,包括可戳👉 PythonMySQL数据分析数据可视化机器学习与数据挖掘爬虫 等,从入门到进阶!

长按👇关注- 数据STUDIO -设为星标,干货速递


原文链接