Source code for ncps.tf.cfc_cell

# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import tensorflow as tf
from typing import Optional, Union


# LeCun improved tanh activation
# http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
def lecun_tanh(x):
    return 1.7159 * tf.nn.tanh(0.666 * x)


[docs]@tf.keras.utils.register_keras_serializable(package="ncps", name="CfCCell") class CfCCell(tf.keras.layers.AbstractRNNCell): def __init__( self, units, input_sparsity=None, recurrent_sparsity=None, mode="default", activation="lecun_tanh", backbone_units=128, backbone_layers=1, backbone_dropout=0.1, **kwargs, ): """A `Closed-form Continuous-time <https://arxiv.org/abs/2106.13898>`_ cell. .. Note:: This is an RNNCell that process single time-steps. To get a full RNN that can process sequences, see `ncps.tf.CfC` or wrap the cell with a `tf.keras.layers.RNN <https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN>`_. :param units: Number of hidden units :param input_sparsity: :param recurrent_sparsity: :param mode: Either "default", "pure" (direct solution approximation), or "no_gate" (without second gate). :param activation: Activation function used in the backbone layers :param backbone_units: Number of hidden units in the backbone layer (default 128) :param backbone_layers: Number of backbone layers (default 1) :param backbone_dropout: Dropout rate in the backbone layers (default 0) :param kwargs: """ super().__init__(**kwargs) self.units = units self.sparsity_mask = None if input_sparsity is not None or recurrent_sparsity is not None: # No backbone is allowed if backbone_units > 0: raise ValueError( "If sparsity of a Cfc cell is set, then no backbone is allowed" ) # Both need to be set if input_sparsity is None or recurrent_sparsity is None: raise ValueError( "If sparsity of a Cfc cell is set, then both input and recurrent sparsity needs to be defined" ) self.sparsity_mask = tf.constant( np.concatenate([input_sparsity, recurrent_sparsity], axis=0), dtype=tf.float32, ) allowed_modes = ["default", "pure", "no_gate"] if mode not in allowed_modes: raise ValueError( "Unknown mode '{}', valid options are {}".format( mode, str(allowed_modes) ) ) self.mode = mode self.backbone_fn = None if activation == "lecun_tanh": activation = lecun_tanh self._activation = activation self._backbone_units = backbone_units self._backbone_layers = backbone_layers self._backbone_dropout = backbone_dropout self._cfc_layers = [] @property def state_size(self): return self.units
[docs] def build(self, input_shape): if isinstance(input_shape[0], tuple) or isinstance( input_shape[0], tf.TensorShape ): # Nested tuple -> First item represent feature dimension input_dim = input_shape[0][-1] else: input_dim = input_shape[-1] backbone_layers = [] for i in range(self._backbone_layers): backbone_layers.append( tf.keras.layers.Dense( self._backbone_units, self._activation, name=f"backbone{i}" ) ) backbone_layers.append(tf.keras.layers.Dropout(self._backbone_dropout)) self.backbone_fn = tf.keras.models.Sequential(backbone_layers) cat_shape = int( self.state_size + input_dim if self._backbone_layers == 0 else self._backbone_units ) if self.mode == "pure": self.ff1_kernel = self.add_weight( shape=(cat_shape, self.state_size), initializer="glorot_uniform", name="ff1_weight", ) self.ff1_bias = self.add_weight( shape=(self.state_size,), initializer="zeros", name="ff1_bias", ) self.w_tau = self.add_weight( shape=(1, self.state_size), initializer=tf.keras.initializers.Zeros(), name="w_tau", ) self.A = self.add_weight( shape=(1, self.state_size), initializer=tf.keras.initializers.Ones(), name="A", ) else: self.ff1_kernel = self.add_weight( shape=(cat_shape, self.state_size), initializer="glorot_uniform", name="ff1_weight", ) self.ff1_bias = self.add_weight( shape=(self.state_size,), initializer="zeros", name="ff1_bias", ) self.ff2_kernel = self.add_weight( shape=(cat_shape, self.state_size), initializer="glorot_uniform", name="ff2_weight", ) self.ff2_bias = self.add_weight( shape=(self.state_size,), initializer="zeros", name="ff2_bias", ) # = tf.keras.layers.Dense( # , self._activation, name=f"{self.name}/ff1" # ) # self.ff2 = tf.keras.layers.Dense( # self.state_size, self._activation, name=f"{self.name}/ff2" # ) # if self.sparsity_mask is not None: # self.ff1.build((None,)) # self.ff2.build((None, self.sparsity_mask.shape[0])) self.time_a = tf.keras.layers.Dense(self.state_size, name="time_a") self.time_b = tf.keras.layers.Dense(self.state_size, name="time_b") self.built = True
[docs] def call(self, inputs, states, **kwargs): if isinstance(inputs, (tuple, list)): # Irregularly sampled mode inputs, t = inputs t = tf.reshape(t, [-1, 1]) else: # Regularly sampled mode (elapsed time = 1 second) t = 1.0 x = tf.keras.layers.Concatenate()([inputs, states[0]]) x = self.backbone_fn(x) if self.sparsity_mask is not None: ff1_kernel = self.ff1_kernel * self.sparsity_mask ff1 = tf.matmul(x, ff1_kernel) + self.ff1_bias else: ff1 = tf.matmul(x, self.ff1_kernel) + self.ff1_bias if self.mode == "pure": # Solution new_hidden = ( -self.A * tf.math.exp(-t * (tf.math.abs(self.w_tau) + tf.math.abs(ff1))) * ff1 + self.A ) else: # Cfc if self.sparsity_mask is not None: ff2_kernel = self.ff2_kernel * self.sparsity_mask ff2 = tf.matmul(x, ff2_kernel) + self.ff2_bias else: ff2 = tf.matmul(x, self.ff2_kernel) + self.ff2_bias t_a = self.time_a(x) t_b = self.time_b(x) t_interp = tf.nn.sigmoid(-t_a * t + t_b) if self.mode == "no_gate": new_hidden = ff1 + t_interp * ff2 else: new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2 return new_hidden, [new_hidden]