Source code for ncps.torch.cfc_cell

# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
    import torch
except:
    raise ImportWarning(
        "It seems like the PyTorch package is not installed\n"
        "Installation instructions: https://pytorch.org/get-started/locally/\n",
    )
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from typing import Optional, Union


class LeCun(nn.Module):
    def __init__(self):
        super(LeCun, self).__init__()
        self.tanh = nn.Tanh()

    def forward(self, x):
        return 1.7159 * self.tanh(0.666 * x)


[docs]class CfCCell(nn.Module): def __init__( self, input_size, hidden_size, mode="default", backbone_activation="lecun_tanh", backbone_units=128, backbone_layers=1, backbone_dropout=0.0, sparsity_mask=None, ): """A `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ cell. .. Note:: This is an RNNCell that process single time-steps. To get a full RNN that can process sequences see `ncps.torch.CfC`. :param input_size: :param hidden_size: :param mode: :param backbone_activation: :param backbone_units: :param backbone_layers: :param backbone_dropout: :param sparsity_mask: """ super(CfCCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size allowed_modes = ["default", "pure", "no_gate"] if mode not in allowed_modes: raise ValueError( f"Unknown mode '{mode}', valid options are {str(allowed_modes)}" ) self.sparsity_mask = ( None if sparsity_mask is None else torch.nn.Parameter( data=torch.from_numpy(np.abs(sparsity_mask.T).astype(np.float32)), requires_grad=False, ) ) self.mode = mode if backbone_activation == "silu": backbone_activation = nn.SiLU elif backbone_activation == "relu": backbone_activation = nn.ReLU elif backbone_activation == "tanh": backbone_activation = nn.Tanh elif backbone_activation == "gelu": backbone_activation = nn.GELU elif backbone_activation == "lecun_tanh": backbone_activation = LeCun else: raise ValueError(f"Unknown activation {backbone_activation}") self.backbone = None self.backbone_layers = backbone_layers if backbone_layers > 0: layer_list = [ nn.Linear(input_size + hidden_size, backbone_units), backbone_activation(), ] for i in range(1, backbone_layers): layer_list.append(nn.Linear(backbone_units, backbone_units)) layer_list.append(backbone_activation()) if backbone_dropout > 0.0: layer_list.append(torch.nn.Dropout(backbone_dropout)) self.backbone = nn.Sequential(*layer_list) self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() cat_shape = int( self.hidden_size + input_size if backbone_layers == 0 else backbone_units ) self.ff1 = nn.Linear(cat_shape, hidden_size) if self.mode == "pure": self.w_tau = torch.nn.Parameter( data=torch.zeros(1, self.hidden_size), requires_grad=True ) self.A = torch.nn.Parameter( data=torch.ones(1, self.hidden_size), requires_grad=True ) else: self.ff2 = nn.Linear(cat_shape, hidden_size) self.time_a = nn.Linear(cat_shape, hidden_size) self.time_b = nn.Linear(cat_shape, hidden_size) self.init_weights()
[docs] def init_weights(self): for w in self.parameters(): if w.dim() == 2 and w.requires_grad: torch.nn.init.xavier_uniform_(w)
[docs] def forward(self, input, hx, ts): x = torch.cat([input, hx], 1) if self.backbone_layers > 0: x = self.backbone(x) if self.sparsity_mask is not None: ff1 = F.linear(x, self.ff1.weight * self.sparsity_mask, self.ff1.bias) else: ff1 = self.ff1(x) if self.mode == "pure": # Solution new_hidden = ( -self.A * torch.exp(-ts * (torch.abs(self.w_tau) + torch.abs(ff1))) * ff1 + self.A ) else: # Cfc if self.sparsity_mask is not None: ff2 = F.linear(x, self.ff2.weight * self.sparsity_mask, self.ff2.bias) else: ff2 = self.ff2(x) ff1 = self.tanh(ff1) ff2 = self.tanh(ff2) t_a = self.time_a(x) t_b = self.time_b(x) t_interp = self.sigmoid(t_a * ts + t_b) if self.mode == "no_gate": new_hidden = ff1 + t_interp * ff2 else: new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2 return new_hidden, new_hidden