Source code for models.gcpn

import torch
import torch.nn as nn
from torch_geometric.nn import DenseSAGEConv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


[docs]class GNN(BaseFeaturesExtractor): """ Graph Convolution network: adopted from Zhao et. al "Robogrammar" Args: observation_space (gym.observation): The observation space of the gym environment max_nodes (int): maximum number of nodes for linkage graph num_features (int): number of points in the trajectory to describe the node features hidden_channels (int, optional): hidden channels for the Dense SAGE convolutions. Defaults to 64. out_channels (int, optional): number of output features. Defaults to 64. normalize (bool, optional): normalization used in Dense SAGE. Defaults to False. batch_normalization (bool, optional): Batch Normalization used. Defaults to False. lin (bool, optional): Add linear layer to the end. Defaults to True. add_loop (bool, optional): Add self loops. Defaults to False. """ def __init__(self, observation_space, max_nodes, num_features, hidden_channels=64, out_channels=64, normalize=False, batch_normalization=False, lin=True, add_loop=False): super(GNN, self).__init__(observation_space, features_dim=1) self.max_nodes = max_nodes #observation_space['mask'].shape[0] self.num_features = num_features in_channels = num_features #observation_space['x'].shape[0]// self.max_nodes # 10 = max nodes self.add_loop = add_loop self.batch_normalization = batch_normalization self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize) self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize) self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize) if self.batch_normalization: self.bn1 = torch.nn.BatchNorm1d(hidden_channels) self.bn2 = torch.nn.BatchNorm1d(hidden_channels) self.bn3 = torch.nn.BatchNorm1d(out_channels) if lin is True: self.lin = torch.nn.Linear(2 * hidden_channels + out_channels, out_channels) else: self.lin = None self.relu = nn.ReLU() self._features_dim = out_channels
[docs] def bn(self, i, x): batch_size, num_nodes, in_channels = x.size() x = x.view(-1, in_channels) x = getattr(self, 'bn{}'.format(i))(x) x = x.view(batch_size, num_nodes, in_channels) return x
[docs] def forward(self, observations): ## Get shapes for observation shape_x = self.max_nodes*self.num_features shape_adj = self.max_nodes**2 shape_mask = self.max_nodes ## extract information from observation input x = observations[:, :shape_x] #['x'] adj = observations[:, shape_x:shape_x+shape_adj] #['adj'] mask = observations[:, shape_x+shape_adj:shape_x+shape_adj+shape_mask] #['mask'] if len(x.size()) > 1: batch_size, _ = x.size() else: batch_size = 1 ## Reshape for model x = x.view(batch_size, self.max_nodes, -1) # B, nodes, features adj = adj.view(batch_size, self.max_nodes, self.max_nodes) ## Forward pass if self.batch_normalization: x1 = self.bn(1, self.relu(self.conv1(x, adj, mask))) #, #self.add_loop))) x2 = self.bn(2, self.relu(self.conv2(x1, adj, mask))) #, #self.add_loop))) x3 = self.bn(3, self.relu(self.conv3(x2, adj, mask))) #, #self.add_loop))) else: x1 = self.relu(self.conv1(x, adj, mask)) x2 = self.relu(self.conv2(x1, adj, mask)) x3 = self.relu(self.conv3(x2, adj, mask)) ## Concatenate latent representations x = torch.cat([x1, x2, x3], dim=-1) ## Extra linear layer if self.lin is not None: x = self.relu(self.lin(x)) ## Aggrigate node features to output Graph latent representation return x.sum(1)