Source code for ncps.tf.ltc

# Copyright 2022 Mathias Lechner and Ramin Hasani
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ncps
from . import LTCCell, MixedMemoryRNN
import tensorflow as tf
from typing import Optional, Union


[docs]@tf.keras.utils.register_keras_serializable(package="ncps", name="LTC") class LTC(tf.keras.layers.RNN): def __init__( self, units, mixed_memory: bool = False, input_mapping="affine", output_mapping="affine", ode_unfolds=6, epsilon=1e-8, initialization_ranges=None, return_sequences: bool = False, return_state: bool = False, go_backwards: bool = False, stateful: bool = False, unroll: bool = False, time_major: bool = False, **kwargs, ): """Applies a `Liquid time-constant (LTC) <https://ojs.aaai.org/index.php/AAAI/article/view/16936>`_ RNN to an input sequence. Examples:: >>> from ncps.tf import LTC >>> >>> rnn = LTC(50) >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) >>> y = rnn(x) .. Note:: For creating a wired `Neural circuit policy (NCP) <https://publik.tuwien.ac.at/files/publik_292280.pdf>`_ you can pass a `ncps.wirings.NCP` object instead of the number of units Examples:: >>> from ncps.tf import LTC >>> from ncps.wirings import NCP >>> >>> wiring = NCP(10, 10, 8, 6, 6, 4, 4) >>> rnn = LTC(wiring) >>> x = tf.random.uniform((2, 10, 20)) # (B,L,C) >>> y = rnn(x) :param units: Wiring (ncps.wirings.Wiring instance) or integer representing the number of (fully-connected) hidden units :param mixed_memory: Whether to augment the RNN by a `memory-cell <https://arxiv.org/abs/2006.04418>`_ to help learn long-term dependencies in the data :param input_mapping: Mapping applied to the sensory neurons. Possible values None, "linear", "affine" (default "affine") :param output_mapping: Mapping applied to the motor neurons. Possible values None, "linear", "affine" (default "affine") :param ode_unfolds: Number of ODE-solver steps per time-step (default 6) :param epsilon: Auxillary value to avoid dividing by 0 (default 1e-8) :param initialization_ranges: A dictionary for overwriting the range of the uniform weight initialization (default None) :param return_sequences: Whether to return the full sequence or just the last output (default False) :param return_state: Whether to return just the output of the RNN or a tuple (output, last_hidden_state) (default False) :param go_backwards: If True, the input sequence will be process from back to the front (default False) :param stateful: Whether to remember the last hidden state of the previous inference/training batch and use it as initial state for the next inference/training batch (default False) :param unroll: Whether to unroll the graph, i.e., may increase speed at the cost of more memory (default False) :param time_major: Whether the time or batch dimension is the first (0-th) dimension (default False) :param kwargs: """ if isinstance(units, ncps.wirings.Wiring): wiring = units else: wiring = ncps.wirings.FullyConnected(units) cell = LTCCell( wiring=wiring, input_mapping=input_mapping, output_mapping=output_mapping, ode_unfolds=ode_unfolds, epsilon=epsilon, initialization_ranges=initialization_ranges, **kwargs, ) if mixed_memory: cell = MixedMemoryRNN(cell) super(LTC, self).__init__( cell, return_sequences, return_state, go_backwards, stateful, unroll, time_major, **kwargs, )