Open main menu
首页
专栏
课程
分类
归档
Chat
Sci-Hub
谷歌学术
Libgen
GitHub镜像
登录/注册
搜索
关闭
Previous
Previous
Next
Next
ChatGPT 中的人类反馈强化学习 (RLHF) 实战
sockstack
/
1322
/
2024-02-28 00:02:38
<p><span style="color: red; font-size: 18px">ChatGPT 可用网址,仅供交流学习使用,如对您有所帮助,请收藏并推荐给需要的朋友。</span><br><a href="https://ckai.xyz/?sockstack§ion=detail" target="__blank">https://ckai.xyz</a><br><br></p> <article class="baidu_pl"><div id="article_content" class="article_content clearfix"> <link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/kdoc_html_views-1a98987dfd.css"> <link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/ck_htmledit_views-25cebea3f9.css"> <div id="content_views" class="markdown_views prism-atom-one-dark"> <svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg><p></p> <div class="toc"> <h3>目录</h3> <ul> <li>1 前言</li> <li>2 人类反馈强化学习 (RLHF)</li> <li> <ul> <li>2.1 奖励模型 (RM)</li> <li>2.2 近端策略优化算法 (PPO)</li> </ul> </li> <li>3 总结</li> <li>4 参考</li> </ul> </div> <p></p> <hr> <p>团队博客: CSDN AI小组</p> <hr> <p>相关阅读</p> <ul> <li>ChatGPT 简介</li> <li>大语言模型浅探一</li> <li>关于 ChatGPT 必看的 10 篇论文</li> <li>从 ELMo 到 ChatGPT:历数 NLP 近 5 年必看大模型</li> </ul> <hr> <h1> <a id="1__12"></a>1 前言</h1> <p>在当今数字化的时代,ChatGPT 的火热程度不断升级。ChatGPT 可以处理复杂的语言任务,从而解放人力资源,提高工作效率,减少成本。ChatGPT 的先进技术和广泛应用,使得它成为了当今最炙手可热的人工智能技术之一。无论是企业、学术机构,还是科技爱好者,都对 ChatGPT 的应用前景充满期待。</p> <p>在这样的背景之下,CSDN AI 团队也想对 ChatGPT 进行简单的复现。根据ChatGPT官方博客可知,ChatGPT的训练方法与 InstructGPT 的训练方法基本一致 (如图1所示),只是使用的数据集不一样。故在训练方法上,我们主要参考 InstructGPT 进行复现,基础模型使用的是 RWKV,拆分后共包含以下四个阶段:</p> <ul> <li>(1) 语言模型预训练 (Language Model Pre-training);</li> <li>(2) 有监督指令微调 (Supervised Fine-Tuning, SFT);</li> <li>(3) 奖励模型的训练 (Reward Modeling, RM);</li> <li>(4) 使用近端策略优化算法进行强化学习 (Proximal Policy Optimization, PPO).</li> </ul> <p>第 (1)、(2) 阶段的 Pre-training 和 SFT 由 @zxm2015 完成,可参考文章大语言模型浅探一。本文主要介绍第 (3)、(4) 阶段的内容,即人类反馈强化学习 (Reinforcement Learning from Human Feedback, RLHF)。</p> <p><img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/5521856c7faf4ca19b802fae72d525ba.png#pic_center" alt="在这里插入图片描述"></p> <center>图1 InstructGPT 模型的训练过程 </center> <h1> <a id="2__RLHF_27"></a>2 人类反馈强化学习 (RLHF)</h1> <p>人类反馈强化学习 (RLHF) 是 ChatGPT 中一种用于改善其回答效果的算法。它是一种基于强化学习的方法,通过结合人类反馈来优化 ChatGPT 的回答。</p> <p>在 RLHF 中,ChatGPT 学习通过和人类用户的交互来提高其回答的质量。当 ChatGPT 生成一个回答时,它会将回答展示给用户并请求用户的反馈。用户可以对回答进行评分,比如“好”、“不错”、“一般”、“差”等。ChatGPT 会将用户的反馈作为奖励或惩罚信号,以此来更新自己的模型,以更好地满足用户的需求。</p> <p>RLHF 可分为两个部分。第一部分是奖励模型,人类反馈主要就体现在这个地方;第二部分采用近端策略优化算法的强化学习阶段,基于奖励模型的反馈来优化模型,最终得到满足人类偏好的语言模型。下面将对这两个部分进行详细的说明。</p> <h2> <a id="21__RM_33"></a>2.1 奖励模型 (RM)</h2> <p>在 RLHF 之前,语言模型已经进行了 SFT (后续称该模型为 SFT Model),而奖励模型的任务主要是对 SFT Model 的回复进行打分,打分越高表示回答效果越好。训练好奖励模型之后,就可以用于下一阶段的 PPO 进行强化学习的调优,奖励模型是 PPO 中的一个子部分,用于 PPO 训练时提供奖励信号。</p> <p><strong>(1) 模型的输入输出</strong><br> 模型的输入是用户提问 (Prompt) 和 SFT Model 回复 (Response) 的 pair 对 <Prompt, Response>,输出是一个奖励得分,如下图所示:</p> <p><img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/1c61ac4efd174f3fa73cdf063f24ddd8.png#pic_center" alt="在这里插入图片描述"></p> <center>图2 RM 的输入和输出 </center> <p><strong>(2) 数据集的构建</strong><br> 这个阶段主要是通过人工标注训练数据,来训练 RM,人类反馈就体现在这个地方。在 Prompts 数据集中随机抽取问题,对于每个问题,生成 K 个不同的回答。人类标注者对这些结果综合考虑(例如:相关性、富含信息性、有害信息等诸多标准)给出排名顺序。</p> <p>按照上述奖励模型的输入输出描述,构建数据集时应该是人工对 <Prompt, Response> 进行打分,但实际上对多个回答进行打分比较困难,得分是连续的,这会降低标注的速度。此外,我们其实关注的是多个选项之间哪个更好,哪个更差。所以标注的时候对多个选项进行排序就可以了,最后基于排序后的回答,构建数据集,选用合适的损失函数即可。</p> <p>通常情况下,人类进行排序任务,当选项为 4-9 个 (即 K∈{4, 5, 6, 7, 8, 9}) 时速度最快且效果最准确,此处我们设定 K=4。最终一个 Prompt 我们就可以得到 C(4, 2)=6 条训练样本。</p> <p>具体而言,假设我们选定了一个问题 x,接着使用 SFT Model 生成了 4 个回答 {y1, y2, y3, y4},人类标注者进行排序后为 y4 > y3 > y1 > y2},则得到的训练样本如下所示,左边<Prompt, Response>的得分要高于右边:</p> <blockquote> <p>(<x, y4>, <x, y3>)<br> (<x, y4>, <x, y1>)<br> (<x, y4>, <x, y2>)<br> (<x, y3>, <x, y1>)<br> (<x, y3>, <x, y2>)<br> (<x, y1>, <x, y2>)</p> </blockquote> <p><strong>(3) 损失函数</strong><br> 根据上面构建的数据集可知,我们没有连续的得分目标去训练奖励模型,但是有正负例样本对,所以损失函数如下所示,该损失函数需要最小化:<br> <img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/5803c4c1582e4656964011c9aa040ea2.png#pic_center" alt="在这里插入图片描述"><br> 其中,r(x,y) 为 <x, y> 输入到 RM 模型的得分,θ 是 RM 的参数,yw 和 yl 是输入为 x 时,SFT Model 生成的不同回答,其中人工标注时 yw > yl.</p> <pre><code class="prism language-python"><span class="token comment"># loss function</span> <span class="token keyword">def</span> <span class="token function">loss_function</span><span class="token punctuation">(</span>prefer_reward<span class="token punctuation">,</span> alter_reward<span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token keyword">return</span> <span class="token operator">-</span>torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>log<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>prefer_reward <span class="token operator">-</span> alter_reward<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> </code></pre> <p><strong>(4) 核心代码</strong><br> RM 的网络结构相比于 SFT Model,并不需要做太大的改动,输入 <Prompt, Response> 后,直接取最后一个 token 的 embedding,在其后面接一个线性层计算奖励得分即可</p> <p>a) 线性层:</p> <pre><code class="prism language-python"><span class="token comment"># reward 得分计算</span> self<span class="token punctuation">.</span>pred_reward <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>dim<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> </code></pre> <p>b) forword 函数</p> <pre><code class="prism language-python"> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span>x<span class="token punctuation">,</span>mask <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">,</span>prompt_mask <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">,</span>prompt_lengths <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token comment"># prompt_mask 和 prompt_lengths 只能二选一</span><span class="token keyword">assert</span> <span class="token keyword">not</span> <span class="token punctuation">(</span>exists<span class="token punctuation">(</span>prompt_mask<span class="token punctuation">)</span> <span class="token keyword">and</span> exists<span class="token punctuation">(</span>prompt_lengths<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment"># derive prompt mask from prompt lengths</span><span class="token keyword">if</span> exists<span class="token punctuation">(</span>prompt_lengths<span class="token punctuation">)</span><span class="token punctuation">:</span>batch<span class="token punctuation">,</span> seq_len <span class="token operator">=</span> x<span class="token punctuation">.</span>shapearange <span class="token operator">=</span> torch<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>seq_len<span class="token punctuation">,</span> device<span class="token operator">=</span>x<span class="token punctuation">.</span>device<span class="token punctuation">)</span>prompt_mask <span class="token operator">=</span> repeat<span class="token punctuation">(</span>arange<span class="token punctuation">,</span> <span class="token string">'n -> b n'</span><span class="token punctuation">,</span> b <span class="token operator">=</span> batch<span class="token punctuation">)</span> <span class="token operator"><</span> rearrange<span class="token punctuation">(</span>prompt_lengths<span class="token punctuation">,</span> <span class="token string">'b -> b 1'</span><span class="token punctuation">)</span><span class="token comment"># reward model should have an understanding of which section is prompt, and which section is response</span><span class="token comment"># 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值</span><span class="token comment"># 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选</span>prompt_response_mask_embed <span class="token operator">=</span> torch<span class="token punctuation">.</span>stack<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>prompt_embed<span class="token punctuation">,</span>self<span class="token punctuation">.</span>response_embed<span class="token punctuation">,</span>self<span class="token punctuation">.</span>padding_embed<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>prompt_mask<span class="token punctuation">.</span>device<span class="token punctuation">)</span>extra_embed <span class="token operator">=</span> <span class="token boolean">None</span><span class="token keyword">if</span> exists<span class="token punctuation">(</span>prompt_mask<span class="token punctuation">)</span><span class="token punctuation">:</span>extra_embed <span class="token operator">=</span> prompt_response_mask_embed<span class="token punctuation">[</span>prompt_mask<span class="token punctuation">]</span> <span class="token comment"># 获得最后一个 token 的 embedding</span>last_token_embeds <span class="token operator">=</span> self<span class="token punctuation">.</span>rwkv<span class="token punctuation">(</span>x<span class="token punctuation">,</span>extra_embed<span class="token operator">=</span>extra_embed<span class="token punctuation">,</span>rm_train<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token comment"># 计算奖励</span>reward <span class="token operator">=</span> self<span class="token punctuation">.</span>pred_reward<span class="token punctuation">(</span>last_token_embeds<span class="token punctuation">)</span>reward <span class="token operator">=</span> reward<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token keyword">return</span> reward </code></pre> <p>c) train_forward 函数</p> <pre><code class="prism language-python"> <span class="token keyword">def</span> <span class="token function">train_forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x_p<span class="token punctuation">,</span> x_a<span class="token punctuation">,</span> m_p<span class="token punctuation">,</span> m_a<span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token comment"># 因为前向传播的时候,需要过两次模型。所以反馈的时候需要冻结其中一次的参数</span><span class="token comment"># 不然梯度会被计算两次,在包含 deepspeed 框架下会报错</span><span class="token comment"># 报错信息:Gradient computed twice for this partition.</span><span class="token keyword">with</span> torch<span class="token punctuation">.</span>enable_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>prefer_reward <span class="token operator">=</span> self<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>x_p<span class="token punctuation">,</span> prompt_mask<span class="token operator">=</span>m_p<span class="token punctuation">)</span><span class="token keyword">with</span> torch<span class="token punctuation">.</span>no_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>alter_reward <span class="token operator">=</span> self<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>x_a<span class="token punctuation">,</span> prompt_mask<span class="token operator">=</span>m_a<span class="token punctuation">)</span><span class="token keyword">return</span> prefer_reward<span class="token punctuation">,</span> alter_reward </code></pre> <h2> <a id="22__PPO_135"></a>2.2 近端策略优化算法 (PPO)</h2> <p>近端策略优化算法(Proximal Policy Optimization, PPO)是一种深度强化学习算法,其目标是学习一个能够最大化长期累积回报的策略。</p> <p><img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/4dc66ae3af7047939cf9546760f25a23.png#pic_center" alt="在这里插入图片描述"></p> <center>图3 PPO 训练架构详细版本 </center> <p><strong>(1) PPO算法包含以下几个主要部分:</strong></p> <ul> <li> <p>a) 策略网络 (Policy Network)<br> 用于学习并输出给定状态下不同行动的概率分布。它通常是一个神经网络,可以根据环境的反馈进行更新。对应图3中的 Actor,使用 SFT Model 进行初始化,在 PPO 中需要参与训练。</p> </li> <li> <p>b) 价值网络 (Value Network)<br> 用于预测给定状态的预期回报值。它通常也是一个神经网络,它的输出可以用来计算优势函数,从而帮助更新策略网络。对应图3中的 Critic,使用 RM 进行初始化,在 PPO 中需要参与训练。</p> </li> <li> <p>c) 奖励模型<br> 对应图3中的 Reward Model,是 2.1 节中训练得到的模型,在 PPO 中不参与训练,只提供奖励信号,用于 PPO 的训练。</p> </li> <li> <p>d) SFT Model<br> 对应图3中的 Supervised Fine-Tune Model,用于更新策略网络,以使其能够产生更好的策略。通过限制每次更新的幅度,从而确保更新后的策略与原始策略之间的差异不会太大。该部分可以参与训练,也可以不参与,当参与训练时,PPO 被称为 PPO-ptx。</p> </li> <li> <p>e) 经验采样<br> 用于收集与环境交互的经验数据,以供策略网络和价值网络的更新使用。在PPO算法中,经验采样通常采用基于行动价值估计的策略。对应图3中顶部的 Prompts -> Actor -> Response 流程。</p> </li> </ul> <p><img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/d1bf5d2826ce4c298bcbfb6990ebe034.png#pic_center" alt="在这里插入图片描述"></p> <center>图4 PPO 训练架构简化版本 </center> <p><strong>(2)损失函数</strong></p> <ul> <li>a) actor loss (也称为 policy loss, 是最终要使用模型的 loss)<br> <img referrerpolicy="no-referrer" src="https://img-blog.csdnimg.cn/df8de2fdb4c14f4cae67450162db13f9.png#pic_center" alt="在这里插入图片描述"><br> 其中,πRL 是 actor,πSFT 是已经训练好的 SFT Model。损失函数的第1项和第2项是核心部分,第3项是可选项。该损失函数需要最大化。具体如下: <ul> <li>第一项:这一项是奖励模型 RM 奖励得分,奖励需要最大化;</li> <li>第二项:这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值;</li> <li>第三项:这一项是预训练梯度 (可选项),传统的 PPO 中一般不包含该项,InstructGPT 中加入这一项是为了避免 RLHF 导致大模型在公开的 NLP 评测任务上效果下降。加入该项之后被命名为 PPO-ptx。</li> </ul> </li> <li>b) critic loss (也称为 value loss)<br> 使用的是 clipped_value_loss。</li> </ul> <p><strong>(3)核心代码</strong><br> a) training_step</p> <pre><code class="prism language-python"> <span class="token keyword">def</span> <span class="token function">training_step</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> batch<span class="token punctuation">,</span> batch_idx<span class="token punctuation">,</span> optimizer_idx<span class="token punctuation">)</span><span class="token punctuation">:</span>sequences<span class="token punctuation">,</span> \prompt_masks<span class="token punctuation">,</span> \masks<span class="token punctuation">,</span> \old_action_probs<span class="token punctuation">,</span> \old_log_probs<span class="token punctuation">,</span> \rewards<span class="token punctuation">,</span> \old_values <span class="token operator">=</span> batch<span class="token comment"># PPO training</span>action_masks <span class="token operator">=</span> <span class="token operator">~</span>prompt_masks <span class="token operator">&</span> masksaction_logits<span class="token punctuation">,</span> values <span class="token operator">=</span> self<span class="token punctuation">.</span>actor_critic<span class="token punctuation">(</span>sequences<span class="token punctuation">,</span>mask <span class="token operator">=</span> action_masks<span class="token punctuation">)</span>action_logits <span class="token operator">=</span> shift<span class="token punctuation">(</span>action_logits<span class="token punctuation">,</span> shift<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># need to shift along sequence dimension by 1, since actions start from the last prompt (state) token</span>action_len <span class="token operator">=</span> old_log_probs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>action_probs <span class="token operator">=</span> action_logits<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>action_log_probs <span class="token operator">=</span> log_prob<span class="token punctuation">(</span>action_probs<span class="token punctuation">,</span> sequences<span class="token punctuation">)</span>action_log_probs <span class="token operator">=</span> action_log_probs<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span>action_len<span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token comment"># calculate entropies, taking into account which part of the sequence is actually an action</span>entropies <span class="token operator">=</span> masked_entropy<span class="token punctuation">(</span>action_probs<span class="token punctuation">,</span> mask <span class="token operator">=</span> action_masks<span class="token punctuation">)</span><span class="token comment"># calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not</span>kl_div_loss <span class="token operator">=</span> <span class="token number">0.</span><span class="token keyword">if</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>kl_div_loss_weight <span class="token operator">></span> <span class="token number">0</span><span class="token punctuation">:</span>kl_div_loss <span class="token operator">=</span> masked_kl_div<span class="token punctuation">(</span>action_probs<span class="token punctuation">,</span> old_action_probs<span class="token punctuation">,</span> mask <span class="token operator">=</span> action_masks<span class="token punctuation">)</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>kl_div_loss_weight<span class="token comment"># handle non-pooled values</span>normalize_kwargs <span class="token operator">=</span> <span class="token builtin">dict</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token keyword">if</span> old_values<span class="token punctuation">.</span>ndim <span class="token operator">==</span> <span class="token number">2</span><span class="token punctuation">:</span>old_values<span class="token punctuation">,</span> values <span class="token operator">=</span> <span class="token builtin">map</span><span class="token punctuation">(</span><span class="token keyword">lambda</span> t<span class="token punctuation">:</span> shift<span class="token punctuation">(</span>t<span class="token punctuation">,</span> shift <span class="token operator">=</span> <span class="token number">1</span><span class="token punctuation">,</span> dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span>old_values<span class="token punctuation">,</span> values<span class="token punctuation">)</span><span class="token punctuation">)</span>old_values <span class="token operator">=</span> old_values<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span>action_len<span class="token punctuation">:</span><span class="token punctuation">]</span>values <span class="token operator">=</span> values<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span>action_len<span class="token punctuation">:</span><span class="token punctuation">]</span>rewards <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>rewards<span class="token punctuation">,</span> <span class="token string">'b -> b 1'</span><span class="token punctuation">)</span>normalize_kwargs <span class="token operator">=</span> <span class="token builtin">dict</span><span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> mask <span class="token operator">=</span> action_masks<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span>action_len<span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token keyword">if</span> values<span class="token punctuation">.</span>ndim <span class="token operator"><</span> rewards<span class="token punctuation">.</span>ndim<span class="token punctuation">:</span>values <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>values<span class="token punctuation">,</span> <span class="token string">'... -> ... 1'</span><span class="token punctuation">)</span><span class="token comment"># calculate clipped surrogate objective, classic PPO loss</span>ratios <span class="token operator">=</span> <span class="token punctuation">(</span>action_log_probs <span class="token operator">-</span> old_log_probs<span class="token punctuation">)</span><span class="token punctuation">.</span>exp<span class="token punctuation">(</span><span class="token punctuation">)</span>advantages <span class="token operator">=</span> masked_normalize<span class="token punctuation">(</span>rewards <span class="token operator">-</span> old_values<span class="token punctuation">,</span> <span class="token operator">**</span>normalize_kwargs<span class="token punctuation">)</span><span class="token keyword">if</span> advantages<span class="token punctuation">.</span>ndim <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">:</span>advantages <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>advantages<span class="token punctuation">,</span> <span class="token string">'b -> b 1'</span><span class="token punctuation">)</span>surr1 <span class="token operator">=</span> ratios <span class="token operator">*</span> advantagessurr2 <span class="token operator">=</span> ratios<span class="token punctuation">.</span>clamp<span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>eps_clip<span class="token punctuation">,</span> <span class="token number">1</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>eps_clip<span class="token punctuation">)</span> <span class="token operator">*</span> advantagespolicy_loss <span class="token operator">=</span> <span class="token operator">-</span> torch<span class="token punctuation">.</span><span class="token builtin">min</span><span class="token punctuation">(</span>surr1<span class="token punctuation">,</span> surr2<span class="token punctuation">)</span> <span class="token operator">-</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>beta_s <span class="token operator">*</span> entropies<span class="token comment"># actor loss (也称为 policy loss, 是最终要使用模型的 loss)</span><span class="token keyword">if</span> optimizer_idx <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>actor_loss <span class="token operator">=</span> policy_loss<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">+</span> kl_div_loss<span class="token keyword">return</span> actor_loss<span class="token comment"># critic loss (也称为 value loss)</span><span class="token comment"># update value network separate from policy network</span><span class="token keyword">if</span> optimizer_idx <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">:</span>critic_loss <span class="token operator">=</span> clipped_value_loss<span class="token punctuation">(</span>values<span class="token punctuation">,</span> rewards<span class="token punctuation">,</span> old_values<span class="token punctuation">,</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>value_clip<span class="token punctuation">)</span>critic_loss <span class="token operator">=</span> critic_loss<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token keyword">return</span> critic_loss </code></pre> <p>b) gen_experience_dataset</p> <pre><code class="prism language-python"> <span class="token keyword">def</span> <span class="token function">gen_experience_dataset</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token triple-quoted-string string">''' 通过与 environment 交互产生训练数据'''</span>device <span class="token operator">=</span> self<span class="token punctuation">.</span>devicetime_cnt <span class="token operator">=</span> <span class="token number">0</span><span class="token keyword">for</span> eps <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>num_episodes<span class="token punctuation">)</span><span class="token punctuation">,</span> desc <span class="token operator">=</span> <span class="token string">'episodes'</span><span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token keyword">for</span> timestep <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>max_timesteps<span class="token punctuation">)</span><span class="token punctuation">:</span>time_cnt <span class="token operator">+=</span> <span class="token number">1</span><span class="token comment"># select a bunch of random states (prompts)</span><span class="token comment"># and get the action (sampled sequence from rwkv as well as the action probs)</span><span class="token comment"># also calculate the reward using reward model and store</span><span class="token comment"># 随机挑选一条 prompt</span>rand_prompt_index <span class="token operator">=</span> randrange<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>prompts<span class="token punctuation">)</span><span class="token punctuation">)</span>state <span class="token operator">=</span> self<span class="token punctuation">.</span>prompts<span class="token punctuation">[</span>rand_prompt_index<span class="token punctuation">]</span><span class="token comment"># remove padding from state</span>state_mask <span class="token operator">=</span> state <span class="token operator">!=</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>pad_valuestate <span class="token operator">=</span> state<span class="token punctuation">[</span>state_mask<span class="token punctuation">]</span><span class="token comment"># get predicted sequence</span><span class="token comment"># 与 environment 进行交互,其中返回的:</span><span class="token comment"># action 是 response,</span><span class="token comment"># sequence 是 prompt + response, </span><span class="token punctuation">(</span>actions<span class="token punctuation">,</span>sequence<span class="token punctuation">,</span>mask<span class="token punctuation">,</span>prompt_mask<span class="token punctuation">,</span>action_logits<span class="token punctuation">,</span>value<span class="token punctuation">)</span> <span class="token operator">=</span> self<span class="token punctuation">.</span>actor_critic<span class="token punctuation">.</span>generate<span class="token punctuation">(</span>rearrange<span class="token punctuation">(</span>state<span class="token punctuation">,</span> <span class="token string">'n -> 1 n'</span><span class="token punctuation">)</span><span class="token punctuation">,</span>max_seq_len <span class="token operator">=</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>ctx_len<span class="token punctuation">,</span>return_values <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>action_logits <span class="token operator">=</span> shift<span class="token punctuation">(</span>action_logits<span class="token punctuation">,</span> shift <span class="token operator">=</span> <span class="token number">1</span><span class="token punctuation">,</span> dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># need to shift along sequence dimension by 1, since actions start from the last prompt (state) token</span>action_prob <span class="token operator">=</span> action_logits<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>action_len <span class="token operator">=</span> actions<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>action_log_prob <span class="token operator">=</span> log_prob<span class="token punctuation">(</span>action_prob<span class="token punctuation">,</span> sequence<span class="token punctuation">)</span>action_log_prob <span class="token operator">=</span> action_log_prob<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span>action_len<span class="token punctuation">:</span><span class="token punctuation">]</span>actions <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>actions<span class="token punctuation">,</span> <span class="token string">'1 ... -> ...'</span><span class="token punctuation">)</span><span class="token comment"># get reward as given by supervised trained reward model</span>sequence <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>state<span class="token punctuation">,</span> actions<span class="token punctuation">)</span><span class="token punctuation">,</span> dim <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">)</span>prompt_length <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>state<span class="token punctuation">)</span>prompt_mask <span class="token operator">=</span> torch<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>sequence<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> device <span class="token operator">=</span> device<span class="token punctuation">)</span> <span class="token operator"><</span> prompt_lengthsequence <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>sequence<span class="token punctuation">,</span> <span class="token string">'n -> 1 n'</span><span class="token punctuation">)</span>prompt_mask <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>prompt_mask<span class="token punctuation">,</span> <span class="token string">'n -> 1 n'</span><span class="token punctuation">)</span>mask <span class="token operator">=</span> rearrange<span class="token punctuation">(</span>mask<span class="token punctuation">,</span> <span class="token string">'n -> 1 n'</span><span class="token punctuation">)</span> <span class="token keyword">if</span> exists<span class="token punctuation">(</span>mask<span class="token punctuation">)</span> <span class="token keyword">else</span> torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>sequence<span class="token punctuation">.</span>shape<span class="token punctuation">,</span> dtype <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">bool</span><span class="token punctuation">,</span> device <span class="token operator">=</span> device<span class="token punctuation">)</span>reward <span class="token operator">=</span> self<span class="token punctuation">.</span>reward_model<span class="token punctuation">(</span>sequence<span class="token punctuation">,</span>prompt_mask <span class="token operator">=</span> prompt_mask<span class="token punctuation">,</span>mask <span class="token operator">=</span> mask<span class="token punctuation">,</span>sample <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>sequence_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>sequence<span class="token punctuation">)</span>self<span class="token punctuation">.</span>prompt_mask_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>prompt_mask<span class="token punctuation">)</span>self<span class="token punctuation">.</span>mask_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>mask<span class="token punctuation">)</span>self<span class="token punctuation">.</span>action_prob_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>action_prob<span class="token punctuation">)</span>self<span class="token punctuation">.</span>action_log_prob_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>action_log_prob<span class="token punctuation">)</span>self<span class="token punctuation">.</span>reward_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>reward<span class="token punctuation">)</span>self<span class="token punctuation">.</span>value_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>value<span class="token punctuation">)</span><span class="token keyword">if</span> time_cnt <span class="token operator">%</span> self<span class="token punctuation">.</span>args<span class="token punctuation">.</span>update_timesteps <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>train_data <span class="token operator">=</span> <span class="token builtin">zip</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>sequence_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>prompt_mask_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>mask_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>action_prob_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>action_log_prob_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>reward_batch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>value_batch<span class="token punctuation">)</span><span class="token keyword">for</span> _sequence<span class="token punctuation">,</span> _prompt_mask<span class="token punctuation">,</span> _mask<span class="token punctuation">,</span> _action_prob<span class="token punctuation">,</span> _action_log_prob<span class="token punctuation">,</span> _reward<span class="token punctuation">,</span> _value <span class="token keyword">in</span> train_data<span class="token punctuation">:</span><span class="token keyword">yield</span> _sequence<span class="token punctuation">,</span> _prompt_mask<span class="token punctuation">,</span> _mask<span class="token punctuation">,</span> _action_prob<span class="token punctuation">,</span> _action_log_prob<span class="token punctuation">,</span> _reward<span class="token punctuation">,</span> _valueself<span class="token punctuation">.</span>sequence_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>prompt_mask_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>mask_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>action_prob_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>action_log_prob_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>reward_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span>self<span class="token punctuation">.</span>value_batch<span class="token punctuation">.</span>clear<span class="token punctuation">(</span><span class="token punctuation">)</span> </code></pre> <h1> <a id="3__343"></a>3 总结</h1> <p>RLHF 可以根据用户反馈不断学习和优化对话,从而提高对话的质量和效果。但是由于算力资源的限制,我们只是简单调试并拉通了 RLHF 的训练流程,暂未在实际的数据集上训练模型。如若有纰漏指出,还请指正,感谢!</p> <h1> <a id="4__345"></a>4 参考</h1> <p>[1] InstructGPT<br> [2] ChatGPT 背后的“功臣”——RLHF 技术详解<br> [3] ColossalAI<br> [4] PaLM-rlhf-pytorch<br> [5] Promixal Policy Optimization with PyTorch<br> [6] How ChatGPT Works Part 2: The Reward Model</p> </div> <link href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/markdown_views-98b95bb57c.css" rel="stylesheet"> <link href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/style-c216769e99.css" rel="stylesheet"> </div> <div id="treeSkill"></div> </article>
ChatGPT 中的人类反馈强化学习 (RLHF) 实战
作者
sockstack
许可协议
CC BY 4.0
发布于
2024-02-28
修改于
2024-12-27
上一篇:软件:常用 Linux 软件汇总,值得收藏
下一篇:【生活工作经验 十】ChatGPT模型对话初探
尚未登录
登录 / 注册
文章分类
博客重构之路
5
Spring Boot简单入门
4
k8s 入门教程
0
MySQL 知识
1
NSQ 消息队列
0
ThinkPHP5 源码分析
5
使用 Docker 从零开始搭建私人代码仓库
3
日常开发汇总
4
标签列表
springboot
hyperf
swoole
webman
php
多线程
数据结构
docker
k8s
thinkphp
mysql
tailwindcss
flowbite
css
前端