from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution,
)
[docs]class CustomActorCriticPolicy(ActorCriticPolicy):
"""Custom Actor Critic Policy for GCP-HOLO
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
*args,
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
[docs] def get_distribution(self, obs: th.Tensor):
"""
Get the current policy distribution given the observations.
:param obs:
:return: the action distribution.
"""
features = self.extract_features(obs)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi, obs)
[docs] def forward(self, obs: th.Tensor, deterministic: bool = False): # -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs) ## GNN
latent_pi, latent_vf = self.mlp_extractor(features) ## Latent Vectors from value and policy
# Evaluate the values for the given observations
values = self.value_net(latent_vf) ## Value head
distribution = self._get_action_dist_from_latent(latent_pi, obs) ## Action Head
actions = distribution.get_actions(deterministic=deterministic) ## Select Action
log_prob = distribution.log_prob(actions) ## log_probs
return actions, values, log_prob
[docs] def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor): # -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
distribution = self._get_action_dist_from_latent(latent_pi, obs)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()
def _get_action_dist_from_latent(self, latent_pi: th.Tensor, obs): # -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
## Mitch Edit: Mask invalid actions
mask = obs[:, -self.action_space.n:].bool() #['action_mask'].bool()
mean_actions = self.action_net(latent_pi).masked_fill_(~mask, -1e9) ## Action Head
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")