BCQ: Batch-Constrained deep Q-learning

论文地址

alt text

在阅读本文之前,我们假设读者已经熟悉DDPG、TD3、VAE,对offline RL的任务设定有所了解,了解Behavioral Policy等基本术语的含义。

用VAE实现Behavior Cloning

抛开一切琐碎的细节,在看论文前,我们先从最简单的想法开始,思考一下我们拿到由Behavior Policy采集得到的offline dataset时,我们使用行为克隆(BC)会怎么做。

  • 方法一,直接用一个MLP,按照action reconstruction 损失进行训练,这种即使是在offline dataset全部是expert demos的情况下,表现也未必会很好(这个在Dagger论文已经给出证明,累计误差为O(T2)O(T^2)
  • 方法二,学一个条件VAE,跟方法一没有本质区别,好处是能够捕捉到offline dataset中针对同一个state的动作分布,毕竟offline dataset可能包含策略在不同学习阶段采集得到的data,针对同一个state,在训练初期和训练后期,可能会有不同的action,VAE的方式能够对action的分布进行建模,显然更加Robust一点

上面两个想法的本质都是监督学习,实际上在性能表现上没有太大区别,缺陷是缺乏探索性。我们不妨假设VAE为Gω(s)G_{\omega}(s)

action的ood问题:用VAE解决

我们回过头来看看DQN和DDPG的学习范式: TQ(s,a)=Es[r+γQ(s,π(s))] \mathcal{T} Q(s, a)=\mathbb{E}_{s'}[r + \gamma Q(s', \pi(s'))]

  • 在DQN中,π(s)=argmaxaQ(s,a) \pi(s') = \text{argmax}_{a'}Q(s', a')
  • 在DDPG中,π(s)=argmaxϕEsB[Qθ(s,πϕ(s))] \pi(s') = \text{argmax}_{\phi} \mathbb{E}_{s\in B}[Q_{\theta}(s, \pi_{\phi}(s))]

作者提到的所谓的外推误差的,无非就是因为a=π(s)a'=\pi(s')很有可能会导致(s,a)(s', a')没有出现在dataset里面,即Q(s,a)Q(s', a')完全是不准的,直接的后果便是Q(s,a)Q(s, a)的更新目标r+γQ(s,π(s))r + \gamma Q(s', \pi(s'))是不准的。我们称之为action的ood(out-of-distribution)问题。

ok,所以我们需要做的就是让aa'的选取尽量满足(s,a)(s', a')出现在数据集B里面。比较理想的更新方式便是下面这个方式了:

Q(s,a)(1α)Q(s,a)+α(r+γmaxa,s.t.(s,a)BQ(s,a)) Q(s, a)\rightarrow (1-\alpha)Q(s, a) + \alpha (r + \gamma \max_{a', s.t. (s', a')\in B}Q(s', a'))

但这里有一个问题,就是aa'怎么来,一个很直接的想法便是直接用VAE按BC的方式学一个出来,Gω(s)G_{\omega}(s')不就可以反映行为策略在面对s’时的action的分布吗?

稍微加点稳定训练的代码技巧,我们可以让VAE对同一个state多输出几个action,然后取Q值最大的一个作为Q(s,a)Q(s',a')的估计值。这样就可以很大程度的避免我们的策略在训练前期针对s’的a’的ood问题了。

按照这样的方式学到Q之后,其实策略也就可以直接得到了,策略更新的梯度往Q值大的方向走,跟DDPG的方式是一样的。

探索性

但是这样显然会存在一个问题,学出来的策略缺乏探索性,尽管学了一个Q函数,但是本质跟BC没有太大区别。因此BCQ在VAE的基础上加了一个扰动的网络:ξ(s,a)\xi(s, a), 最终的actor输出的action为Gω(s)+ξ(s,Gω(s))G_{\omega}(s) + \xi(s, G_{\omega}(s))。加扰动的方式是为了让策略有一定的探索性,这个在TD3中也有体现。

简单来说,BCQ作为offline RL早期的文章,核心思想是行为克隆 + 探索,BC的部分体现在VAE,探索部分体现在扰动网络的加入,整体来说,毕竟是TD3作者的续作,TD3也是在DDPG的确定性策略的部分加了一个扰动网络,作者这么加显然是经过实验的效果验证。一言以蔽之:在监督学习(纯BC)和RL之间做一个trade-off。

关键代码实现

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, phi=0.05):
        super(Actor, self).__init__()
        # (l1, l2, l3)对应前面提到的扰动网络 \xi(s, a)
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        self.max_action = max_action
        self.phi = phi

    def forward(self, state, action):
      # 这里的action来自于VAE的输出 G_{\omega}(s),state是s
        a = F.relu(self.l1(torch.cat([state, action], 1)))
        a = F.relu(self.l2(a))
        a = self.phi * self.max_action * torch.tanh(self.l3(a)) # \xi(s, G_{\omega}(s))
        # 输出的是 G_{\omega}(s) + \xi(s, G_{\omega}(s))
        return (a + action).clamp(-self.max_action, self.max_action)



class BCQ(object):

    # ... 此处省略非关键代码

    def train(self, replay_buffer, iterations, batch_size=100):

        for it in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            # Variational Auto-Encoder Training,需要读者对VAE有所了解
            recon, mean, std = self.vae(state, action)
            recon_loss = F.mse_loss(recon, action)
            KL_loss    = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_optimizer.zero_grad()
            vae_loss.backward()
            self.vae_optimizer.step()

            # Critic Training
            with torch.no_grad():
                # Duplicate next state 10 times
                next_state = torch.repeat_interleave(next_state, 10, 0)

                # Compute value of perturbed actions sampled from the VAE
                target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))

                # Soft Clipped Double Q-learning 
                target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
                # Take max over each action sampled from the VAE
                target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)

                target_Q = reward + not_done * self.discount * target_Q

            current_Q1, current_Q2 = self.critic(state, action)
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Pertubation Model / Action Training
            sampled_actions = self.vae.decode(state)
            perturbed_actions = self.actor(state, sampled_actions)

            # Update through DPG,需要读者对DDPG有所了解
            actor_loss = -self.critic.q1(state, perturbed_actions).mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update Target Networks 
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

results matching ""

    No results matching ""