Source code for sign_language_translator.models.utils

"""Utility functions and classes for PyTorch models.

This module contains various utility functions and classes to assist with PyTorch models and their training.

Functions:
    top_p_top_k_indexes(probabilities, top_p=None, top_k=None):
        Perform top-p (nucleus) and top-k filtering based on the given probabilities.

    plot_lr_scheduler(*args, lr_scheduler_class=None, lr_scheduler_object=None, initial_lr=1e-3, n_steps=20, parameter_group_number=0, **kwargs):
        Plot the learning rate of a specific parameter group across training steps.

    downwards_wave(n_waves, n_steps_per_wave=9, start=1e-3, end=1e-7, amplitude=0.25):
        Generate a downwards wave pattern with a combination of sine wave and linear function.

    set_layers_trainability_(model, layers_to_unfreeze=None, layers_to_freeze=None):
        Set the trainability of specified layers in the given PyTorch model.

Classes:
    FullyLambdaLR(torch.optim.lr_scheduler.LRScheduler):
        Sets the learning rate of each parameter group to a given function that takes step_num, base_lr & last_lr as args.

    VideoEmbeddingPipeline(slt.models.VideoEmbeddingModel):
        With optional multiprocessing, reads video files from paths, performs forward pass of a model on them and saves the output in specified format.
"""

from __future__ import annotations

__all__ = [
    "top_p_top_k_indexes",
    "FullyLambdaLR",
    "plot_lr_scheduler",
    "downwards_wave",
    "set_layers_trainability_",
    "VideoEmbeddingPipeline",
]

import multiprocessing
from functools import partial
from glob import glob
from os import makedirs
from os.path import abspath, basename, exists, join
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Type, Union
from warnings import warn

import numpy
import torch
from numpy.typing import NDArray
from tqdm.auto import tqdm

from sign_language_translator.vision.utils import iter_frames_with_opencv

if TYPE_CHECKING:
    from sign_language_translator.models.video_embedding import VideoEmbeddingModel


[docs] def top_p_top_k_indexes( probabilities: Iterable[float], top_p: Optional[float] = None, top_k: Optional[int] = None, ) -> List[int]: """Perform top-p (nucleus) and top-k filtering based on the given probabilities. Top-k returns the indices of the top-k elements. Top-p returns the indices of the top elements whose sum does not exceed a certain value. Args: probs (Iterable[float]): A 1-D iterable containing the probabilities of each element. Probabilities must sum to 1. top_p (float or None): The threshold probability for top-p sampling (0 to 1). If None, top-p sampling is not applied. top_k (int or None): The maximum number of elements to keep for top-k sampling. If None, top-k sampling is not applied. Returns: torch.Tensor: The indices of the selected elements. Examples: .. code-block:: python selected_indices = top_p_top_k_indexes( probabilities=[0.1, 0.2, 0.15, 0.05, 0.3, 0.2], top_p=0.75, top_k=3, ) # [4, 1, 5] # You can then use the `selected_indices` to gather # the actual elements from the original tensor. sampled_elements = probs[selected_indices] # [0.3, 0.2, 0.2] """ if top_p is None and top_k is None: return sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True) # type: ignore probs = torch.Tensor(probabilities) assert probs.dim() == 1, "The input probabilities iterable must be 1-dimensional." assert torch.allclose( # pylint: disable = no-member probs.sum(), torch.Tensor([1]) ), "probabilities must sum to 1." assert top_p is None or 0 <= top_p <= 1, "top_p must be between 0 and 1, or None." assert top_k is None or top_k > 0, "top_k must be greater than 0, or None." sorted_probs, sorted_indices = torch.sort( # pylint: disable = no-member probs, descending=True ) cumulative_probs = torch.cumsum(sorted_probs, dim=0) # pylint: disable = no-member if top_p is not None: # Find indices of the elements whose cumulative probability is below top_p sorted_indices_to_remove_mask = cumulative_probs > top_p # [False, False, ..., True, True, ...] -> True means the probability too low and is below top_p # Shift the indices to the right to keep the first token above the threshold sorted_indices_to_remove_mask[1:] = sorted_indices_to_remove_mask[:-1].clone() # if the first element itself has probability above top_p, shifting step wouldn't help it from removal. sorted_indices_to_remove_mask[0] = False # Zero out the probabilities of elements whose index is marked for removal sorted_probs[sorted_indices_to_remove_mask] = 0.0 if top_k is not None: # Zero out the probabilities of elements beyond top_k sorted_probs[top_k:] = 0.0 # Get the indices of the selected elements selected_indices = sorted_indices[sorted_probs > 0] return selected_indices.tolist()
[docs] class FullyLambdaLR(torch.optim.lr_scheduler.LRScheduler): """Sets the learning rate of each parameter group to a given function that takes step_num, base_lr, and last_lr as parameters. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. lr_lambda (function): A function which computes the learning rate given an integer parameter step_num, initial learning rate and previous learning rate for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. Example: .. code-block:: python scheduler = FullyLambdaLR( optimizer, lambda step_num, base_lr, last_lr: last_lr * (1.08 if step_num%2 == 0 else 0.8) ) for epoch in range(100): train(...) validate(...) scheduler.step() print(scheduler.get_last_lr()[0]) """ def __init__( self, optimizer, lr_lambda: Callable[[int, float, float], float], last_epoch=-1, verbose=False, ): self.optimizer = optimizer self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) super().__init__(optimizer, last_epoch, verbose)
[docs] def get_lr(self): return [ lmbda(self.last_epoch, base_lr, group["lr"]) for lmbda, base_lr, group in zip( self.lr_lambdas, self.base_lrs, self.optimizer.param_groups, ) ]
[docs] def plot_lr_scheduler( *args, lr_scheduler_class: Optional[Type[torch.optim.lr_scheduler.LRScheduler]] = None, # lr_lambda: Optional[Callable[[Any], float]] = None, lr_scheduler_object: Optional[torch.optim.lr_scheduler.LRScheduler] = None, initial_lr: float = 1e-3, n_steps: int = 20, parameter_group_number: int = 0, save_fig: bool = False, fig_name: Optional[str] = None, **kwargs, ): """Plot the learning rate of a specific parameter group across training steps. This function generates a plot to visualize how the learning rate of a specified parameter group changes across training steps. It requires either an existing `lr_scheduler_object` or a combination of `lr_scheduler_class`, `args`, and `kwargs` to create a new learning rate scheduler. Args: lr_scheduler_class (Type[torch.optim.lr_scheduler._LRScheduler], optional): The class of the learning rate scheduler. Defaults to None. lr_scheduler_object (torch.optim.lr_scheduler._LRScheduler, optional): An existing object of a learning rate scheduler class. Defaults to None. initial_lr (float, optional): The initial learning rate for the new optimizer object needed in case lr_scheduler_object is None. Defaults to 1e-3. n_steps (int, optional): The number of epochs/steps to visualize the learning rate changes. Defaults to 20. parameter_group_number (int, optional): The index for the optimizer's parameter group to plot the learning rate for. Defaults to 0. save_fig (bool, optional): Whether to save the plot instead of showing. Defaults to False. fig_name (str, optional): The name of the file to save the plot to. Defaults to None (the class name of the lr_scheduler_class). *args: Additional arguments to pass to the lr_scheduler_class when creating a new scheduler. **kwargs: Additional keyword arguments to pass to the lr_scheduler_class when creating a new scheduler. Example: .. code-block:: python # Using an existing learning rate scheduler object lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) plot_lr_scheduler(lr_scheduler_object=lr_scheduler, n_steps=100) # Creating a new learning rate scheduler object plot_lr_scheduler( lr_scheduler_class=torch.optim.lr_scheduler.ExponentialLR, initial_lr=0.01, gamma=0.9, n_steps=50, ) """ import matplotlib.pyplot as plt optimizer = None if lr_scheduler_object is None: class TemporaryModel(torch.nn.Module): def __init__(self): super(TemporaryModel, self).__init__() self.linear = torch.nn.Linear(10, 10) def forward(self, x): return self.linear(x) model = TemporaryModel() optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr) if lr_scheduler_class is None: lr_scheduler_class = torch.optim.lr_scheduler.LambdaLR lr_scheduler_object = lr_scheduler_class(optimizer, *args, **kwargs) # type: ignore x_values = list(range(n_steps)) y_values = [] for _ in x_values: y_values.append(lr_scheduler_object.get_lr()[parameter_group_number]) # type: ignore if optimizer: optimizer.step() lr_scheduler_object.step() plt.plot( x_values, y_values, ".-" if len(x_values) < 50 else "-", label=f"lr, param_group_no:{parameter_group_number}", ) plt.xlabel("steps") plt.legend() if save_fig: plt.savefig(fig_name or f"{lr_scheduler_object.__class__.__name__}.png") else: plt.show() plt.close()
[docs] def downwards_wave( n_waves: int, n_steps_per_wave: int = 9, start: float = 1e-3, end: float = 1e-7, amplitude: float = 0.25, ) -> numpy.ndarray: """ Generate a downwards wave pattern with a combination of sine wave and linear function. The function generates a sequence of points forming a downward wave pattern, which consists of a combination of a sine wave and a linear function. The sine wave is modulated by the linear function to create a gradual decrease in axis of the waves. Parameters: n_waves (int): Number of peaks/cycles to generate. n_steps_per_wave (int, optional): Number of steps per wave. Default is 9. start (float, optional): Starting value for the linear function. Default is 1e-3. end (float, optional): Ending value for the linear function. Default is 1e-7. amplitude (float, optional): Amplitude of the sine wave. Default is 0.25. Returns: numpy.ndarray: Array containing the y-axis values of downwards wave pattern. """ x = numpy.linspace(0, n_waves, n_waves * n_steps_per_wave) downwards_line = numpy.linspace(start, end, len(x)) wave = numpy.sin(x * 2 * numpy.pi - numpy.pi / 2) + 1 wave *= abs(start - end) / n_waves * amplitude final = wave + downwards_line return final
[docs] def set_layers_trainability_( model: torch.nn.Module, layers_to_unfreeze: Optional[List[str]] = None, layers_to_freeze: Optional[List[str]] = None, ): """ Set the trainability of specified layers in the given PyTorch model. This function allows you to selectively freeze or unfreeze specific layers of a PyTorch model by setting their `requires_grad` attribute accordingly. Args: model (torch.nn.Module): The PyTorch model whose layers' requires_grad will be modified. layers_to_unfreeze (List[str] | None, optional): A list of layer names or prefixes for layers that you want to unfreeze. If None, no layers will be unfrozen. If [""], all layers will be unfrozen. Default is None. layers_to_freeze (List[str] | None, optional): A list of layer names or prefixes for layers that you want to freeze. If None, no layers will be frozen. If [""], all layers will be frozen. Default is None. Returns: None: This function modifies the model in-place. It does not return anything. Note: - If both `layers_to_unfreeze` and `layers_to_freeze` are None or empty, no action will be taken, and the function will return immediately. - The layers' names or prefixes specified in the lists should match the names as returned by `model.named_parameters()`. Examples: .. code-block:: python # To freeze all layers in the model: set_layers_trainability_(model, layers_to_freeze=[""]) # To unfreeze the layers with names starting with 'classifier' and 'fc': set_layers_trainability_(model, layers_to_unfreeze=["classifier", "fc"]) # To unfreeze all layers: set_layers_trainability_(model, layers_to_unfreeze=[""]) """ if not (layers_to_unfreeze or layers_to_freeze): return for layer_name, param in model.named_parameters(): for given_name in layers_to_freeze or []: if layer_name.startswith(given_name): param.requires_grad = False for given_name in layers_to_unfreeze or []: if layer_name.startswith(given_name): param.requires_grad = True
[docs] class VideoEmbeddingPipeline: """ A class for processing and embedding video data using a slt.models.VideoEmbeddingModel. Args: model (VideoEmbeddingModel): An instance of the VideoEmbeddingModel class or its child class used for embedding. Methods: process_video(path, save_format="csv", overwrite=False, output_dir=".", **kwargs): Load, embed, and save a video's embedding. kwargs are passed to model.embed(). process_videos_parallel(path_patterns, n_processes=multiprocessing.cpu_count(), save_format="csv", overwrite=False, output_dir=".", **kwargs): Process multiple videos in parallel using multiprocessing. kwargs are passed to model.embed(). Attributes: model (VideoEmbeddingModel): The VideoEmbeddingModel instance used for embedding. """ def __init__(self, model: VideoEmbeddingModel): self.model = model
[docs] def process_video( self, path, save_format="csv", overwrite=False, output_dir=".", **kwargs ): """ Load a video, embed its frames, and save the embedding. Args: path (str): The path to the video file. save_format (str, optional): Format for saving the embedding ("csv", "torch", "npy", "npz"). overwrite (bool, optional): Whether to overwrite existing embedding files. output_dir (str, optional): Directory to save the embedding file. **kwargs: Additional keyword arguments for embedding model. Returns: None """ # TODO: handle batched data video = self.__read_video(path) embedding = self.__embed_video(video, **kwargs) # TODO: frames progress callback self.__save_embedding( embedding, basename(path), output_dir=output_dir, file_format=save_format, overwrite=overwrite, )
[docs] def process_videos_parallel( self, path_patterns: List[str], n_processes=multiprocessing.cpu_count(), save_format="csv", overwrite=False, output_dir=".", **kwargs, ): """ Process multiple videos in parallel using multiprocessing, embedding their frames. Args: path_patterns (list): List of file path patterns to match videos e.g. ["dataset/*.mp4", "dataset/*.avi"]. n_processes (int, optional): Number of parallel processes. Defaults to multiprocessing.cpu_count(). save_format (str, optional): Format for saving the embeddings ("csv", "torch", "npy", "npz"). overwrite (bool, optional): Whether to overwrite existing embedding files. output_dir (str, optional): Directory to save the embedding files. **kwargs: Additional keyword arguments for embedding model. Returns: None """ paths = {abspath(path) for pattern in path_patterns for path in glob(pattern)} # warn if multiple paths have the same base name base_to_paths: Dict[str, List[str]] = {} for path in paths: if (base := basename(path)) not in base_to_paths: base_to_paths[base] = [] base_to_paths[base].append(path) clashing_paths = [ path for paths in base_to_paths.values() for path in paths if len(paths) > 1 ] if clashing_paths: warn( "Found multiple paths with the same base name" + f" (overwrite=True will prevent skipping). {clashing_paths = }" ) # optionally skip over existing targets if not overwrite: existing_targets = [ (join(output_dir, basename(path)) + f".{save_format}") for path in paths ] existing_targets = {path for path in existing_targets if exists(path)} for path in existing_targets: warn( f"Target file already exists at {path}. Use overwrite=True to prevent skipping." ) existing_sources = { basename(path)[: -len(save_format) - 1] for path in existing_targets } paths = {path for path in paths if basename(path) not in existing_sources} paths = sorted(paths) if len(paths) < 1: return # process n_processes = min(n_processes, len(paths), multiprocessing.cpu_count()) partial_process_video = partial( self.process_video, save_format=save_format, overwrite=overwrite, output_dir=output_dir, **kwargs, ) if len(paths) == 1: list(tqdm(map(partial_process_video, paths), total=1)) return with multiprocessing.Pool(processes=n_processes) as pool: list(tqdm(pool.imap(partial_process_video, paths), total=len(paths)))
def __read_video(self, path) -> Iterable[NDArray[numpy.uint8]]: return iter_frames_with_opencv(path) def __embed_video(self, video, **kwargs): return self.model.embed(video, **kwargs) def __save_embedding( self, embedding: Union[NDArray, torch.Tensor], filename: str, output_dir=".", file_format="csv", overwrite=False, ): target_path = abspath(join(output_dir, filename) + f".{file_format}") makedirs(output_dir, exist_ok=True) if exists(target_path) and not overwrite: warn(f"File already exists at {target_path}") return if file_format.lower() == "csv": numpy.savetxt(target_path, embedding, delimiter=",") elif file_format.lower() in ("torch", "pt"): torch.save(embedding, target_path) elif file_format.lower() == "npy": numpy.save(target_path, embedding) elif file_format.lower() == "npz": numpy.savez_compressed(target_path, **{filename: embedding})