BCQ: Batch-Constrained deep Q-learning
用VAE实现Behavior Cloning
抛开一切琐碎的细节,在看论文前,我们先从最简单的想法开始,思考一下我们拿到由Behavior Policy采集得到的offline dataset时,我们使用行为克隆(BC)会怎么做。
- 方法一,直接用一个MLP,按照action reconstruction 损失进行训练,这种即使是在offline dataset全部是expert demos的情况下,表现也未必会很好(这个在Dagger论文已经给出证明,累计误差为)
- 方法二,学一个条件VAE,跟方法一没有本质区别,好处是能够捕捉到offline dataset中针对同一个state的动作分布,毕竟offline dataset可能包含策略在不同学习阶段采集得到的data,针对同一个state,在训练初期和训练后期,可能会有不同的action,VAE的方式能够对action的分布进行建模,显然更加Robust一点
上面两个想法的本质都是监督学习,实际上在性能表现上没有太大区别,缺陷是缺乏探索性。我们不妨假设VAE为。
action的ood问题:用VAE解决
我们回过头来看看DQN和DDPG的学习范式:
- 在DQN中,
- 在DDPG中,
作者提到的所谓的外推误差的,无非就是因为很有可能会导致没有出现在dataset里面,即完全是不准的,直接的后果便是的更新目标是不准的。我们称之为action的ood(out-of-distribution)问题。
ok,所以我们需要做的就是让的选取尽量满足出现在数据集B里面。比较理想的更新方式便是下面这个方式了:
但这里有一个问题,就是怎么来,一个很直接的想法便是直接用VAE按BC的方式学一个出来,不就可以反映行为策略在面对s’时的action的分布吗?
稍微加点稳定训练的代码技巧,我们可以让VAE对同一个state多输出几个action,然后取Q值最大的一个作为的估计值。这样就可以很大程度的避免我们的策略在训练前期针对s’的a’的ood问题了。
按照这样的方式学到Q之后,其实策略也就可以直接得到了,策略更新的梯度往Q值大的方向走,跟DDPG的方式是一样的。
探索性
但是这样显然会存在一个问题,学出来的策略缺乏探索性,尽管学了一个Q函数,但是本质跟BC没有太大区别。因此BCQ在VAE的基础上加了一个扰动的网络:, 最终的actor输出的action为。加扰动的方式是为了让策略有一定的探索性,这个在TD3中也有体现。
简单来说,BCQ作为offline RL早期的文章,核心思想是行为克隆 + 探索,BC的部分体现在VAE,探索部分体现在扰动网络的加入,整体来说,毕竟是TD3作者的续作,TD3也是在DDPG的确定性策略的部分加了一个扰动网络,作者这么加显然是经过实验的效果验证。一言以蔽之:在监督学习(纯BC)和RL之间做一个trade-off。
关键代码实现
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action, phi=0.05):
super(Actor, self).__init__()
# (l1, l2, l3)对应前面提到的扰动网络 \xi(s, a)
self.l1 = nn.Linear(state_dim + action_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_dim)
self.max_action = max_action
self.phi = phi
def forward(self, state, action):
# 这里的action来自于VAE的输出 G_{\omega}(s),state是s
a = F.relu(self.l1(torch.cat([state, action], 1)))
a = F.relu(self.l2(a))
a = self.phi * self.max_action * torch.tanh(self.l3(a)) # \xi(s, G_{\omega}(s))
# 输出的是 G_{\omega}(s) + \xi(s, G_{\omega}(s))
return (a + action).clamp(-self.max_action, self.max_action)
class BCQ(object):
# ... 此处省略非关键代码
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
# Variational Auto-Encoder Training,需要读者对VAE有所了解
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
# Critic Training
with torch.no_grad():
# Duplicate next state 10 times
next_state = torch.repeat_interleave(next_state, 10, 0)
# Compute value of perturbed actions sampled from the VAE
target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))
# Soft Clipped Double Q-learning
target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
# Take max over each action sampled from the VAE
target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
target_Q = reward + not_done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Pertubation Model / Action Training
sampled_actions = self.vae.decode(state)
perturbed_actions = self.actor(state, sampled_actions)
# Update through DPG,需要读者对DDPG有所了解
actor_loss = -self.critic.q1(state, perturbed_actions).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)