Skip to content

Adding type annotations to polyhedra.py and matrix.py #4322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions manim/mobject/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def construct(self):

import itertools as it
from collections.abc import Iterable, Sequence
from typing import Any, Callable

import numpy as np
from typing_extensions import Self

from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
Expand All @@ -56,7 +58,7 @@ def construct(self):
# Not sure if we should keep it or not.


def matrix_to_tex_string(matrix):
def matrix_to_tex_string(matrix: np.ndarray) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
Expand All @@ -67,7 +69,7 @@ def matrix_to_tex_string(matrix):
return prefix + " \\\\ ".join(rows) + suffix


def matrix_to_mobject(matrix):
def matrix_to_mobject(matrix: np.ndarray) -> MathTex:
return MathTex(matrix_to_tex_string(matrix))


Expand Down Expand Up @@ -170,14 +172,14 @@ def __init__(
bracket_v_buff: float = MED_SMALL_BUFF,
add_background_rectangles_to_entries: bool = False,
include_background_rectangle: bool = False,
element_to_mobject: type[MathTex] = MathTex,
element_to_mobject: type[Mobject] | Callable[..., Mobject] = MathTex,
element_to_mobject_config: dict = {},
element_alignment_corner: Sequence[float] = DR,
left_bracket: str = "[",
right_bracket: str = "]",
stretch_brackets: bool = True,
bracket_config: dict = {},
**kwargs,
**kwargs: Any,
):
self.v_buff = v_buff
self.h_buff = h_buff
Expand Down Expand Up @@ -205,7 +207,7 @@ def __init__(
if self.include_background_rectangle:
self.add_background_rectangle()

def _matrix_to_mob_matrix(self, matrix):
def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]:
return [
[
self.element_to_mobject(item, **self.element_to_mobject_config)
Expand All @@ -214,7 +216,7 @@ def _matrix_to_mob_matrix(self, matrix):
for row in matrix
]

def _organize_mob_matrix(self, matrix):
def _organize_mob_matrix(self, matrix: list[list[Mobject]]) -> Self:
for i, row in enumerate(matrix):
for j, _ in enumerate(row):
mob = matrix[i][j]
Expand All @@ -224,7 +226,7 @@ def _organize_mob_matrix(self, matrix):
)
return self

def _add_brackets(self, left: str = "[", right: str = "]", **kwargs):
def _add_brackets(self, left: str = "[", right: str = "]", **kwargs: Any) -> Self:
"""Adds the brackets to the Matrix mobject.

See Latex document for various bracket types.
Expand Down Expand Up @@ -278,13 +280,13 @@ def _add_brackets(self, left: str = "[", right: str = "]", **kwargs):
self.add(l_bracket, r_bracket)
return self

def get_columns(self):
def get_columns(self) -> VGroup:
r"""Return columns of the matrix as VGroups.

Returns
--------
List[:class:`~.VGroup`]
Each VGroup contains a column of the matrix.
:class:`~.VGroup`
The VGroup contains a nested VGroup for each column of the matrix.

Examples
--------
Expand All @@ -305,7 +307,7 @@ def construct(self):
)
)

def set_column_colors(self, *colors: str):
def set_column_colors(self, *colors: str) -> Self:
r"""Set individual colors for each columns of the matrix.

Parameters
Expand Down Expand Up @@ -335,13 +337,13 @@ def construct(self):
column.set_color(color)
return self

def get_rows(self):
def get_rows(self) -> VGroup:
r"""Return rows of the matrix as VGroups.

Returns
--------
List[:class:`~.VGroup`]
Each VGroup contains a row of the matrix.
:class:`~.VGroup`
The VGroup contains a nested VGroup for each row of the matrix.

Examples
--------
Expand All @@ -357,7 +359,7 @@ def construct(self):
"""
return VGroup(*(VGroup(*row) for row in self.mob_matrix))

def set_row_colors(self, *colors: str):
def set_row_colors(self, *colors: str) -> Self:
r"""Set individual colors for each row of the matrix.

Parameters
Expand Down Expand Up @@ -387,7 +389,7 @@ def construct(self):
row.set_color(color)
return self

def add_background_to_entries(self):
def add_background_to_entries(self) -> Self:
"""Add a black background rectangle to the matrix,
see above for an example.

Expand All @@ -400,7 +402,7 @@ def add_background_to_entries(self):
mob.add_background_rectangle()
return self

def get_mob_matrix(self) -> list[list[MathTex]]:
def get_mob_matrix(self) -> list[list[Mobject]]:
"""Return the underlying mob matrix mobjects.

Returns
Expand All @@ -410,7 +412,7 @@ def get_mob_matrix(self) -> list[list[MathTex]]:
"""
return self.mob_matrix

def get_entries(self):
def get_entries(self) -> VGroup:
"""Return the individual entries of the matrix.

Returns
Expand Down Expand Up @@ -483,9 +485,9 @@ def construct(self):
def __init__(
self,
matrix: Iterable,
element_to_mobject: Mobject = DecimalNumber,
element_to_mobject_config: dict[str, Mobject] = {"num_decimal_places": 1},
**kwargs,
element_to_mobject: type[Mobject] = DecimalNumber,
element_to_mobject_config: dict[str, Any] = {"num_decimal_places": 1},
**kwargs: Any,
):
"""
Will round/truncate the decimal places as per the provided config.
Expand Down Expand Up @@ -526,7 +528,10 @@ def construct(self):
"""

def __init__(
self, matrix: Iterable, element_to_mobject: Mobject = Integer, **kwargs
self,
matrix: Iterable,
element_to_mobject: type[Mobject] = Integer,
**kwargs: Any,
):
"""
Will round if there are decimal entries in the matrix.
Expand Down Expand Up @@ -560,7 +565,12 @@ def construct(self):
self.add(m0)
"""

def __init__(self, matrix, element_to_mobject=lambda m: m, **kwargs):
def __init__(
self,
matrix: Iterable,
element_to_mobject: type[Mobject] | Callable[..., Mobject] = lambda m: m,
**kwargs: Any,
):
super().__init__(matrix, element_to_mobject=element_to_mobject, **kwargs)


Expand All @@ -569,7 +579,7 @@ def get_det_text(
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: float = 2,
):
) -> VGroup:
r"""Helper function to create determinant.

Parameters
Expand Down
29 changes: 15 additions & 14 deletions manim/mobject/three_d/polyhedra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from collections.abc import Hashable
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -14,7 +15,7 @@

if TYPE_CHECKING:
from manim.mobject.mobject import Mobject
from manim.typing import Point3D
from manim.typing import Point3D, Point3DLike_Array

__all__ = [
"Polyhedron",
Expand Down Expand Up @@ -96,10 +97,10 @@ def construct(self):

def __init__(
self,
vertex_coords: list[list[float] | np.ndarray],
vertex_coords: Point3DLike_Array,
faces_list: list[list[int]],
faces_config: dict[str, str | int | float | bool] = {},
graph_config: dict[str, str | int | float | bool] = {},
graph_config: dict[str, Any] = {},
):
super().__init__()
self.faces_config = dict(
Expand All @@ -116,7 +117,7 @@ def __init__(
)
self.vertex_coords = vertex_coords
self.vertex_indices = list(range(len(self.vertex_coords)))
self.layout = dict(enumerate(self.vertex_coords))
self.layout: dict[Hashable, Any] = dict(enumerate(self.vertex_coords))
self.faces_list = faces_list
self.face_coords = [[self.layout[j] for j in i] for i in faces_list]
self.edges = self.get_edges(self.faces_list)
Expand All @@ -129,27 +130,27 @@ def __init__(

def get_edges(self, faces_list: list[list[int]]) -> list[tuple[int, int]]:
"""Creates list of cyclic pairwise tuples."""
edges = []
edges: list[tuple[int, int]] = []
for face in faces_list:
edges += zip(face, face[1:] + face[:1])
return edges

def create_faces(
self,
face_coords: list[list[list | np.ndarray]],
face_coords: Point3DLike_Array,
) -> VGroup:
"""Creates VGroup of faces from a list of face coordinates."""
face_group = VGroup()
for face in face_coords:
face_group.add(Polygon(*face, **self.faces_config))
return face_group

def update_faces(self, m: Mobject):
def update_faces(self, m: Mobject) -> None:
face_coords = self.extract_face_coords()
new_faces = self.create_faces(face_coords)
self.faces.match_points(new_faces)

def extract_face_coords(self) -> list[list[np.ndarray]]:
def extract_face_coords(self) -> Point3DLike_Array:
"""Extracts the coordinates of the vertices in the graph.
Used for updating faces.
"""
Expand Down Expand Up @@ -181,7 +182,7 @@ def construct(self):
self.add(obj)
"""

def __init__(self, edge_length: float = 1, **kwargs):
def __init__(self, edge_length: float = 1, **kwargs: Any):
unit = edge_length * np.sqrt(2) / 4
super().__init__(
vertex_coords=[
Expand Down Expand Up @@ -216,7 +217,7 @@ def construct(self):
self.add(obj)
"""

def __init__(self, edge_length: float = 1, **kwargs):
def __init__(self, edge_length: float = 1, **kwargs: Any):
unit = edge_length * np.sqrt(2) / 2
super().__init__(
vertex_coords=[
Expand Down Expand Up @@ -262,7 +263,7 @@ def construct(self):
self.add(obj)
"""

def __init__(self, edge_length: float = 1, **kwargs):
def __init__(self, edge_length: float = 1, **kwargs: Any):
unit_a = edge_length * ((1 + np.sqrt(5)) / 4)
unit_b = edge_length * (1 / 2)
super().__init__(
Expand Down Expand Up @@ -327,7 +328,7 @@ def construct(self):
self.add(obj)
"""

def __init__(self, edge_length: float = 1, **kwargs):
def __init__(self, edge_length: float = 1, **kwargs: Any):
unit_a = edge_length * ((1 + np.sqrt(5)) / 4)
unit_b = edge_length * ((3 + np.sqrt(5)) / 4)
unit_c = edge_length * (1 / 2)
Expand Down Expand Up @@ -427,7 +428,7 @@ def construct(self):
self.add(dots)
"""

def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs):
def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs: Any):
# Build Convex Hull
array = np.array(points)
hull = QuickHull(tolerance)
Expand Down
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@ ignore_errors = True
[mypy-manim.mobject.logo]
ignore_errors = True

[mypy-manim.mobject.matrix]
ignore_errors = True

[mypy-manim.mobject.mobject]
ignore_errors = True

Expand Down Expand Up @@ -171,9 +168,6 @@ ignore_errors = True
[mypy-manim.mobject.text.text_mobject]
ignore_errors = True

[mypy-manim.mobject.three_d.polyhedra]
ignore_errors = True

[mypy-manim.mobject.three_d.three_dimensions]
ignore_errors = True

Expand Down