sign_language_translator.models.language_models.transformer_language_model.train module
This module contains classes to train transformer language models.
- Classes:
LM_Dataset(torch.utils.data.Dataset): subclass for language model. has a process function that can convert text file into a list of tensors. LM_Trainer: Trainer class for language model. runs training loop, prints metrics and makes model checkpoints.
- class sign_language_translator.models.language_models.transformer_language_model.train.LM_Dataset(data: Tensor)[source]
Bases:
Dataset- static prepare(file_path: str, text_to_token_ids: Callable[[str], List[int]], max_sequence_length: int = 32, encoding='utf-8', dtype=torch.int32) List[Tensor][source]
Process a text file into list of 2d torch tensors of shape (n_examples, n_tokens).
- Parameters:
file_path (str) – where the input file is stored
text_to_token_ids (Callable[[str], List[int]]) – a function that can process a line from file and convert it into a list of token ids.
max_sequence_length (int, optional) – make n_grams of sequences longer than this of size max_sequence_length. Defaults to 32.
encoding (str, optional) – the encoding used in the text file. Defaults to “utf-8”.
dtype (_type_, optional) – the type of returned torch tensors. check the range of values a type can contain and choose the smallest to save space. Defaults to torch.int.
- Returns:
_description_
- Return type:
List[torch.Tensor]
- class sign_language_translator.models.language_models.transformer_language_model.train.LM_Trainer(model: TransformerLanguageModel, device: str = 'cpu', epochs: int = 10, learning_rate: float = 0.001, lr_lambda: Callable[[int, float, float], float] = <function LM_Trainer.<lambda>>, lr_update_step_count: Optional[int] = None, optimizer='adamw', seed: int = 0, model_output_renderer: Optional[Callable[[TransformerLanguageModel], str]] = None, epoch_unfreeze_map: Optional[Dict[int, List[str]]] = None, class_weights: Optional[torch.Tensor] = None, max_gradient_norm: Optional[float] = None)[source]
Bases:
objectclass contains functions to train a language model built with pytorch. It is not designed to be generic rather it is specific to the TransformerLanguageModel.
- checkpoint(checkpoint_dir: str, losses, epoch, steps_fraction) None[source]
save metrics in model and save model to disk.
- run(train_batches: Iterable[Tuple[Tensor, Tensor]], validation_batches: Iterable[Tuple[Tensor, Tensor]], early_stop: bool = False, checkpoint_dir: str = '', checkpoint_step_count: int = 1000, model_output_step_count: int = 100, start_epoch_number: int = 0) Dict[str, List[float]][source]
Run the training/validation loop and generate output & checkpoints.
- Returns:
the tracked metrics
- Return type:
Dict[str, List[float]]
- train(input_sequences, outputs) Dict[str, float][source]
Training loop. calculates loss on non-padding tokens. Multiples class weights with the loss to scale it for imbalanced token distributions. Clips gradients with norm > max_gradient_norm to avoid exploding gradients.
- Parameters:
input_sequences (_type_) – batch of sequences
outputs (_type_) – batch of target sequences in which each position contains target token for the input sequence upto that position.
- Returns:
the metrics tracked e.g. loss
- Return type:
Dict[str, float]
- validate(input_sequences: Tensor, outputs: Tensor) Dict[str, float][source]
the validation loop. infers the model on validation data without gradients or back propagation and calculates metrics.
- Parameters:
input_sequences (torch.Tensor) – batch of input tokens
outputs (torch.Tensor) – batch of target sequences
- Returns:
the tracked metrics
- Return type:
Dict[str, float]