Source code for ncps.torch.ltc_cell

# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union


[docs]class LTCCell(nn.Module): def __init__( self, wiring, in_features=None, input_mapping="affine", output_mapping="affine", ode_unfolds=6, epsilon=1e-8, implicit_param_constraints=False, ): """A `Liquid time-constant (LTC) <https://ojs.aaai.org/index.php/AAAI/article/view/16936>`_ cell. .. Note:: This is an RNNCell that process single time-steps. To get a full RNN that can process sequences see `ncps.torch.LTC`. :param wiring: :param in_features: :param input_mapping: :param output_mapping: :param ode_unfolds: :param epsilon: :param implicit_param_constraints: """ super(LTCCell, self).__init__() if in_features is not None: wiring.build(in_features) if not wiring.is_built(): raise ValueError( "Wiring error! Unknown number of input features. Please pass the parameter 'in_features' or call the 'wiring.build()'." ) self.make_positive_fn = ( nn.Softplus() if implicit_param_constraints else nn.Identity() ) self._implicit_param_constraints = implicit_param_constraints self._init_ranges = { "gleak": (0.001, 1.0), "vleak": (-0.2, 0.2), "cm": (0.4, 0.6), "w": (0.001, 1.0), "sigma": (3, 8), "mu": (0.3, 0.8), "sensory_w": (0.001, 1.0), "sensory_sigma": (3, 8), "sensory_mu": (0.3, 0.8), } self._wiring = wiring self._input_mapping = input_mapping self._output_mapping = output_mapping self._ode_unfolds = ode_unfolds self._epsilon = epsilon self._clip = torch.nn.ReLU() self._allocate_parameters() @property def state_size(self): return self._wiring.units @property def sensory_size(self): return self._wiring.input_dim @property def motor_size(self): return self._wiring.output_dim @property def output_size(self): return self.motor_size @property def synapse_count(self): return np.sum(np.abs(self._wiring.adjacency_matrix)) @property def sensory_synapse_count(self): return np.sum(np.abs(self._wiring.adjacency_matrix))
[docs] def add_weight(self, name, init_value, requires_grad=True): param = torch.nn.Parameter(init_value, requires_grad=requires_grad) self.register_parameter(name, param) return param
def _get_init_value(self, shape, param_name): minval, maxval = self._init_ranges[param_name] if minval == maxval: return torch.ones(shape) * minval else: return torch.rand(*shape) * (maxval - minval) + minval def _allocate_parameters(self): self._params = {} self._params["gleak"] = self.add_weight( name="gleak", init_value=self._get_init_value((self.state_size,), "gleak") ) self._params["vleak"] = self.add_weight( name="vleak", init_value=self._get_init_value((self.state_size,), "vleak") ) self._params["cm"] = self.add_weight( name="cm", init_value=self._get_init_value((self.state_size,), "cm") ) self._params["sigma"] = self.add_weight( name="sigma", init_value=self._get_init_value( (self.state_size, self.state_size), "sigma" ), ) self._params["mu"] = self.add_weight( name="mu", init_value=self._get_init_value((self.state_size, self.state_size), "mu"), ) self._params["w"] = self.add_weight( name="w", init_value=self._get_init_value((self.state_size, self.state_size), "w"), ) self._params["erev"] = self.add_weight( name="erev", init_value=torch.Tensor(self._wiring.erev_initializer()), ) self._params["sensory_sigma"] = self.add_weight( name="sensory_sigma", init_value=self._get_init_value( (self.sensory_size, self.state_size), "sensory_sigma" ), ) self._params["sensory_mu"] = self.add_weight( name="sensory_mu", init_value=self._get_init_value( (self.sensory_size, self.state_size), "sensory_mu" ), ) self._params["sensory_w"] = self.add_weight( name="sensory_w", init_value=self._get_init_value( (self.sensory_size, self.state_size), "sensory_w" ), ) self._params["sensory_erev"] = self.add_weight( name="sensory_erev", init_value=torch.Tensor(self._wiring.sensory_erev_initializer()), ) self._params["sparsity_mask"] = self.add_weight( "sparsity_mask", torch.Tensor(np.abs(self._wiring.adjacency_matrix)), requires_grad=False, ) self._params["sensory_sparsity_mask"] = self.add_weight( "sensory_sparsity_mask", torch.Tensor(np.abs(self._wiring.sensory_adjacency_matrix)), requires_grad=False, ) if self._input_mapping in ["affine", "linear"]: self._params["input_w"] = self.add_weight( name="input_w", init_value=torch.ones((self.sensory_size,)), ) if self._input_mapping == "affine": self._params["input_b"] = self.add_weight( name="input_b", init_value=torch.zeros((self.sensory_size,)), ) if self._output_mapping in ["affine", "linear"]: self._params["output_w"] = self.add_weight( name="output_w", init_value=torch.ones((self.motor_size,)), ) if self._output_mapping == "affine": self._params["output_b"] = self.add_weight( name="output_b", init_value=torch.zeros((self.motor_size,)), ) def _sigmoid(self, v_pre, mu, sigma): v_pre = torch.unsqueeze(v_pre, -1) # For broadcasting mues = v_pre - mu x = sigma * mues return torch.sigmoid(x) def _ode_solver(self, inputs, state, elapsed_time): v_pre = state # We can pre-compute the effects of the sensory neurons here sensory_w_activation = self.make_positive_fn( self._params["sensory_w"] ) * self._sigmoid( inputs, self._params["sensory_mu"], self._params["sensory_sigma"] ) sensory_w_activation = ( sensory_w_activation * self._params["sensory_sparsity_mask"] ) sensory_rev_activation = sensory_w_activation * self._params["sensory_erev"] # Reduce over dimension 1 (=source sensory neurons) w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1) w_denominator_sensory = torch.sum(sensory_w_activation, dim=1) # cm/t is loop invariant cm_t = self.make_positive_fn(self._params["cm"]) / ( elapsed_time / self._ode_unfolds ) # Unfold the multiply ODE multiple times into one RNN step w_param = self.make_positive_fn(self._params["w"]) for t in range(self._ode_unfolds): w_activation = w_param * self._sigmoid( v_pre, self._params["mu"], self._params["sigma"] ) w_activation = w_activation * self._params["sparsity_mask"] rev_activation = w_activation * self._params["erev"] # Reduce over dimension 1 (=source neurons) w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory gleak = self.make_positive_fn(self._params["gleak"]) numerator = cm_t * v_pre + gleak * self._params["vleak"] + w_numerator denominator = cm_t + gleak + w_denominator # Avoid dividing by 0 v_pre = numerator / (denominator + self._epsilon) return v_pre def _map_inputs(self, inputs): if self._input_mapping in ["affine", "linear"]: inputs = inputs * self._params["input_w"] if self._input_mapping == "affine": inputs = inputs + self._params["input_b"] return inputs def _map_outputs(self, state): output = state if self.motor_size < self.state_size: output = output[:, 0 : self.motor_size] # slice if self._output_mapping in ["affine", "linear"]: output = output * self._params["output_w"] if self._output_mapping == "affine": output = output + self._params["output_b"] return output
[docs] def apply_weight_constraints(self): if not self._implicit_param_constraints: # In implicit mode, the parameter constraints are implemented via # a softplus function at runtime self._params["w"].data = self._clip(self._params["w"].data) self._params["sensory_w"].data = self._clip(self._params["sensory_w"].data) self._params["cm"].data = self._clip(self._params["cm"].data) self._params["gleak"].data = self._clip(self._params["gleak"].data)
[docs] def forward(self, inputs, states, elapsed_time=1.0): # Regularly sampled mode (elapsed time = 1 second) inputs = self._map_inputs(inputs) next_state = self._ode_solver(inputs, states, elapsed_time) outputs = self._map_outputs(next_state) return outputs, next_state