# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from typing import Optional, Union
import ncps
from . import CfCCell, WiredCfCCell
from .lstm import LSTMCell
[docs]class CfC(nn.Module):
def __init__(
self,
input_size: Union[int, ncps.wirings.Wiring],
units,
proj_size: Optional[int] = None,
return_sequences: bool = True,
batch_first: bool = True,
mixed_memory: bool = False,
mode: str = "default",
activation: str = "lecun_tanh",
backbone_units: Optional[int] = None,
backbone_layers: Optional[int] = None,
backbone_dropout: Optional[int] = None,
):
"""Applies a `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ RNN to an input sequence.
Examples::
>>> from ncps.torch import CfC
>>>
>>> rnn = CfC(20,50)
>>> x = torch.randn(2, 3, 20) # (batch, time, features)
>>> h0 = torch.zeros(2,50) # (batch, units)
>>> output, hn = rnn(x,h0)
:param input_size: Number of input features
:param units: Number of hidden units
:param proj_size: If not None, the output of the RNN will be projected to a tensor with dimension proj_size (i.e., an implict linear output layer)
:param return_sequences: Whether to return the full sequence or just the last output
:param batch_first: Whether the batch or time dimension is the first (0-th) dimension
:param mixed_memory: Whether to augment the RNN by a `memory-cell <https://arxiv.org/abs/2006.04418>`_ to help learn long-term dependencies in the data
:param mode: Either "default", "pure" (direct solution approximation), or "no_gate" (without second gate).
:param activation: Activation function used in the backbone layers
:param backbone_units: Number of hidden units in the backbone layer (default 128)
:param backbone_layers: Number of backbone layers (default 1)
:param backbone_dropout: Dropout rate in the backbone layers (default 0)
"""
super(CfC, self).__init__()
self.input_size = input_size
self.wiring_or_units = units
self.proj_size = proj_size
self.batch_first = batch_first
self.return_sequences = return_sequences
if isinstance(units, ncps.wirings.Wiring):
self.wired_mode = True
if backbone_units is not None:
raise ValueError(f"Cannot use backbone_units in wired mode")
if backbone_layers is not None:
raise ValueError(f"Cannot use backbone_layers in wired mode")
if backbone_dropout is not None:
raise ValueError(f"Cannot use backbone_dropout in wired mode")
# self.rnn_cell = WiredCfCCell(input_size, wiring_or_units)
self.wiring = units
self.state_size = self.wiring.units
self.output_size = self.wiring.output_dim
self.rnn_cell = WiredCfCCell(
input_size,
self.wiring_or_units,
mode,
)
else:
self.wired_false = True
backbone_units = 128 if backbone_units is None else backbone_units
backbone_layers = 1 if backbone_layers is None else backbone_layers
backbone_dropout = 0.0 if backbone_dropout is None else backbone_dropout
self.state_size = units
self.output_size = self.state_size
self.rnn_cell = CfCCell(
input_size,
self.wiring_or_units,
mode,
activation,
backbone_units,
backbone_layers,
backbone_dropout,
)
self.use_mixed = mixed_memory
if self.use_mixed:
self.lstm = LSTMCell(input_size, self.state_size)
if proj_size is None:
self.fc = nn.Identity()
else:
self.fc = nn.Linear(self.output_size, self.proj_size)
[docs] def forward(self, input, hx=None, timespans=None):
"""
:param input: Input tensor of shape (L,C) in batchless mode, or (B,L,C) if batch_first was set to True and (L,B,C) if batch_first is False
:param hx: Initial hidden state of the RNN of shape (B,H) if mixed_memory is False and a tuple ((B,H),(B,H)) if mixed_memory is True. If None, the hidden states are initialized with all zeros.
:param timespans:
:return: A pair (output, hx), where output and hx the final hidden state of the RNN
"""
device = input.device
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
seq_dim = 1 if self.batch_first else 0
if not is_batched:
input = input.unsqueeze(batch_dim)
if timespans is not None:
timespans = timespans.unsqueeze(batch_dim)
batch_size, seq_len = input.size(batch_dim), input.size(seq_dim)
if hx is None:
h_state = torch.zeros((batch_size, self.state_size), device=device)
c_state = (
torch.zeros((batch_size, self.state_size), device=device)
if self.use_mixed
else None
)
else:
if self.use_mixed and isinstance(hx, torch.Tensor):
raise RuntimeError(
"Running a CfC with mixed_memory=True, requires a tuple (h0,c0) to be passed as state (got torch.Tensor instead)"
)
h_state, c_state = hx if self.use_mixed else (hx, None)
if is_batched:
if h_state.dim() != 2:
msg = (
"For batched 2-D input, hx and cx should "
f"also be 2-D but got ({h_state.dim()}-D) tensor"
)
raise RuntimeError(msg)
else:
# batchless mode
if h_state.dim() != 1:
msg = (
"For unbatched 1-D input, hx and cx should "
f"also be 1-D but got ({h_state.dim()}-D) tensor"
)
raise RuntimeError(msg)
h_state = h_state.unsqueeze(0)
c_state = c_state.unsqueeze(0) if c_state is not None else None
output_sequence = []
for t in range(seq_len):
if self.batch_first:
inputs = input[:, t]
ts = 1.0 if timespans is None else timespans[:, t].squeeze()
else:
inputs = input[t]
ts = 1.0 if timespans is None else timespans[t].squeeze()
if self.use_mixed:
h_state, c_state = self.lstm(inputs, (h_state, c_state))
h_out, h_state = self.rnn_cell.forward(inputs, h_state, ts)
if self.return_sequences:
output_sequence.append(self.fc(h_out))
if self.return_sequences:
stack_dim = 1 if self.batch_first else 0
readout = torch.stack(output_sequence, dim=stack_dim)
else:
readout = self.fc(h_out)
hx = (h_state, c_state) if self.use_mixed else h_state
if not is_batched:
# batchless mode
readout = readout.squeeze(batch_dim)
hx = (h_state[0], c_state[0]) if self.use_mixed else h_state[0]
return readout, hx