Neural network layers#
These are the basic building blocks for neural networks. In order to use them, include the header like in the example below:
#include <metalchat/nn.h>
using namespace metalchat::nn;
Attention#
-
template<typename T, contiguous_container Container, cache_t<T> Cache = sink_cache<T>>
class attention : public metalchat::nn::basic_layer# Allows the model to jointly attend to information from different representation subspaces.
This attention layer implements the original architecture described in the Attention Is All You Need paper.
Public Functions
-
inline void enable_norm(float eps)#
Enable RMS-normalization of keys and queries.
-
template<immutable_tensor3_t<T> Input, immutable_tensor2_t<T> Mask>
inline auto operator()(Input input, std::optional<Mask> mask = std::nullopt, std::size_t start_pos = 0)# Compute multi-head attention of the input sequence.
- Parameters:
input – an input embedding.
mask – if specified, a 2-dim mask preventing attention to certain positions.
start_pos – a start position of the input sequence.
-
inline void enable_norm(float eps)#
Embedding#
-
template<typename T, contiguous_container Container = hardware_memory_container<T>>
class embedding : public metalchat::nn::basic_embedding<T, hardware_memory_container<T>># A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.
Rotary positional embedding#
-
template<typename T>
class rope : public metalchat::nn::basic_layer# This class implements Rotary Positional Embeddings (RoPE).
In this implementation we cache the frequencies for each position. When user requests an embedding with start position that is not presented in the cache, the module will recompute the cached frequencies for a range
[start_pos, start_pos + max_seq_len).
Linear#
-
template<typename T, contiguous_container Container = hardware_memory_container<T>>
class linear : public metalchat::nn::basic_linear<T, hardware_memory_container<T>># Applies an affine linear transformation to the input data.
This module does not support bias adjustment to the input tensor, and only multiplies it (input) by the specified weight tensor. Meaning it effectively works as matrix multiplication operation.
Root mean square normalization#
-
template<typename T, contiguous_container Container = hardware_memory_container<T>>
class rmsnorm : public metalchat::nn::basic_layer# Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
Transformer#
-
template<typename T, contiguous_container Container = hardware_memory_container<T>, typename Activation = kernel::silu<T>>
class transformer : public metalchat::nn::basic_layer# Public Functions
-
inline void enable_norm(float eps)#
Enable normalization of the attention and feed-forward layers.
The method registers two additional RMS-normalization layers that are executed before the attention and feed-forward layers respectively.
-
inline void enable_post_norm(float eps)#
Enable post-normalization of the attention and feed-forward (also called MLP) layers.
The method registers two additional RMS-normalization layers that are executed right after the attention and feed-forward layers respectively.
-
inline void enable_norm(float eps)#