Source code for models.a2c

from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import gym
import torch as th
from torch import nn

from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.distributions import (
    BernoulliDistribution,
    CategoricalDistribution,
    DiagGaussianDistribution,
    MultiCategoricalDistribution,
    StateDependentNoiseDistribution,
)


[docs]class CustomActorCriticPolicy(ActorCriticPolicy): """Custom Actor Critic Policy for GCP-HOLO """ def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, *args, **kwargs, ): super(CustomActorCriticPolicy, self).__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn, # Pass remaining arguments to base class *args, **kwargs, ) # Disable orthogonal initialization self.ortho_init = False
[docs] def get_distribution(self, obs: th.Tensor): """ Get the current policy distribution given the observations. :param obs: :return: the action distribution. """ features = self.extract_features(obs) latent_pi = self.mlp_extractor.forward_actor(features) return self._get_action_dist_from_latent(latent_pi, obs)
[docs] def forward(self, obs: th.Tensor, deterministic: bool = False): # -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) :param obs: Observation :param deterministic: Whether to sample or use deterministic actions :return: action, value and log probability of the action """ # Preprocess the observation if needed features = self.extract_features(obs) ## GNN latent_pi, latent_vf = self.mlp_extractor(features) ## Latent Vectors from value and policy # Evaluate the values for the given observations values = self.value_net(latent_vf) ## Value head distribution = self._get_action_dist_from_latent(latent_pi, obs) ## Action Head actions = distribution.get_actions(deterministic=deterministic) ## Select Action log_prob = distribution.log_prob(actions) ## log_probs return actions, values, log_prob
[docs] def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor): # -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, given the observations. :param obs: :param actions: :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ # Preprocess the observation if needed features = self.extract_features(obs) latent_pi, latent_vf = self.mlp_extractor(features) distribution = self._get_action_dist_from_latent(latent_pi, obs) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) return values, log_prob, distribution.entropy()
def _get_action_dist_from_latent(self, latent_pi: th.Tensor, obs): # -> Distribution: """ Retrieve action distribution given the latent codes. :param latent_pi: Latent code for the actor :return: Action distribution """ ## Mitch Edit: Mask invalid actions mask = obs[:, -self.action_space.n:].bool() #['action_mask'].bool() mean_actions = self.action_net(latent_pi).masked_fill_(~mask, -1e9) ## Action Head if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) elif isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, MultiCategoricalDistribution): # Here mean_actions are the flattened logits return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, BernoulliDistribution): # Here mean_actions are the logits (before rounding to get the binary actions) return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) else: raise ValueError("Invalid action distribution")