跳转至

Actor-Critic 方法

REINFORCE 方差太大,DQN 只能处理离散动作。Actor-Critic 将两者的优点结合在一起:Actor(演员)负责选动作,Critic(评论员)负责打分——实现低方差、可处理连续动作的强化学习架构。


一、从 REINFORCE 到 Actor-Critic

1.1 REINFORCE 的问题回顾

REINFORCE 用完整轨迹的回报 \(G_t\) 作为策略梯度的信号:

\[ \nabla_\theta J \approx \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t \]

\(G_t\) 方差大 → 梯度估计噪声大 → 训练不稳定、收敛慢。

1.2 引入 Critic 的动机

使用基线 \(V(s_t)\) 可以降低方差。但如果 \(V(s_t)\) 也用蒙特卡洛估计,那方差问题又回来了。

解决方案:用一个独立的神经网络来学习 \(V(s_t)\)——这就是 Critic。


二、Actor-Critic 架构

2.1 两个角色

角色 网络 职责 类比
Actor(演员) 策略网络 \(\pi_\theta(a \mid s)\) 根据状态选择动作 演员在舞台上表演
Critic(评论员) 价值网络 \(V_\phi(s)\) 评估当前状态有多好 评论家给表演打分
状态 s ──→ Actor网络 ──→ 动作概率 π(a|s)
  |
  └──→ Critic网络 ──→ 状态价值 V(s)

2.2 优势函数(Advantage Function)

Actor 更新时使用的信号不再是 \(G_t\),而是优势函数 \(A(s_t, a_t)\)

\[ A(s_t, a_t) = Q(s_t, a_t) - V(s_t) \]

直觉理解

  • \(V(s_t)\):在状态 \(s_t\) 下"平均水平"能拿到多少分
  • \(Q(s_t, a_t)\):在状态 \(s_t\) 下执行动作 \(a_t\) 能拿到多少分
  • \(A(s_t, a_t)\)这个动作比平均水平好多少

\(A > 0\) → 动作比平均好 → 让 Actor 增大这个动作的概率
\(A < 0\) → 动作不如平均 → 让 Actor 减小这个动作的概率

2.3 TD 优势估计

在实践中,不需要单独估计 \(Q\),可以用 TD 误差来近似优势函数:

\[ A(s_t, a_t) \approx \delta_t = r_{t+1} + \gamma V_\phi(s_{t+1}) - V_\phi(s_t) \]
  • \(r_{t+1} + \gamma V_\phi(s_{t+1})\):TD 目标(Critic 对实际回报的估计)
  • \(V_\phi(s_t)\):Critic 对当前状态的估计
  • 差值 \(\delta_t\):Critic "意外"程度——实际比预期好还是差

三、A2C(Advantage Actor-Critic)

3.1 完整算法

初始化 Actor 网络 πθ,Critic 网络 Vφ

循环每一步:
  1. 观察状态 s
  2. Actor 选动作:a ~ πθ(·|s)
  3. 执行 a,得到 r, s', done
  4. 计算 TD 目标:
     若 done: y = r
     否则:   y = r + γ·Vφ(s')
  5. 计算优势:δ = y - Vφ(s)

  6. 更新 Critic(最小化 TD 误差):
     Lcritic = δ² = (y - Vφ(s))²
     φ ← φ - αc·∇φ Lcritic

  7. 更新 Actor(策略梯度):
     Lactor = -log πθ(a|s)·δ
     θ ← θ - αa·∇θ Lactor

3.2 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        # 共享特征提取层
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU()
        )
        # Actor 头(输出动作概率)
        self.actor = nn.Sequential(
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
        # Critic 头(输出状态价值)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        features = self.shared(x)
        return self.actor(features), self.critic(features)

model = ActorCritic(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for episode in range(num_episodes):
    state = env.reset()
    done = False

    while not done:
        state_tensor = torch.FloatTensor(state)
        probs, value = model(state_tensor)

        # Actor 选动作
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        # 执行
        next_state, reward, done, _ = env.step(action.item())

        # Critic 评估
        _, next_value = model(torch.FloatTensor(next_state))
        target = reward + gamma * next_value * (1 - done)
        advantage = target - value

        # 联合损失
        actor_loss = -log_prob * advantage.detach()  # 注意 detach!
        critic_loss = advantage.pow(2)
        loss = actor_loss + 0.5 * critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        state = next_state

关键细节:advantage.detach()

计算 Actor 损失时,必须将 advantage 从计算图中分离(detach),否则 Actor 的梯度会"穿透"到 Critic 中,导致 Critic 被错误更新。

3.3 共享网络 vs 独立网络

Actor 和 Critic 共享底层特征提取,只在输出头分叉。

优点:参数少,特征共享,训练快
缺点:两个任务可能"抢"特征,互相干扰

Actor 和 Critic 各有独立的完整网络。

优点:互不干扰,更稳定
缺点:参数多,训练稍慢


四、A3C(Asynchronous Advantage Actor-Critic)

4.1 核心思想:异步并行

A3C 的突破不在算法本身,而在训练方式:同时开多个环境副本并行训练

       全局网络(参数 θ, φ)
      ↗     ↑      ↑      ↖
  Worker1  Worker2 Worker3 Worker4
  (环境1)  (环境2) (环境3) (环境4)

4.2 每个 Worker 的流程

循环:
  1. 从全局网络复制参数到本地
  2. 在自己的环境中交互 n 步
  3. 计算本地梯度
  4. 将梯度异步推送到全局网络,更新全局参数

4.3 为什么并行有效?

好处 解释
打破数据相关性 不同 Worker 在不同环境中独立探索,产生的数据天然不相关
探索多样性 每个 Worker 走不同的路线,能更全面地探索环境
不需要经验回放 并行本身就解决了数据相关性问题(A3C 诞生的初衷之一)
加速训练 多个 CPU 并行,线性提速

A3C vs A2C

  • A3C(异步):各 Worker 独立更新全局网络,不等待彼此,异步会导致“脏更新”。比如工人 A 算好的梯度还没上传,服务器已经被工人 B 更新过了,导致 A 的经验变成了“过期的药”。
  • A2C(同步):所有 Worker 完成一批数据后,统一更新。实践中 A2C 更常用,因为同步更新更稳定,且现代 GPU 做批量计算很快

目前业界基本只用 A2C,A3C 已经被淘汰了。


五、GAE:广义优势估计(Generalized Advantage Estimation)

5.1 TD(0) vs MC 的优势估计

GAE 是为了解决一个深度学习里经典的难题:偏差与方差的权衡。

方法 优势估计 偏差 方差
TD(0) \(\hat{A}_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 高偏差 低方差
MC \(\hat{A}_t = G_t - V(s_t)\) 低偏差 高方差

能否在偏差和方差之间平滑插值

我们之前用 1 步的 TD 误差来算优势:\(\delta_t^{(1)} = r_t + \gamma V(s_{t+1}) - V(s_t)\) 好处: 只看未来一步,波动很小(低方差)。
坏处: 太依赖 Critic 的预估 \(V(s_{t+1})\)。如果 Critic 是个菜鸟,乱估一通,就会导致 Actor 被带偏(高偏差)。

那如果不用预估,直接把直到游戏结束的所有真实奖励都加起来呢(这就是蒙特卡洛法 MC)?
好处: 绝对真实,没有任何预估误差(低偏差)。
坏处: 一局游戏变数太多,波动极大,根本没法学(高方差)。

GAE 的伟大之处:它全都要!GAE 引入了一个衰减参数 \(\lambda\)(取值在 0 到 1 之间)。它的思想是:把 1步 TD 误差、2步 TD 误差、3步......一直到 n步 的 TD 误差,按照指数级衰减的权重,全部加起来揉在一起!

5.2 GAE 公式

\[ \hat{A}_t^{\text{GAE}} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} \]

其中 \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 是每步的 TD 误差。

参数 \(\lambda \in [0, 1]\) 控制偏差-方差的权衡:

  • \(\lambda = 0\):退化为 TD(0),低方差高偏差
  • \(\lambda = 1\):退化为蒙特卡洛,高方差低偏差
  • \(\lambda = 0.95\):常用值,取得良好平衡

5.3 递推计算(非常高效)

\[ \hat{A}_t^{\text{GAE}} = \delta_t + \gamma \lambda \hat{A}_{t+1}^{\text{GAE}} \]
def compute_gae(rewards, values, next_value, gamma=0.99, lam=0.95):
    advantages = []
    gae = 0
    values = values + [next_value]  # 追加最后的 V(s')

    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t+1] - values[t]
        gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    return advantages

GAE 是 PPO 的标配

几乎所有现代 Actor-Critic 方法(尤其是 PPO)都使用 GAE 来估计优势函数。它是 PPO 高性能的重要组成部分。


六、方法演进总结

REINFORCE(纯策略梯度,方差大)
    │ 引入 Critic 作为基线
Actor-Critic(单步更新,低方差,有偏差)
    │ 异步并行训练
    ├──→ A3C(异步)
    ├──→ A2C(同步,更常用)
    │ 优势估计优化
    ├──→ GAE(λ 插值偏差-方差)
    │ 限制策略更新步长
PPO(近端策略优化,工业界主流)→ 下一章

关键公式速查

名称 公式
优势函数 \(A(s, a) = Q(s, a) - V(s)\)
TD 优势 \(\hat{A}_t = r + \gamma V(s') - V(s)\)
Actor 损失 \(L_{\text{actor}} = -\log \pi_\theta(a \mid s) \cdot \hat{A}\)
Critic 损失 \(L_{\text{critic}} = (r + \gamma V(s') - V(s))^2\)
GAE \(\hat{A}_t^{\text{GAE}} = \sum_{l=0}^{\infty}(\gamma\lambda)^l \delta_{t+l}\)
GAE 递推 \(\hat{A}_t = \delta_t + \gamma\lambda \hat{A}_{t+1}\)