Pytorch实现PG中的概率采样+反向传播

date
Jun 4, 2021
slug
pytorch-distributions
status
Published
tags
DeepLearning
PyTorch
ReinforcementLearning
Notes
summary
Policy Gradient的核心问题
type
Post
用PyTorch实现Policy Gradient的核心问题:计算
重点是怎么用。
  • 可参数化的概率分布和采样函数,用来构造随机计算图和随机梯度估计优化器。
  • 通过随机样本无法直接反向传播。有两种办法创建代理函数来执行反传。RL里主要用第一种,即得分函数估计器,策略梯度方法的基础。如果PDF对于它的参数可微分,那么
    • 在策略梯度方法里不只需要获取采样到的动作,还需要对应动作的log_prob。log_prob乘上对应的reward(i.e. weight)求和以后就可以反传回去,更新参数
  • Categorical:
    • 具体用什么分布是看问题的。例如Cartpole的动作空间是{0, 1},就用Bernouli分布。信号灯控制问题是多选一,所以输出先用softmax整理一下,然后Categorical(也就是Multinomial分布)。
    • If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.
    • If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.
用这个办法实现的反向传播代码
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.distributions import Categorical
import numpy as np

torch.manual_seed(2021)
torch.cuda.manual_seed(2021)

class PolicyNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNet, self).__init__()

        self.fc1 = nn.Linear(input_size, 24)
        self.fc2 = nn.Linear(24, 36)
        self.fc3 = nn.Linear(36, output_size)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = nn.functional.softmax(self.fc3(x),dim=0)
        return x

p = PolicyNet(24, 8)
optimizer = torch.optim.Adam(
            p.parameters(), lr=0.01)
i = torch.rand(24).float()
prob = p(Variable(i, requires_grad = False))
prob
>> tensor([0.1312, 0.1242, 0.1192, 0.1465, 0.1253, 0.1037, 0.1214, 0.1285],
       grad_fn=<SoftmaxBackward>)
# -----------------------------------------
torch.autograd.set_detect_anomaly(True)
for epoch in range(5):
    optimizer.zero_grad()
    prob = p(Variable(torch.rand(24), requires_grad = True))
    m = Categorical(prob)
    action = m.sample()
    loss = -m.log_prob(action) * 50
    loss.backward()
    optimizer.step()
    print(f"epoch: {epoch}, loss: {loss}")
>> epoch: 0, loss: 101.52423095703125
>> epoch: 1, loss: 101.901611328125
>> epoch: 2, loss: 109.515380859375
>> epoch: 3, loss: 94.08329772949219
>> epoch: 4, loss: 106.5773696899414
# -----------------------------------------
prob = p(Variable(i, requires_grad = False))
prob
>> tensor([0.0918, 0.1821, 0.0951, 0.1145, 0.1298, 0.0858, 0.1164, 0.1843],
       grad_fn=<SoftmaxBackward>) # 可以看到反向传播已经更新过网络了。
 

© Phillip Gu 2021