sign_language_translator.models.language_models.transformer_language_model.train module

This module contains classes to train transformer language models.

Classes:

LM_Dataset(torch.utils.data.Dataset): subclass for language model. has a process function that can convert text file into a list of tensors. LM_Trainer: Trainer class for language model. runs training loop, prints metrics and makes model checkpoints.

class sign_language_translator.models.language_models.transformer_language_model.train.LM_Dataset(data: Tensor)[source]

Bases: Dataset

static prepare(file_path: str, text_to_token_ids: Callable[[str], List[int]], max_sequence_length: int = 32, encoding='utf-8', dtype=torch.int32) List[Tensor][source]

Process a text file into list of 2d torch tensors of shape (n_examples, n_tokens).

Parameters:
  • file_path (str) – where the input file is stored

  • text_to_token_ids (Callable[[str], List[int]]) – a function that can process a line from file and convert it into a list of token ids.

  • max_sequence_length (int, optional) – make n_grams of sequences longer than this of size max_sequence_length. Defaults to 32.

  • encoding (str, optional) – the encoding used in the text file. Defaults to “utf-8”.

  • dtype (_type_, optional) – the type of returned torch tensors. check the range of values a type can contain and choose the smallest to save space. Defaults to torch.int.

Returns:

_description_

Return type:

List[torch.Tensor]

class sign_language_translator.models.language_models.transformer_language_model.train.LM_Trainer(model: TransformerLanguageModel, device: str = 'cpu', epochs: int = 10, learning_rate: float = 0.001, lr_lambda: Callable[[int, float, float], float] = <function LM_Trainer.<lambda>>, lr_update_step_count: Optional[int] = None, optimizer='adamw', seed: int = 0, model_output_renderer: Optional[Callable[[TransformerLanguageModel], str]] = None, epoch_unfreeze_map: Optional[Dict[int, List[str]]] = None, class_weights: Optional[torch.Tensor] = None, max_gradient_norm: Optional[float] = None)[source]

Bases: object

class contains functions to train a language model built with pytorch. It is not designed to be generic rather it is specific to the TransformerLanguageModel.

checkpoint(checkpoint_dir: str, losses, epoch, steps_fraction) None[source]

save metrics in model and save model to disk.

run(train_batches: Iterable[Tuple[Tensor, Tensor]], validation_batches: Iterable[Tuple[Tensor, Tensor]], early_stop: bool = False, checkpoint_dir: str = '', checkpoint_step_count: int = 1000, model_output_step_count: int = 100, start_epoch_number: int = 0) Dict[str, List[float]][source]

Run the training/validation loop and generate output & checkpoints.

Returns:

the tracked metrics

Return type:

Dict[str, List[float]]

train(input_sequences, outputs) Dict[str, float][source]

Training loop. calculates loss on non-padding tokens. Multiples class weights with the loss to scale it for imbalanced token distributions. Clips gradients with norm > max_gradient_norm to avoid exploding gradients.

Parameters:
  • input_sequences (_type_) – batch of sequences

  • outputs (_type_) – batch of target sequences in which each position contains target token for the input sequence upto that position.

Returns:

the metrics tracked e.g. loss

Return type:

Dict[str, float]

validate(input_sequences: Tensor, outputs: Tensor) Dict[str, float][source]

the validation loop. infers the model on validation data without gradients or back propagation and calculates metrics.

Parameters:
  • input_sequences (torch.Tensor) – batch of input tokens

  • outputs (torch.Tensor) – batch of target sequences

Returns:

the tracked metrics

Return type:

Dict[str, float]