深度强化学习算法在PyTorch中的实现与应用
近年来,深度强化学习(Deep Reinforcement Learning, DRL)在人工智能领域取得了巨大的进展,成为了解决复杂决策问题的有力工具。本文将详细介绍多种流行的DRL算法在PyTorch中的实现,包括Q-learning、DQN(Deep Q-Network)、PPO(Proximal Policy Optimization)、DDPG(Deep Deterministic Policy Gradient)、TD3(Twin Delayed Deep Deterministic Policy Gradient)和SAC(Soft Actor-Critic)等。我们将探讨这些算法的原理、优缺点以及在实际问题中的应用。
DRL算法概述
深度强化学习结合了深度学习和强化学习的优点,能够在复杂的环境中学习最优策略。以下是一些主要的DRL算法:
- Q-learning: 经典的值迭代算法,用于离散动作空间。
- DQN: 将深度神经网络与Q-learning结合,能够处理高维状态空间。
- PPO: 基于策略梯度的算法,通过限制新旧策略的差异来稳定训练。
- DDPG: 适用于连续动作空间的确定性策略梯度算法。
- TD3: DDPG的改进版本,通过多个技巧提高了性能和稳定性。
- SAC: 结合了最大熵强化学习的思想,在探索与利用之间取得良好平衡。
PyTorch实现
PyTorch作为一个灵活而强大的深度学习框架,非常适合实现各种DRL算法。以下是这些算法在PyTorch中实现的关键点:
Q-learning
Q-learning是最基础的强化学习算法之一,它通过迭代更新Q值表来学习最优策略。虽然Q-learning本身不需要深度学习,但了解它有助于理解更复杂的DRL算法。
import torch
class QLearning:
def __init__(self, state_dim, action_dim, learning_rate=0.1, gamma=0.99):
self.Q = torch.zeros((state_dim, action_dim))
self.lr = learning_rate
self.gamma = gamma
def update(self, state, action, reward, next_state):
target = reward + self.gamma * torch.max(self.Q[next_state])
self.Q[state, action] += self.lr * (target - self.Q[state, action])
def get_action(self, state):
return torch.argmax(self.Q[state]).item()
DQN (Deep Q-Network)
DQN通过使用深度神经网络来近似Q函数,极大地扩展了Q-learning处理复杂环境的能力。
import torch
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super(DQN, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
class DQNAgent:
def __init__(self, state_dim, action_dim, learning_rate=1e-3, gamma=0.99):
self.q_network = DQN(state_dim, action_dim)
self.target_network = DQN(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=learning_rate)
self.gamma = gamma
def update(self, state, action, reward, next_state, done):
state = torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state)
q_values = self.q_network(state)
next_q_values = self.target_network(next_state).detach()
target = q_values.clone()
target[action] = reward + (1 - done) * self.gamma * next_q_values.max()
loss = nn.MSELoss()(q_values, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def get_action(self, state):
state = torch.FloatTensor(state)
q_values = self.q_network(state)
return torch.argmax(q_values).item()
PPO (Proximal Policy Optimization)
PPO是一种基于策略梯度的算法,通过引入信任区域约束来提高训练的稳定性。
import torch
import torch.nn as nn
import torch.optim as optim
class PPO(nn.Module):
def __init__(self, state_dim, action_dim):
super(PPO, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
self.critic = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
def forward(self, state):
return self.actor(state), self.critic(state)
class PPOAgent:
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2):
self.ppo = PPO(state_dim, action_dim)
self.optimizer = optim.Adam(self.ppo.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
def update(self, states, actions, rewards, next_states, dones, old_probs):
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
old_probs = torch.FloatTensor(old_probs)
for _ in range(10): # 多次更新
new_probs, state_values = self.ppo(states)
new_probs = new_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
ratio = new_probs / old_probs
surr1 = ratio * rewards
surr2 = torch.clamp(ratio, 1-self.epsilon, 1+self.epsilon) * rewards
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = nn.MSELoss()(state_values, rewards)
loss = actor_loss + 0.5 * critic_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def get_action(self, state):
state = torch.FloatTensor(state)
probs, _ = self.ppo(state)
return torch.multinomial(probs, 1).item()
DDPG (Deep Deterministic Policy Gradient)
DDPG是一种适用于连续动作空间的算法,它结合了DQN和确定性策略梯度的思想。
import torch
import torch.nn as nn
import torch.optim as optim
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, state):
a = torch.relu(self.fc1(state))
a = torch.relu(self.fc2(a))
return self.max_action * torch.tanh(self.fc3(a))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, 1)
def forward(self, state, action):
q = torch.cat([state, action], 1)
q = torch.relu(self.fc1(q))
q = torch.relu(self.fc2(q))
return self.fc3(q)
class DDPGAgent:
def __init__(self, state_dim, action_dim, max_action, lr=1e-4, gamma=0.99, tau=0.001):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
self.gamma = gamma
self.tau = tau
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1))
return self.actor(state).cpu().data.numpy().flatten()
def update(self, replay_buffer, batch_size=100):
# 从经验回放中采样
state, action, next_state, reward, done = replay_buffer.sample(batch_size)
# 计算目标Q值
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q = reward + (1 - done) * self.gamma * target_Q.detach()
# 更新Critic
current_Q = self.critic(state, action)
critic_loss = nn.MSELoss()(current_Q, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 更新Actor
actor_loss = -self.critic(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 软更新目标网络
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
TD3 (Twin Delayed DDPG)
TD3是DDPG的改进版本,通过引入双Q学习、延迟策略更新和目标策略平滑等技巧来提高性能和稳定性。
import torch
import torch.nn as nn
import torch.optim as optim
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, state):
a = torch.relu(self.fc1(state))
a = torch.relu(self.fc2(a))
return self.max_action * torch.tanh(self.fc3(a))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q = torch.relu(self.fc1(sa))
q = torch.relu(self.fc2(q))
return self.fc3(q)
class TD3Agent:
def __init__(self, state_dim, action_dim, max_action, lr=1e-3, gamma=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
self.critic1 = Critic(state_dim, action_dim)
self.critic2 = Critic(state_dim, action_dim)
self.critic1_target = Critic(state_dim, action_dim)
self.critic2_target = Critic(state_dim, action_dim)
self.critic1_target.load_state_dict(self.critic1.state_dict())
self.critic2_target.load_state_dict(self.critic2.state_dict())
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr)
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_freq = policy_freq