Source code for models.dqn

from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gym
import numpy as np
import torch as th
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.utils import is_vectorized_observation
from stable_baselines3.common.type_aliases import GymEnv,Schedule
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3 import DQN
from stable_baselines3.dqn.policies import QNetwork, DQNPolicy

from typing import Any, Dict, List, Optional, Type

import gym
import torch as th
from torch import nn

from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    FlattenExtractor,
)
from stable_baselines3.common.type_aliases import Schedule


[docs]class CustomQNetwork(QNetwork): """ Action-Value (Q-Value) network for DQN :param observation_space: Observation space :param action_space: Action space :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, features_extractor: nn.Module, features_dim: int, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): super(CustomQNetwork, self).__init__( observation_space, action_space, features_extractor, features_dim, net_arch, activation_fn, normalize_images, )
[docs] def forward(self, obs: th.Tensor): """ Predict the q-values. :param obs: Observation :return: The estimated Q-Value for each action. """ action_mask = obs[:, -self.action_space.n:].bool() #['action_mask'].bool() return self.q_net(self.extract_features(obs)).masked_fill_(~action_mask, -1.0) #.softmax(1)
[docs]class CustomDQNPolicy(DQNPolicy): """ Policy class with Q-Value Net and target net for DQN :param observation_space: Observation space :param action_space: Action space :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): super(CustomDQNPolicy, self).__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, )
[docs] def make_q_net(self): # Make sure we always have separate networks for features extractors etc net_args = self._update_features_extractor(self.net_args, features_extractor=None) return CustomQNetwork(**net_args).to(self.device)
[docs]class CustomDQN(DQN): def __init__( self, policy: Union[str, Type[DQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 50000, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, max_grad_norm: float = 10, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ): super(CustomDQN, self).__init__( policy, env, learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, replay_buffer_class, replay_buffer_kwargs, optimize_memory_usage, target_update_interval, exploration_fraction, exploration_initial_eps, exploration_final_eps, max_grad_norm, tensorboard_log, create_eval_env, policy_kwargs, verbose, seed, device, _init_setup_model ) def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None, n_envs: int = 1, ): """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, or sampling a random action (from a uniform distribution over the action space) or by adding noise to the deterministic output. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param n_envs: :return: action to take in the environment and scaled action that will be stored in the replay buffer. The two differs when the action space is not normalized (bounds are not [-1, 1]). """ # Select action randomly or according to policy if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): all_actions = np.arange(self.action_space.n) unscaled_action = np.array([np.random.choice(all_actions, 1, p=self._last_obs[i, -self.action_space.n:]/self._last_obs[i, -self.action_space.n:].sum()) for i in range(n_envs)]) else: # Note: when using continuous actions, # we assume that the policy uses tanh to scale the action # We use non-deterministic action in the case of SAC, for TD3, it does not matter unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] if isinstance(self.action_space, gym.spaces.Box): scaled_action = self.policy.scale_action(unscaled_action) # Add noise to the action (improve exploration) if action_noise is not None: scaled_action = np.clip(scaled_action + action_noise(), -1, 1) # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) else: # Discrete case, no need to normalize or clip buffer_action = unscaled_action action = buffer_action return action, buffer_action
[docs] def train(self, gradient_steps: int, batch_size: int = 100): # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) with th.no_grad(): # Compute the next Q-values using the target network next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values # import pdb # pdb.set_trace() # Get current Q-values estimates current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) loss = F.smooth_l1_loss(current_q_values, target_q_values) if np.isnan(loss.detach().cpu().numpy()).any(): import pdb pdb.set_trace() raise ValueError losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/loss", np.mean(losses))
[docs] def predict( self, observation: np.ndarray, state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ): """ Overrides the base_class predict function to include epsilon-greedy exploration. :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): if isinstance(self.observation_space, gym.spaces.Dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: n_batch = observation.shape[0] all_actions = np.arange(self.action_space.n) action = np.array([np.random.choice(all_actions, 1, p=observation[i, -self.action_space.n:]/observation[i, -self.action_space.n:].sum()) for i in range(n_batch)]) else: action = np.array(self.action_space.sample()) else: action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state