from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.utils import is_vectorized_observation
from stable_baselines3.common.type_aliases import GymEnv,Schedule
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3 import DQN
from stable_baselines3.dqn.policies import QNetwork, DQNPolicy
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
)
from stable_baselines3.common.type_aliases import Schedule
[docs]class CustomQNetwork(QNetwork):
"""
Action-Value (Q-Value) network for DQN
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(CustomQNetwork, self).__init__(
observation_space,
action_space,
features_extractor,
features_dim,
net_arch,
activation_fn,
normalize_images,
)
[docs] def forward(self, obs: th.Tensor):
"""
Predict the q-values.
:param obs: Observation
:return: The estimated Q-Value for each action.
"""
action_mask = obs[:, -self.action_space.n:].bool() #['action_mask'].bool()
return self.q_net(self.extract_features(obs)).masked_fill_(~action_mask, -1.0) #.softmax(1)
[docs]class CustomDQNPolicy(DQNPolicy):
"""
Policy class with Q-Value Net and target net for DQN
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CustomDQNPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)
[docs] def make_q_net(self):
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return CustomQNetwork(**net_args).to(self.device)
[docs]class CustomDQN(DQN):
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super(CustomDQN, self).__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
replay_buffer_class,
replay_buffer_kwargs,
optimize_memory_usage,
target_update_interval,
exploration_fraction,
exploration_initial_eps,
exploration_final_eps,
max_grad_norm,
tensorboard_log,
create_eval_env,
policy_kwargs,
verbose,
seed,
device,
_init_setup_model
)
def _sample_action(
self,
learning_starts: int,
action_noise: Optional[ActionNoise] = None,
n_envs: int = 1,
):
"""
Sample an action according to the exploration policy.
This is either done by sampling the probability distribution of the policy,
or sampling a random action (from a uniform distribution over the action space)
or by adding noise to the deterministic output.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param n_envs:
:return: action to take in the environment
and scaled action that will be stored in the replay buffer.
The two differs when the action space is not normalized (bounds are not [-1, 1]).
"""
# Select action randomly or according to policy
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
all_actions = np.arange(self.action_space.n)
unscaled_action = np.array([np.random.choice(all_actions, 1, p=self._last_obs[i, -self.action_space.n:]/self._last_obs[i, -self.action_space.n:].sum()) for i in range(n_envs)])
else:
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
if isinstance(self.action_space, gym.spaces.Box):
scaled_action = self.policy.scale_action(unscaled_action)
# Add noise to the action (improve exploration)
if action_noise is not None:
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action
return action, buffer_action
[docs] def train(self, gradient_steps: int, batch_size: int = 100):
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# import pdb
# pdb.set_trace()
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
if np.isnan(loss.detach().cpu().numpy()).any():
import pdb
pdb.set_trace()
raise ValueError
losses.append(loss.item())
# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))
[docs] def predict(
self,
observation: np.ndarray,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
):
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, gym.spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]
all_actions = np.arange(self.action_space.n)
action = np.array([np.random.choice(all_actions, 1, p=observation[i, -self.action_space.n:]/observation[i, -self.action_space.n:].sum()) for i in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state