import torch
import torch.nn as nn
from torch_geometric.nn import DenseSAGEConv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
[docs]class GNN(BaseFeaturesExtractor):
"""
Graph Convolution network: adopted from Zhao et. al "Robogrammar"
Args:
observation_space (gym.observation): The observation space of the gym environment
max_nodes (int): maximum number of nodes for linkage graph
num_features (int): number of points in the trajectory to describe the node features
hidden_channels (int, optional): hidden channels for the Dense SAGE convolutions. Defaults to 64.
out_channels (int, optional): number of output features. Defaults to 64.
normalize (bool, optional): normalization used in Dense SAGE. Defaults to False.
batch_normalization (bool, optional): Batch Normalization used. Defaults to False.
lin (bool, optional): Add linear layer to the end. Defaults to True.
add_loop (bool, optional): Add self loops. Defaults to False.
"""
def __init__(self, observation_space, max_nodes, num_features, hidden_channels=64, out_channels=64, normalize=False, batch_normalization=False, lin=True, add_loop=False):
super(GNN, self).__init__(observation_space, features_dim=1)
self.max_nodes = max_nodes #observation_space['mask'].shape[0]
self.num_features = num_features
in_channels = num_features #observation_space['x'].shape[0]// self.max_nodes # 10 = max nodes
self.add_loop = add_loop
self.batch_normalization = batch_normalization
self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
if self.batch_normalization:
self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
self.bn3 = torch.nn.BatchNorm1d(out_channels)
if lin is True:
self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,
out_channels)
else:
self.lin = None
self.relu = nn.ReLU()
self._features_dim = out_channels
[docs] def bn(self, i, x):
batch_size, num_nodes, in_channels = x.size()
x = x.view(-1, in_channels)
x = getattr(self, 'bn{}'.format(i))(x)
x = x.view(batch_size, num_nodes, in_channels)
return x
[docs] def forward(self, observations):
## Get shapes for observation
shape_x = self.max_nodes*self.num_features
shape_adj = self.max_nodes**2
shape_mask = self.max_nodes
## extract information from observation input
x = observations[:, :shape_x] #['x']
adj = observations[:, shape_x:shape_x+shape_adj] #['adj']
mask = observations[:, shape_x+shape_adj:shape_x+shape_adj+shape_mask] #['mask']
if len(x.size()) > 1:
batch_size, _ = x.size()
else:
batch_size = 1
## Reshape for model
x = x.view(batch_size, self.max_nodes, -1) # B, nodes, features
adj = adj.view(batch_size, self.max_nodes, self.max_nodes)
## Forward pass
if self.batch_normalization:
x1 = self.bn(1, self.relu(self.conv1(x, adj, mask))) #, #self.add_loop)))
x2 = self.bn(2, self.relu(self.conv2(x1, adj, mask))) #, #self.add_loop)))
x3 = self.bn(3, self.relu(self.conv3(x2, adj, mask))) #, #self.add_loop)))
else:
x1 = self.relu(self.conv1(x, adj, mask))
x2 = self.relu(self.conv2(x1, adj, mask))
x3 = self.relu(self.conv3(x2, adj, mask))
## Concatenate latent representations
x = torch.cat([x1, x2, x3], dim=-1)
## Extra linear layer
if self.lin is not None:
x = self.relu(self.lin(x))
## Aggrigate node features to output Graph latent representation
return x.sum(1)