Language models#
These are the building blocks of language models. This library could be used like following:
#include <metalchat/nn.h>
using namespace metalchat::nn;
Meta Llama 3#
-
template<typename T, contiguous_container Container = hardware_memory_container<T>, cache_t<T> Cache = sink_cache<T>>
class llama3 : public metalchat::nn::basic_layer# Llama 3 is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human preferences for helpfulness and safety.
Public Functions
-
inline llama3(const llama3_options &options, hardware_accelerator &accelerator)#
Constructs a new Llama3 model with uninitialized weights with the given options.
-
template<immutable_tensor2_t<index_type> Input>
inline auto operator()(Input input, std::size_t start_pos = 0)# Invoke the layer.
- Template Parameters:
Input – type of the input tensor.
- Parameters:
input – a 2-dimensional tensor with the indices of the input tokens.
start_pos – a start position of the input sequence.
- Returns:
a future_tensor with logits of model vocabulary.
-
inline llama3(const llama3_options &options, hardware_accelerator &accelerator)#
Key-value caching#
In autoregressive language models, key-value tensors power the attention mechanism that determines how tokens relate to each other. These models generate text one token at a time, and each new prediction requires attention calculations across all previous tokens in the sequence. Without optimization, this creates a costly cycle of redundant work. Every time the model predicts the next token, it recalculates the same key-value tensors for tokens it has already processed.
The cache is a layer and is registered as a layer as well within a language model, therefore it
is possible to access cache tensors through regular metalchat::basic_layer api.
For example, to access cache for the 2-nd layer use the following approach:
using namespace metalchat;
using namespace metalchat::dtype;
hardware_accelerator accelerator;
nn::llama3<bf16> llm(default_llama3_1b_options(), accelerator);
std::cout << llm.get_parameter("caches.2.keys")->sizes() << std::endl;:
std::cout << llm.get_parameter("caches.2.values")->sizes() << std::endl;
// out:
// 1, 1024, 8, 64
// 1, 1024, 8, 64
-
struct caching_options#
Caching options to configure the key-value cache of the large language model.
Public Members
-
std::size_t head_dim#
Per-attention head embedding dimension.
-
std::size_t n_heads#
Number of query heads.
-
std::size_t n_kv_heads#
Number of key and value heads.
-
std::size_t max_seq_len#
Maximum sequence length model will be run with.
-
std::size_t max_batch_size#
Batch size the model will be run with.
-
std::size_t head_dim#
-
template<typename T>
struct caching_result# The result of a cache update query. The dimensions of the caching_result::keys and caching_result::values tensors are like following: [
bs,max_seq_len,n_kv_heads,n_heads].- Template Parameters:
T – data type of the cache elements (float, int, etc.).
Public Members
-
future_tensor<T, 4> keys#
A future tensor that will contain a result of keys caching query.
-
future_tensor<T, 4> values#
A future tensor that will contain a result of values caching query.
-
std::optional<future_tensor<T, 2>> mask#
An optional future tensor that will contain a result of additive causal mask creation. This mask is created only when the length of keys (or values) is larger than
1.
-
template<typename T>
class sink_cache : public metalchat::nn::basic_layer# Implementation of the KV cache introduced in attention sinks paper. It allows the model to generate text beyond the length of its context window, without losing fluency in the conversation.
This is done by always keeping the first few tokens (“sink tokens”) in the KV cache, as models often pay a large amount of attention to them during the training. As it discards past non-sink tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. It’s also a solution to contain the memory footprint of the KV cache.
Warning
The implementation does not track if the specified start position corresponds to the latest used start position. So if a user calls an attention layer with
start_pos = 15and cache size set to16, and then makes subsequent call withstart_pos = 44, the implementation won’t complain, but the result might be not what a user expects.- Template Parameters:
T – data type of the cache elements (float, int, etc.).
Public Functions
-
inline sink_cache(std::size_t pre_len, const caching_options &options, hardware_accelerator accelerator)#
Constructs a new instance of the sink cache.
- Parameters:
pre_len – a number of sink tokens that will be permanently kept in cache.
options – caching options for the KV cache.
accelerator – a hardware accelerator instance.
-
inline sink_cache(const caching_options &options, hardware_accelerator accelerator)#
Constructs a new instance of the sink cache with the number of sink tokens set to the logarithm of base 2 from the maximum length of the context window.
- Parameters:
options – caching options for the KV cache.
accelerator – a hardware accelerator instance.
-
inline caching_result<T> update(input_tensor keys, input_tensor vals, std::size_t start_pos)#
Updates the cache tensor with new keys and values.
- Parameters:
keys – new keys to cache.
vals – new values to cache.
start_pos – position of the next token in an output sequence.
Sampling#
-
template<typename T, typename Index>
struct basic_sampling_context# Provides access to sampling state consisting of raw logits and their positions in the model vocabulary.
Public Types
-
using logits_tensor = future_tensor<value_type, 2>#
Logits tensor type.
-
using index_tensor = future_tensor<index_type, 2>#
Index tensor type.
-
using logits_tensor = future_tensor<value_type, 2>#
-
template<typename T>
struct basic_sampler# Subclassed by metalchat::nn::multinomial_sampler< T >, metalchat::nn::nucleus_sampler< T >, metalchat::nn::sequential_sampler< T >, metalchat::nn::topk_sampler< T >
Public Functions
-
virtual context_type sample(const context_type &context, hardware_accelerator &accelerator) = 0#
Return subset of raw logits and their indices (context) that should be considered in token sequence generation for a language transformer model.
-
virtual ~basic_sampler() = default#
A default virtual destructor.
-
virtual context_type sample(const context_type &context, hardware_accelerator &accelerator) = 0#
-
template<typename T>
class nucleus_sampler : public metalchat::nn::basic_sampler<T># A sampler that selects the smallest set of elements whose cumulative probability exceeds the probability
p.This version of a sampler combines top-p sampling with temperature scaling.
Public Functions
-
inline nucleus_sampler(T temperature, T p)#
The nucleus_sampler constructor.
- Parameters:
temperature – a positive value used to modulate the logits distribution.
p – the cumulative probability cutoff value.
-
inline nucleus_sampler()#
The default nucleus_sampler constructor initializers
temperatureparameters with0.6, andpparameter with0.9value.
-
inline virtual context_type sample(const context_type &context, hardware_accelerator &accelerator)#
Return subset of raw logits and their indices (context) that should be considered in token sequence generation for a language transformer model.
-
inline nucleus_sampler(T temperature, T p)#
-
template<typename T>
class multinomial_sampler : public metalchat::nn::basic_sampler<T># Draws samples interpreting logits array as a cumulative distribution function of a multinomial distribution.
Warning
In order to using this sampler, the logits must be sorted in a descending order, otherwise the result is undefined.
Public Functions
-
inline multinomial_sampler(std::size_t sample_size = 1)#
The multinomial_sampler constructor.
- Parameters:
sample_size – a number of samples that should be drawn from the distribution.
-
inline virtual context_type sample(const context_type &context, hardware_accelerator &accelerator)#
Return subset of raw logits and their indices (context) that should be considered in token sequence generation for a language transformer model.
-
inline multinomial_sampler(std::size_t sample_size = 1)#
-
template<typename T>
class topk_sampler : public metalchat::nn::basic_sampler<T># A CPU-based top-k logits sampling. It restricts the pool of candidate tokens to the k most likely tokens.
The sampler processes each batch element independently, applying top-k filtering row-wise to the input logits tensor. If k exceeds the vocabulary size, all tokens are retained.
Warning
This implementation uses a CPU-based selection algorithm to find top-k largest elements in the logits tensor. This implies that pending command queue is submitted to GPU for processing and the result (logits) are awaited by blocking a thread.
- Template Parameters:
T – The data type of the logits (e.g., float, bf16)
Public Functions
-
inline topk_sampler(std::size_t k)#
Constructs a top-k sampler with the specified k value.
- Parameters:
k – The number of top candidates to retain. Keeps all tokens, if k is larger than vocabulary size.
-
inline virtual context_type sample(const context_type &context, hardware_accelerator &accelerator)#
Applies top-k sampling to the input logits tensor.
- Parameters:
context – input logits and indices of shape [batch_size, vocab_size].
accelerator – hardware accelerator used for tensor operations.
-
template<typename T>
class sequential_sampler : public metalchat::nn::basic_sampler<T># A sampler that applies the provided samplers one after another, passing the output from the previous sampler to the next one as an input.
Here is an example how to create the most common sampling strategy using a composition of nucleus_sampler and multinomial_sampler.
using namespace metalchat::nn; auto sampler = sequential_sampler<float>({ std::make_shared<nucleus_sampler<float>>(), std::make_shared<multinomial_sampler<float>>() });
Note
When sequential sampler is created from an empty range, the sequential sampler returns unmodified sampling context.
Public Types
-
using sampler_type = basic_sampler<T>#
A base type for samplers comprising a sequential_sampler.
-
using sampler_pointer = std::shared_ptr<sampler_type>#
A shared pointer type to the sampler_type.
Public Functions
-
inline sequential_sampler(std::initializer_list<sampler_pointer> samplers)#
Constructs the sequential_sampler from the list of samplers.
-
template<std::forward_iterator ForwardIt>
inline sequential_sampler(ForwardIt first, ForwardIt last)# Constructs the sequential_sampler by moving elements from the specified range.
- Parameters:
first, last – the pair of iterators defining the range of elements to move the samplers from.
-
inline sequential_sampler()#
The default sequential_sampler constructor.
-
inline virtual context_type sample(const context_type &context, hardware_accelerator &accelerator)#
Return subset of raw logits and their indices (context) that should be considered in token sequence generation for a language transformer model.
-
using sampler_type = basic_sampler<T>#