Diffusion QL: Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning
关键词:Offline Reinforcement Learning, Diffusion Model, Behavior Cloning, Q-Learning
(1) 假如给你一个离线数据集,你打算怎么用Diffusion实现BC?
我们想想,此前在Offline RL里面用BC有哪些算法?
- BCQ: CVAE(行为克隆,策略约束) + 扰动网络(探索性) + Q Learning
- TD3+BC
废话不多说,直接看代码,一看就懂了。
Diffusion_BC
class MLP(nn.Module):
"""
MLP Model
"""
def __init__(self,
state_dim,
action_dim,
device,
t_dim=16):
super(MLP, self).__init__()
self.device = device
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(t_dim),
nn.Linear(t_dim, t_dim * 2),
nn.Mish(),
nn.Linear(t_dim * 2, t_dim),
)
input_dim = state_dim + action_dim + t_dim
self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
nn.Mish(),
nn.Linear(256, 256),
nn.Mish(),
nn.Linear(256, 256),
nn.Mish())
self.final_layer = nn.Linear(256, action_dim)
def forward(self, x, time, state):
t = self.time_mlp(time)
x = torch.cat([x, t, state], dim=1)
x = self.mid_layer(x)
return self.final_layer(x)
class Diffusion(nn.Module):
def __init__(self, state_dim, action_dim, model, max_action,
beta_schedule='linear', n_timesteps=100,
loss_type='l2', clip_denoised=True, predict_epsilon=True):
super(Diffusion, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.max_action = max_action
self.model = model
if beta_schedule == 'linear':
betas = linear_beta_schedule(n_timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(n_timesteps)
elif beta_schedule == 'vp':
betas = vp_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = int(n_timesteps)
self.clip_denoised = clip_denoised
self.predict_epsilon = predict_epsilon
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
## log calculation clipped because the posterior variance
## is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1',
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
self.loss_fn = Losses[loss_type]()
# ------------------------------------------ sampling ------------------------------------------#
def predict_start_from_noise(self, x_t, t, noise):
'''
if self.predict_epsilon, model output is (scaled) noise;
otherwise, model predicts x0 directly
'''
if self.predict_epsilon:
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
else:
return noise
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, s):
x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, t, s))
if self.clip_denoised:
x_recon.clamp_(-self.max_action, self.max_action)
else:
assert RuntimeError()
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
# @torch.no_grad()
def p_sample(self, x, t, s):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
# @torch.no_grad()
def p_sample_loop(self, state, shape, verbose=False, return_diffusion=False):
device = self.betas.device
batch_size = shape[0]
x = torch.randn(shape, device=device)
if return_diffusion: diffusion = [x]
progress = Progress(self.n_timesteps) if verbose else Silent()
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
x = self.p_sample(x, timesteps, state)
progress.update({'t': i})
if return_diffusion: diffusion.append(x)
progress.close()
if return_diffusion:
return x, torch.stack(diffusion, dim=1)
else:
return x
# @torch.no_grad()
def sample(self, state, *args, **kwargs):
batch_size = state.shape[0]
shape = (batch_size, self.action_dim)
action = self.p_sample_loop(state, shape, *args, **kwargs)
return action.clamp_(-self.max_action, self.max_action)
# ------------------------------------------ training ------------------------------------------#
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sample = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
return sample
def p_losses(self, x_start, state, t, weights=1.0):
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = self.model(x_noisy, t, state)
assert noise.shape == x_recon.shape
if self.predict_epsilon:
loss = self.loss_fn(x_recon, noise, weights)
else:
loss = self.loss_fn(x_recon, x_start, weights)
return loss
def loss(self, x, state, weights=1.0):
batch_size = len(x)
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
return self.p_losses(x, state, t, weights)
def forward(self, state, *args, **kwargs):
return self.sample(state, *args, **kwargs)
class Diffusion_BC(object):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
beta_schedule='linear',
n_timesteps=100,
lr=2e-4,
):
self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action,
beta_schedule=beta_schedule, n_timesteps=n_timesteps,
).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.device = device
def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
metric = {'bc_loss': [], 'ql_loss': [], 'actor_loss': [], 'critic_loss': []}
for _ in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
loss = self.actor.loss(action, state)
self.actor_optimizer.zero_grad()
loss.backward()
self.actor_optimizer.step()
metric['actor_loss'].append(0.)
metric['bc_loss'].append(loss.item())
metric['ql_loss'].append(0.)
metric['critic_loss'].append(0.)
return metric
def sample_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
with torch.no_grad():
action = self.actor.sample(state)
return action.cpu().data.numpy().flatten()
def save_model(self, dir, id=None):
if id is not None:
torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
else:
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
def load_model(self, dir, id=None):
if id is not None:
self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth'))
else:
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
Diffusion的代码跟Diffuser里面的代码大差不差,跟Diffuser的差别是,Diffuser输入的是一个轨迹片段,整个思路更像是Model-Based,但是这里的更像是Behavior-Cloning,输入是单个状态,输出是针对该状态的动作,数学形式如下:
训练的范式跟DDPM保持一致。
那现在问题来了,为什么不直接用最简单的BC,比如用一个GuassianPolicy,使用logprob来计算loss,然后优化?又抑或为什么不试试CVAE? 作者在原论文做了一个Toy Experiment,需要模仿的策略是4个策略的混合。
效果:
可以看到只有Diffusion BC很好的学到了多峰的分布。
(2) 只用BC貌似不会学到比数据集更好的策略,或许我们可以加入Q-Learing?
对!跟BCQ或者Td3+BC一样的流程了!所以原则上你如果看过了BCQ,这里我们只不过是换了Policy,将其变成了Diffusion BC。
总结来说,DQL 被设计用于 Offline RL 任务,因此 loss 包含两项,一项用于 Behavior Clone 数据集中的 Behavior 策略(用 Diffusion model 拟合数据集中的策略分布),一项用于最大化 Q function 提高策略性能(优化策略以最大 Q function)
(3) 论文缺点
没有确定的统计数学上的含义,无法和某个确定的概率分布对应。之后的SfBC和QGPO会改进。