Source code for sign_language_translator.models.video_embedding.mediapipe_landmarks_model
"""
This module contains the `MediaPipeLandmarksModel` class, which is a deep learning-based
video embedding model utilizing the MediaPipe framework for extracting pose and hand
landmarks from video frames.
Classes:
MediaPipeLandmarksModel: A video embedding model that utilizes MediaPipe for pose and hand landmark extraction.
Example:
.. code-block:: python
from sign_language_translator.models import MediaPipeLandmarksModel
from sign_language_translator.vision.utils import iter_frames_with_opencv
mediapipe_model = MediaPipeLandmarksModel(number_of_persons=1)
frame_sequence = iter_frames_with_opencv("video.mp4")
embedding = mediapipe_model.embed(frame_sequence, landmark_type="world")
print(embedding.shape)
"""
from os.path import join
from typing import Dict, Iterable, List, Optional, Union
try:
import mediapipe
except ImportError:
mediapipe = None
import numpy as np
import torch
from numpy.typing import NDArray
from sign_language_translator.config.assets import Assets
from sign_language_translator.models.video_embedding.video_embedding_model import (
VideoEmbeddingModel,
)
from sign_language_translator.utils import ProgressStatusCallback
[docs]
class MediaPipeLandmarksModel(VideoEmbeddingModel):
"""
A video embedding model using MediaPipe to extract pose and hand landmarks from video frames.
Args:
pose_model_name (str): The name of the pose estimation model.
hand_model_name (str): The name of the hand estimation model.
number_of_persons (int): The maximum number of persons to detect in each frame.
Attributes:
n_persons (int): The maximum number of persons to detect in each frame.
Methods:
embed: Embeds a sequence of frames using pose and hand landmarks.
"""
def __init__(
self,
pose_model_name="pose_landmarker_heavy.task",
hand_model_name="hand_landmarker.task",
number_of_persons: int = 1,
) -> None:
if mediapipe is None:
raise ImportError(
"The 'mediapipe' package is required to use the 'MediaPipeLandmarksModel'. "
"Install it using `pip install sign-language-translator[mediapipe]`. "
"(also make sure if your python version is compatible with mediapipe)."
)
self._pose_class = mediapipe.tasks.vision.PoseLandmarker
self._hand_class = mediapipe.tasks.vision.HandLandmarker
path = self.__download_and_get_model_path(f"models/mediapipe/{pose_model_name}")
self._pose_options = mediapipe.tasks.vision.PoseLandmarkerOptions(
base_options=mediapipe.tasks.BaseOptions(model_asset_path=path),
running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,
output_segmentation_masks=False,
num_poses=number_of_persons,
)
path = self.__download_and_get_model_path(f"models/mediapipe/{hand_model_name}")
self._hand_options = mediapipe.tasks.vision.HandLandmarkerOptions(
base_options=mediapipe.tasks.BaseOptions(model_asset_path=path),
running_mode=mediapipe.tasks.vision.RunningMode.VIDEO,
num_hands=number_of_persons * 2,
)
self.n_persons = number_of_persons
[docs]
def embed(
self,
frame_sequence: Iterable[Union[torch.Tensor, NDArray[np.uint8]]],
landmark_type: str = "world" or "image" or "all",
progress_callback: Optional[ProgressStatusCallback] = None,
total_frames: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
"""
Embed a sequence of frames (video) into a sequence of pose & hand landmarks.
Args:
frame_sequence (Iterable[torch.Tensor | NDArray[np.uint8]]): A sequence of video frames as 3D arrays (W, H, c).
landmark_type (str): The type of landmarks to include in the embedding ("world", "image", "all").
Returns:
torch.Tensor: A tensor containing the frame embeddings.
"""
if mediapipe is None:
raise ImportError(
"The 'mediapipe' package is required to use the 'MediaPipeLandmarksModel'. "
"Install it using `pip install sign-language-translator[mediapipe]`."
)
if landmark_type not in ("world", "image", "all"):
raise ValueError(
"landmark_type not supported, use 'world', 'image' or 'all'."
)
# TODO: Pose only or hands only
if hasattr(frame_sequence, "__len__"):
total_frames = len(frame_sequence) # type: ignore
embeddings = []
# TODO: create here or in __init__ ??
with self._pose_class.create_from_options(
self._pose_options
) as pose_landmarker, self._hand_class.create_from_options(
self._hand_options
) as hand_landmarker:
for i, frame in enumerate(frame_sequence):
# convert frame to mediapipe image
mp_image = mediapipe.Image(
image_format=mediapipe.ImageFormat.SRGB,
data=np.array(frame),
)
# infer through models
pose_result = pose_landmarker.detect_for_video(mp_image, i)
hand_result = hand_landmarker.detect_for_video(mp_image, i)
# create & append the frame embedding
poses = self._extract_from_pose_results(pose_result)
hands = self._extract_from_hand_results(hand_result)
persons = self._arange_pose_and_hands(poses, hands)
frame_embedding = self._create_frame_embedding(persons, landmark_type)
embeddings.append(frame_embedding)
if progress_callback and total_frames:
progress_callback(
{"file": f"{i / total_frames:.1%}" if total_frames else "?%"}
)
return torch.Tensor(embeddings)
def _flatten_landmarks(self, landmarks) -> List[float]:
return [
value
for lm in landmarks
for value in [lm.x, lm.y, lm.z, lm.visibility, lm.presence]
]
def _extract_from_pose_results(self, pose_result) -> Dict[str, List[List[float]]]:
poses = {"image": [], "world": []}
for pose_image, pose_world in zip(
pose_result.pose_landmarks, pose_result.pose_world_landmarks
):
poses["image"].append(self._flatten_landmarks(pose_image))
poses["world"].append(self._flatten_landmarks(pose_world))
return poses
def _extract_from_hand_results(self, hand_result) -> Dict[str, List[List[float]]]:
hands = {
"Left_image": [],
"Left_world": [],
"Right_image": [],
"Right_world": [],
}
for hnd, image, world in zip(
hand_result.handedness,
hand_result.hand_landmarks,
hand_result.hand_world_landmarks,
):
# flatten & separate
hands[hnd[0].display_name + "_image"].append(self._flatten_landmarks(image))
hands[hnd[0].display_name + "_world"].append(self._flatten_landmarks(world))
return hands
def _arange_pose_and_hands(
self,
poses: Dict[str, List[List[float]]],
hands: Dict[str, List[List[float]]],
) -> Dict[str, List[List[float]]]:
# TODO: Match left & right hands to poses
# by using minimum distance between hand image centers
# np.linalg.norm(pose[left_hand_ids].mean(axis=...), hands.mean(axis=...).T).argmin(axis=...)
default_hand = [0.0] * 5 * 21
default_pose = [0.0] * 5 * 33
for k in poses.keys():
poses[k] += [default_pose] * (self.n_persons - len(poses[k]))
for k in hands.keys():
hands[k] += [default_hand] * (self.n_persons - len(hands[k]))
return {
key: [
poses[key][p] + hands["Left_" + key][p] + hands["Right_" + key][p]
for p in range(self.n_persons)
] # TODO: order of persons should be the same across frames
for key in ["image", "world"]
}
def _create_frame_embedding(
self, persons: Dict[str, List[List[float]]], landmark_type: str
) -> List[float]:
embedding = []
# flatten & concat
if landmark_type in ("world", "all"):
embedding.extend([value for person in persons["world"] for value in person])
if landmark_type in ("image", "all"):
embedding.extend([value for person in persons["image"] for value in person])
return embedding
def __download_and_get_model_path(self, model_local_path: str):
Assets.download(
model_local_path,
progress_bar=True,
leave=False,
chunk_size=1048576,
)
return join(Assets.ROOT_DIR, model_local_path)