Source code for ncps.torch.ltc

# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import nn
from typing import Optional, Union
import ncps
from . import CfCCell, LTCCell
from .lstm import LSTMCell


[docs]class LTC(nn.Module): def __init__( self, input_size: int, units, return_sequences: bool = True, batch_first: bool = True, mixed_memory: bool = False, input_mapping="affine", output_mapping="affine", ode_unfolds=6, epsilon=1e-8, implicit_param_constraints=True, ): """Applies a `Liquid time-constant (LTC) <https://ojs.aaai.org/index.php/AAAI/article/view/16936>`_ RNN to an input sequence. Examples:: >>> from ncps.torch import LTC >>> >>> rnn = LTC(20,50) >>> x = torch.randn(2, 3, 20) # (batch, time, features) >>> h0 = torch.zeros(2,50) # (batch, units) >>> output, hn = rnn(x,h0) .. Note:: For creating a wired `Neural circuit policy (NCP) <https://publik.tuwien.ac.at/files/publik_292280.pdf>`_ you can pass a `ncps.wirings.NCP` object instead of the number of units Examples:: >>> from ncps.torch import LTC >>> from ncps.wirings import NCP >>> >>> wiring = NCP(10, 10, 8, 6, 6, 4, 6) >>> rnn = LTC(20, wiring) >>> x = torch.randn(2, 3, 20) # (batch, time, features) >>> h0 = torch.zeros(2, 28) # (batch, units) >>> output, hn = rnn(x,h0) :param input_size: Number of input features :param units: Wiring (ncps.wirings.Wiring instance) or integer representing the number of (fully-connected) hidden units :param return_sequences: Whether to return the full sequence or just the last output :param batch_first: Whether the batch or time dimension is the first (0-th) dimension :param mixed_memory: Whether to augment the RNN by a `memory-cell <https://arxiv.org/abs/2006.04418>`_ to help learn long-term dependencies in the data :param input_mapping: :param output_mapping: :param ode_unfolds: :param epsilon: :param implicit_param_constraints: """ super(LTC, self).__init__() self.input_size = input_size self.wiring_or_units = units self.batch_first = batch_first self.return_sequences = return_sequences if isinstance(units, ncps.wirings.Wiring): wiring = units else: wiring = ncps.wirings.FullyConnected(units) self.rnn_cell = LTCCell( wiring=wiring, in_features=input_size, input_mapping=input_mapping, output_mapping=output_mapping, ode_unfolds=ode_unfolds, epsilon=epsilon, implicit_param_constraints=implicit_param_constraints, ) self._wiring = wiring self.use_mixed = mixed_memory if self.use_mixed: self.lstm = LSTMCell(input_size, self.state_size) @property def state_size(self): return self._wiring.units @property def sensory_size(self): return self._wiring.input_dim @property def motor_size(self): return self._wiring.output_dim @property def output_size(self): return self.motor_size @property def synapse_count(self): return np.sum(np.abs(self._wiring.adjacency_matrix)) @property def sensory_synapse_count(self): return np.sum(np.abs(self._wiring.adjacency_matrix))
[docs] def forward(self, input, hx=None, timespans=None): """ :param input: Input tensor of shape (L,C) in batchless mode, or (B,L,C) if batch_first was set to True and (L,B,C) if batch_first is False :param hx: Initial hidden state of the RNN of shape (B,H) if mixed_memory is False and a tuple ((B,H),(B,H)) if mixed_memory is True. If None, the hidden states are initialized with all zeros. :param timespans: :return: A pair (output, hx), where output and hx the final hidden state of the RNN """ device = input.device is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 seq_dim = 1 if self.batch_first else 0 if not is_batched: input = input.unsqueeze(batch_dim) if timespans is not None: timespans = timespans.unsqueeze(batch_dim) batch_size, seq_len = input.size(batch_dim), input.size(seq_dim) if hx is None: h_state = torch.zeros((batch_size, self.state_size), device=device) c_state = ( torch.zeros((batch_size, self.state_size), device=device) if self.use_mixed else None ) else: if self.use_mixed and isinstance(hx, torch.Tensor): raise RuntimeError( "Running a CfC with mixed_memory=True, requires a tuple (h0,c0) to be passed as state (got torch.Tensor instead)" ) h_state, c_state = hx if self.use_mixed else (hx, None) if is_batched: if h_state.dim() != 2: msg = ( "For batched 2-D input, hx and cx should " f"also be 2-D but got ({h_state.dim()}-D) tensor" ) raise RuntimeError(msg) else: # batchless mode if h_state.dim() != 1: msg = ( "For unbatched 1-D input, hx and cx should " f"also be 1-D but got ({h_state.dim()}-D) tensor" ) raise RuntimeError(msg) h_state = h_state.unsqueeze(0) c_state = c_state.unsqueeze(0) if c_state is not None else None output_sequence = [] for t in range(seq_len): if self.batch_first: inputs = input[:, t] ts = 1.0 if timespans is None else timespans[:, t].squeeze() else: inputs = input[t] ts = 1.0 if timespans is None else timespans[t].squeeze() if self.use_mixed: h_state, c_state = self.lstm(inputs, (h_state, c_state)) h_out, h_state = self.rnn_cell.forward(inputs, h_state, ts) if self.return_sequences: output_sequence.append(h_out) if self.return_sequences: stack_dim = 1 if self.batch_first else 0 readout = torch.stack(output_sequence, dim=stack_dim) else: readout = h_out hx = (h_state, c_state) if self.use_mixed else h_state if not is_batched: # batchless mode readout = readout.squeeze(batch_dim) hx = (h_state[0], c_state[0]) if self.use_mixed else h_state[0] return readout, hx