diff --git a/manim/mobject/matrix.py b/manim/mobject/matrix.py index 71b878b0f7..f10483f5c9 100644 --- a/manim/mobject/matrix.py +++ b/manim/mobject/matrix.py @@ -41,8 +41,10 @@ def construct(self): import itertools as it from collections.abc import Iterable, Sequence +from typing import Any, Callable import numpy as np +from typing_extensions import Self from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL @@ -56,7 +58,7 @@ def construct(self): # Not sure if we should keep it or not. -def matrix_to_tex_string(matrix): +def matrix_to_tex_string(matrix: np.ndarray) -> str: matrix = np.array(matrix).astype("str") if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) @@ -67,7 +69,7 @@ def matrix_to_tex_string(matrix): return prefix + " \\\\ ".join(rows) + suffix -def matrix_to_mobject(matrix): +def matrix_to_mobject(matrix: np.ndarray) -> MathTex: return MathTex(matrix_to_tex_string(matrix)) @@ -170,14 +172,14 @@ def __init__( bracket_v_buff: float = MED_SMALL_BUFF, add_background_rectangles_to_entries: bool = False, include_background_rectangle: bool = False, - element_to_mobject: type[MathTex] = MathTex, + element_to_mobject: type[Mobject] | Callable[..., Mobject] = MathTex, element_to_mobject_config: dict = {}, element_alignment_corner: Sequence[float] = DR, left_bracket: str = "[", right_bracket: str = "]", stretch_brackets: bool = True, bracket_config: dict = {}, - **kwargs, + **kwargs: Any, ): self.v_buff = v_buff self.h_buff = h_buff @@ -205,7 +207,7 @@ def __init__( if self.include_background_rectangle: self.add_background_rectangle() - def _matrix_to_mob_matrix(self, matrix): + def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]: return [ [ self.element_to_mobject(item, **self.element_to_mobject_config) @@ -214,7 +216,7 @@ def _matrix_to_mob_matrix(self, matrix): for row in matrix ] - def _organize_mob_matrix(self, matrix): + def _organize_mob_matrix(self, matrix: list[list[Mobject]]) -> Self: for i, row in enumerate(matrix): for j, _ in enumerate(row): mob = matrix[i][j] @@ -224,7 +226,7 @@ def _organize_mob_matrix(self, matrix): ) return self - def _add_brackets(self, left: str = "[", right: str = "]", **kwargs): + def _add_brackets(self, left: str = "[", right: str = "]", **kwargs: Any) -> Self: """Adds the brackets to the Matrix mobject. See Latex document for various bracket types. @@ -278,13 +280,13 @@ def _add_brackets(self, left: str = "[", right: str = "]", **kwargs): self.add(l_bracket, r_bracket) return self - def get_columns(self): + def get_columns(self) -> VGroup: r"""Return columns of the matrix as VGroups. Returns -------- - List[:class:`~.VGroup`] - Each VGroup contains a column of the matrix. + :class:`~.VGroup` + The VGroup contains a nested VGroup for each column of the matrix. Examples -------- @@ -305,7 +307,7 @@ def construct(self): ) ) - def set_column_colors(self, *colors: str): + def set_column_colors(self, *colors: str) -> Self: r"""Set individual colors for each columns of the matrix. Parameters @@ -335,13 +337,13 @@ def construct(self): column.set_color(color) return self - def get_rows(self): + def get_rows(self) -> VGroup: r"""Return rows of the matrix as VGroups. Returns -------- - List[:class:`~.VGroup`] - Each VGroup contains a row of the matrix. + :class:`~.VGroup` + The VGroup contains a nested VGroup for each row of the matrix. Examples -------- @@ -357,7 +359,7 @@ def construct(self): """ return VGroup(*(VGroup(*row) for row in self.mob_matrix)) - def set_row_colors(self, *colors: str): + def set_row_colors(self, *colors: str) -> Self: r"""Set individual colors for each row of the matrix. Parameters @@ -387,7 +389,7 @@ def construct(self): row.set_color(color) return self - def add_background_to_entries(self): + def add_background_to_entries(self) -> Self: """Add a black background rectangle to the matrix, see above for an example. @@ -400,7 +402,7 @@ def add_background_to_entries(self): mob.add_background_rectangle() return self - def get_mob_matrix(self) -> list[list[MathTex]]: + def get_mob_matrix(self) -> list[list[Mobject]]: """Return the underlying mob matrix mobjects. Returns @@ -410,7 +412,7 @@ def get_mob_matrix(self) -> list[list[MathTex]]: """ return self.mob_matrix - def get_entries(self): + def get_entries(self) -> VGroup: """Return the individual entries of the matrix. Returns @@ -483,9 +485,9 @@ def construct(self): def __init__( self, matrix: Iterable, - element_to_mobject: Mobject = DecimalNumber, - element_to_mobject_config: dict[str, Mobject] = {"num_decimal_places": 1}, - **kwargs, + element_to_mobject: type[Mobject] = DecimalNumber, + element_to_mobject_config: dict[str, Any] = {"num_decimal_places": 1}, + **kwargs: Any, ): """ Will round/truncate the decimal places as per the provided config. @@ -526,7 +528,10 @@ def construct(self): """ def __init__( - self, matrix: Iterable, element_to_mobject: Mobject = Integer, **kwargs + self, + matrix: Iterable, + element_to_mobject: type[Mobject] = Integer, + **kwargs: Any, ): """ Will round if there are decimal entries in the matrix. @@ -560,7 +565,12 @@ def construct(self): self.add(m0) """ - def __init__(self, matrix, element_to_mobject=lambda m: m, **kwargs): + def __init__( + self, + matrix: Iterable, + element_to_mobject: type[Mobject] | Callable[..., Mobject] = lambda m: m, + **kwargs: Any, + ): super().__init__(matrix, element_to_mobject=element_to_mobject, **kwargs) @@ -569,7 +579,7 @@ def get_det_text( determinant: int | str | None = None, background_rect: bool = False, initial_scale_factor: float = 2, -): +) -> VGroup: r"""Helper function to create determinant. Parameters diff --git a/manim/mobject/three_d/polyhedra.py b/manim/mobject/three_d/polyhedra.py index 8046f6066c..1f72873f7b 100644 --- a/manim/mobject/three_d/polyhedra.py +++ b/manim/mobject/three_d/polyhedra.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Hashable +from typing import TYPE_CHECKING, Any import numpy as np @@ -14,7 +15,7 @@ if TYPE_CHECKING: from manim.mobject.mobject import Mobject - from manim.typing import Point3D + from manim.typing import Point3D, Point3DLike_Array __all__ = [ "Polyhedron", @@ -96,10 +97,10 @@ def construct(self): def __init__( self, - vertex_coords: list[list[float] | np.ndarray], + vertex_coords: Point3DLike_Array, faces_list: list[list[int]], faces_config: dict[str, str | int | float | bool] = {}, - graph_config: dict[str, str | int | float | bool] = {}, + graph_config: dict[str, Any] = {}, ): super().__init__() self.faces_config = dict( @@ -116,7 +117,7 @@ def __init__( ) self.vertex_coords = vertex_coords self.vertex_indices = list(range(len(self.vertex_coords))) - self.layout = dict(enumerate(self.vertex_coords)) + self.layout: dict[Hashable, Any] = dict(enumerate(self.vertex_coords)) self.faces_list = faces_list self.face_coords = [[self.layout[j] for j in i] for i in faces_list] self.edges = self.get_edges(self.faces_list) @@ -129,14 +130,14 @@ def __init__( def get_edges(self, faces_list: list[list[int]]) -> list[tuple[int, int]]: """Creates list of cyclic pairwise tuples.""" - edges = [] + edges: list[tuple[int, int]] = [] for face in faces_list: edges += zip(face, face[1:] + face[:1]) return edges def create_faces( self, - face_coords: list[list[list | np.ndarray]], + face_coords: Point3DLike_Array, ) -> VGroup: """Creates VGroup of faces from a list of face coordinates.""" face_group = VGroup() @@ -144,12 +145,12 @@ def create_faces( face_group.add(Polygon(*face, **self.faces_config)) return face_group - def update_faces(self, m: Mobject): + def update_faces(self, m: Mobject) -> None: face_coords = self.extract_face_coords() new_faces = self.create_faces(face_coords) self.faces.match_points(new_faces) - def extract_face_coords(self) -> list[list[np.ndarray]]: + def extract_face_coords(self) -> Point3DLike_Array: """Extracts the coordinates of the vertices in the graph. Used for updating faces. """ @@ -181,7 +182,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit = edge_length * np.sqrt(2) / 4 super().__init__( vertex_coords=[ @@ -216,7 +217,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit = edge_length * np.sqrt(2) / 2 super().__init__( vertex_coords=[ @@ -262,7 +263,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit_a = edge_length * ((1 + np.sqrt(5)) / 4) unit_b = edge_length * (1 / 2) super().__init__( @@ -327,7 +328,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit_a = edge_length * ((1 + np.sqrt(5)) / 4) unit_b = edge_length * ((3 + np.sqrt(5)) / 4) unit_c = edge_length * (1 / 2) @@ -427,7 +428,7 @@ def construct(self): self.add(dots) """ - def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs): + def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs: Any): # Build Convex Hull array = np.array(points) hull = QuickHull(tolerance) diff --git a/mypy.ini b/mypy.ini index a69ae2c470..1735b18a11 100644 --- a/mypy.ini +++ b/mypy.ini @@ -120,9 +120,6 @@ ignore_errors = True [mypy-manim.mobject.logo] ignore_errors = True -[mypy-manim.mobject.matrix] -ignore_errors = True - [mypy-manim.mobject.mobject] ignore_errors = True @@ -171,9 +168,6 @@ ignore_errors = True [mypy-manim.mobject.text.text_mobject] ignore_errors = True -[mypy-manim.mobject.three_d.polyhedra] -ignore_errors = True - [mypy-manim.mobject.three_d.three_dimensions] ignore_errors = True