# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ncps import wirings
import numpy as np
import tensorflow as tf
from typing import Optional, Union
[docs]@tf.keras.utils.register_keras_serializable(package="ncps", name="LTCCell")
class LTCCell(tf.keras.layers.AbstractRNNCell):
def __init__(
self,
wiring,
input_mapping="affine",
output_mapping="affine",
ode_unfolds=6,
epsilon=1e-8,
initialization_ranges=None,
**kwargs
):
"""A `Liquid time-constant (LTC) <https://ojs.aaai.org/index.php/AAAI/article/view/16936>`_ cell.
.. Note::
This is an RNNCell that process single time-steps.
To get a full RNN that can process sequences,
see `ncps.tf.LTC` or wrap the cell with a `tf.keras.layers.RNN <https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN>`_.
Examples::
>>> import ncps
>>> from ncps.tf import LTCCell
>>>
>>> wiring = ncps.wirings.Random(16, output_dim=2, sparsity_level=0.5)
>>> cell = LTCCell(wiring)
>>> rnn = tf.keras.layers.RNN(cell)
>>> x = tf.random.uniform((1,4)) # (batch, features)
>>> h0 = tf.zeros((1, 16))
>>> y = cell(x,h0)
>>>
>>> x_seq = tf.random.uniform((1,20,4)) # (batch, time, features)
>>> y_seq = rnn(x_seq)
:param wiring:
:param input_mapping:
:param output_mapping:
:param ode_unfolds:
:param epsilon:
:param initialization_ranges:
:param kwargs:
"""
super().__init__(**kwargs)
self._init_ranges = {
"gleak": (0.001, 1.0),
"vleak": (-0.2, 0.2),
"cm": (0.4, 0.6),
"w": (0.001, 1.0),
"sigma": (3, 8),
"mu": (0.3, 0.8),
"sensory_w": (0.001, 1.0),
"sensory_sigma": (3, 8),
"sensory_mu": (0.3, 0.8),
}
if not initialization_ranges is None:
for k, v in initialization_ranges.items():
if k not in self._init_ranges.keys():
raise ValueError(
"Unknown parameter '{}' in initialization range dictionary! (Expected only {})".format(
k, str(list(self._init_range.keys()))
)
)
if k in ["gleak", "cm", "w", "sensory_w"] and v[0] < 0:
raise ValueError(
"Initialization range of parameter '{}' must be non-negative!".format(
k
)
)
if v[0] > v[1]:
raise ValueError(
"Initialization range of parameter '{}' is not a valid range".format(
k
)
)
self._init_ranges[k] = v
self._wiring = wiring
self._input_mapping = input_mapping
self._output_mapping = output_mapping
self._ode_unfolds = ode_unfolds
self._epsilon = epsilon
@property
def state_size(self):
return self._wiring.units
@property
def sensory_size(self):
return self._wiring.input_dim
@property
def motor_size(self):
return self._wiring.output_dim
@property
def output_size(self):
return self.motor_size
def _get_initializer(self, param_name):
minval, maxval = self._init_ranges[param_name]
if minval == maxval:
return tf.keras.initializers.Constant(minval)
else:
return tf.keras.initializers.RandomUniform(minval, maxval)
[docs] def build(self, input_shape):
# Check if input_shape is nested tuple/list
if isinstance(input_shape[0], tuple) or isinstance(
input_shape[0], tf.TensorShape
):
# Nested tuple -> First item represent feature dimension
input_dim = input_shape[0][-1]
else:
input_dim = input_shape[-1]
self._wiring.build(input_dim)
self._params = {}
self._params["gleak"] = self.add_weight(
name="gleak",
shape=(self.state_size,),
dtype=tf.float32,
constraint=tf.keras.constraints.NonNeg(),
initializer=self._get_initializer("gleak"),
)
self._params["vleak"] = self.add_weight(
name="vleak",
shape=(self.state_size,),
dtype=tf.float32,
initializer=self._get_initializer("vleak"),
)
self._params["cm"] = self.add_weight(
name="cm",
shape=(self.state_size,),
dtype=tf.float32,
constraint=tf.keras.constraints.NonNeg(),
initializer=self._get_initializer("cm"),
)
self._params["sigma"] = self.add_weight(
name="sigma",
shape=(self.state_size, self.state_size),
dtype=tf.float32,
initializer=self._get_initializer("sigma"),
)
self._params["mu"] = self.add_weight(
name="mu",
shape=(self.state_size, self.state_size),
dtype=tf.float32,
initializer=self._get_initializer("mu"),
)
self._params["w"] = self.add_weight(
name="w",
shape=(self.state_size, self.state_size),
dtype=tf.float32,
constraint=tf.keras.constraints.NonNeg(),
initializer=self._get_initializer("w"),
)
self._params["erev"] = self.add_weight(
name="erev",
shape=(self.state_size, self.state_size),
dtype=tf.float32,
initializer=self._wiring.erev_initializer,
)
self._params["sensory_sigma"] = self.add_weight(
name="sensory_sigma",
shape=(self.sensory_size, self.state_size),
dtype=tf.float32,
initializer=self._get_initializer("sensory_sigma"),
)
self._params["sensory_mu"] = self.add_weight(
name="sensory_mu",
shape=(self.sensory_size, self.state_size),
dtype=tf.float32,
initializer=self._get_initializer("sensory_mu"),
)
self._params["sensory_w"] = self.add_weight(
name="sensory_w",
shape=(self.sensory_size, self.state_size),
dtype=tf.float32,
constraint=tf.keras.constraints.NonNeg(),
initializer=self._get_initializer("sensory_w"),
)
self._params["sensory_erev"] = self.add_weight(
name="sensory_erev",
shape=(self.sensory_size, self.state_size),
dtype=tf.float32,
initializer=self._wiring.sensory_erev_initializer,
)
self._params["sparsity_mask"] = tf.constant(
np.abs(self._wiring.adjacency_matrix), dtype=tf.float32
)
self._params["sensory_sparsity_mask"] = tf.constant(
np.abs(self._wiring.sensory_adjacency_matrix), dtype=tf.float32
)
if self._input_mapping in ["affine", "linear"]:
self._params["input_w"] = self.add_weight(
name="input_w",
shape=(self.sensory_size,),
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(1),
)
if self._input_mapping == "affine":
self._params["input_b"] = self.add_weight(
name="input_b",
shape=(self.sensory_size,),
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(0),
)
if self._output_mapping in ["affine", "linear"]:
self._params["output_w"] = self.add_weight(
name="output_w",
shape=(self.motor_size,),
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(1),
)
if self._output_mapping == "affine":
self._params["output_b"] = self.add_weight(
name="output_b",
shape=(self.motor_size,),
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(0),
)
self.built = True
def _sigmoid(self, v_pre, mu, sigma):
v_pre = tf.expand_dims(v_pre, axis=-1) # For broadcasting
mues = v_pre - mu
x = sigma * mues
return tf.nn.sigmoid(x)
def _ode_solver(self, inputs, state, elapsed_time):
v_pre = state
# We can pre-compute the effects of the sensory neurons here
sensory_w_activation = self._params["sensory_w"] * self._sigmoid(
inputs, self._params["sensory_mu"], self._params["sensory_sigma"]
)
sensory_w_activation *= self._params["sensory_sparsity_mask"]
sensory_rev_activation = sensory_w_activation * self._params["sensory_erev"]
# Reduce over dimension 1 (=source sensory neurons)
w_numerator_sensory = tf.reduce_sum(sensory_rev_activation, axis=1)
w_denominator_sensory = tf.reduce_sum(sensory_w_activation, axis=1)
# cm/t is loop invariant
cm_t = self._params["cm"] / tf.cast(
elapsed_time / self._ode_unfolds, dtype=tf.float32
)
# Unfold the multiply ODE multiple times into one RNN step
for t in range(self._ode_unfolds):
w_activation = self._params["w"] * self._sigmoid(
v_pre, self._params["mu"], self._params["sigma"]
)
w_activation *= self._params["sparsity_mask"]
rev_activation = w_activation * self._params["erev"]
# Reduce over dimension 1 (=source neurons)
w_numerator = tf.reduce_sum(rev_activation, axis=1) + w_numerator_sensory
w_denominator = tf.reduce_sum(w_activation, axis=1) + w_denominator_sensory
numerator = (
cm_t * v_pre
+ self._params["gleak"] * self._params["vleak"]
+ w_numerator
)
denominator = cm_t + self._params["gleak"] + w_denominator
# Avoid dividing by 0
v_pre = numerator / (denominator + self._epsilon)
return v_pre
def _map_inputs(self, inputs):
if self._input_mapping in ["affine", "linear"]:
inputs = inputs * self._params["input_w"]
if self._input_mapping == "affine":
inputs = inputs + self._params["input_b"]
return inputs
def _map_outputs(self, state):
output = state
if self.motor_size < self.state_size:
output = output[:, 0 : self.motor_size]
if self._output_mapping in ["affine", "linear"]:
output = output * self._params["output_w"]
if self._output_mapping == "affine":
output = output + self._params["output_b"]
return output
[docs] def call(self, inputs, states):
if isinstance(inputs, (tuple, list)):
# Irregularly sampled mode
inputs, elapsed_time = inputs
else:
# Regularly sampled mode (elapsed time = 1 second)
elapsed_time = 1.0
inputs = self._map_inputs(inputs)
next_state = self._ode_solver(inputs, states[0], elapsed_time)
outputs = self._map_outputs(next_state)
return outputs, [next_state]
[docs] def get_config(self):
seralized = self._wiring.get_config()
seralized["input_mapping"] = self._input_mapping
seralized["output_mapping"] = self._output_mapping
seralized["ode_unfolds"] = self._ode_unfolds
seralized["epsilon"] = self._epsilon
return seralized
[docs] @classmethod
def from_config(cls, config):
wiring = wirings.Wiring.from_config(config)
return cls(wiring=wiring, **config)