Source code for sign_language_translator.models.language_models.beam_sampling

"""This module provides BeamSampling class for generating completions using beam search sampling."""

from math import exp, log2
from random import random
from typing import Any, Callable, Iterable, Optional, Tuple

from sign_language_translator.models.language_models.abstract_language_model import (
    LanguageModel,
)
from sign_language_translator.utils import sample_one_index


[docs] class BeamSampling: """BeamSampling class for generating completions using beam search sampling. Args: model (LanguageModel): The language model used for generating completions. beam_width (int, optional): The beam width for beam search. Defaults to 3. start_of_sequence_token (str, optional): The start of sequence token. Defaults to "[". end_of_sequence_token (str, optional): The end of sequence token. Defaults to "]". max_length (int, optional): The maximum length of the generated completions. Defaults to 33. scoring_function (Callable[[Iterable, float], float], optional): The scoring function used to score the completions. It should accept the generated sequence and its overall log probability as arguments. Defaults to a linear function. return_log_of_probability (bool, optional): A flag indicating whether to return the probability of the completions or log2 of it. Defaults to True. """ def __init__( self, model: LanguageModel, beam_width: int = 3, start_of_sequence_token="[", end_of_sequence_token="]", max_length: int = 37, scoring_function: Callable[[Iterable, float], float] = ( lambda seq, log_prob: 10.0 + log_prob / len(seq) # type: ignore ), return_log_of_probability: bool = True, ) -> None: self.model = model self.start_of_sequence_token = start_of_sequence_token self.end_of_sequence_token = end_of_sequence_token self.beam_width = beam_width self.max_length = max_length self.scoring_function = scoring_function self.return_log_of_probability = return_log_of_probability def __call__(self, context: Optional[Iterable] = None) -> Iterable: return self.complete(initial_context=context)
[docs] def complete( self, initial_context: Optional[Iterable] = None, append_func: Callable[[Any, Any], Any] = lambda context, token: ( (context + [token]) if isinstance(context, list) else ( (context + (token,)) if isinstance(context, tuple) else (context + token) ) ), ) -> Tuple[Iterable, float]: """Generate completions based on the given initial context. Args: initial_context (Iterable | None, optional): The initial context for completion generation. Defaults to None. append_func (Callable[[Any, Any], Any], optional): a function that can append the generated next token to provided context. Defaults to a lambda function that can append to list, tuple & str. Returns: Tuple[Iterable, float]: One generated completion and its score. """ if initial_context is None: initial_context = [self.start_of_sequence_token] branches = [(initial_context, 0.0)] for _ in range(self.max_length): n_branches = round(self.beam_width + random() * 0.8 - 0.4) # Expand new_branches = [] for context_, score_ in branches: if ( # sequence completed context_[-1] == self.end_of_sequence_token # type: ignore or len(context_) >= self.max_length # type: ignore ): # just as it was. no change. new_branches.append((context_, score_)) continue # append next tokens for _ in range(n_branches): next_token, prob = self.model.next(context_) if next_token == self.model.unknown_token: score = score_ next_context = context_ else: next_context = append_func(context_, next_token) score = score_ + log2(prob) # no repeats; for diversity. if (next_context, score) not in new_branches: new_branches.append((next_context, score)) # Prune new_branches = sorted( new_branches, key=lambda item: self.scoring_function(*item), reverse=True, )[:n_branches] if branches == new_branches: # no branch has grown further: stop break branches = new_branches # softmax: turn each branch's scores into a probability distribution weights = [self.scoring_function(seq, log_prob) for seq, log_prob in branches] weights = [exp(w) for w in weights] weights = [w / sum(weights) for w in weights] # select one branch selected_completion, score = branches[sample_one_index(weights)] # reformat score if not self.return_log_of_probability: score = 2**score # math.exp2(score) return selected_completion, score