Merge branch 'ManimCommunity:main' into main

This commit is contained in:
Theo Barollet 2025-06-12 10:12:37 +02:00 committed by GitHub
commit 4131ea1451
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 928 additions and 152 deletions

2
.github/codeql.yml vendored
View file

@ -9,6 +9,8 @@ query-filters:
id: py/multiple-calls-to-init
- exclude:
id: py/missing-call-to-init
- exclude:
id: py/method-first-arg-is-not-self
paths:
- manim
paths-ignore:

View file

@ -13,7 +13,7 @@ repos:
- id: check-toml
name: Validate pyproject.toml
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.7
rev: v0.11.0
hooks:
- id: ruff
name: ruff lint

View file

@ -22,17 +22,17 @@
Manim is an animation engine for explanatory math videos. It's used to create precise animations programmatically, as demonstrated in the videos of [3Blue1Brown](https://www.3blue1brown.com/).
> [!NOTE]
> The community edition of Manim has been forked from 3b1b/manim, a tool originally created and open-sourced by Grant Sanderson, also creator of the 3Blue1Brown educational math videos. While Grant Sandersons repository continues to be maintained separately by him, he is not among the maintainers of the community edition. We recommend this version for its continued development, improved features, enhanced documentation, and more active community-driven maintenance. If you would like to study how Grant makes his videos, head over to his repository ([3b1b/manim](https://github.com/3b1b/manim)).
> The community edition of Manim (ManimCE) is a version maintained and developed by the community. It was forked from 3b1b/manim, a tool originally created and open-sourced by Grant Sanderson, also creator of the 3Blue1Brown educational math videos. While Grant Sanderson continues to maintain his own repository, we recommend this version for its continued development, improved features, enhanced documentation, and more active community-driven maintenance. If you would like to study how Grant makes his videos, head over to his repository ([3b1b/manim](https://github.com/3b1b/manim)).
## Table of Contents:
- [Installation](#installation)
- [Usage](#usage)
- [Documentation](#documentation)
- [Docker](#docker)
- [Help with Manim](#help-with-manim)
- [Contributing](#contributing)
- [License](#license)
- [Installation](#installation)
- [Usage](#usage)
- [Documentation](#documentation)
- [Docker](#docker)
- [Help with Manim](#help-with-manim)
- [Contributing](#contributing)
- [License](#license)
## Installation
@ -90,9 +90,9 @@ The `-p` flag in the command above is for previewing, meaning the video file wil
Some other useful flags include:
- `-s` to skip to the end and just show the final frame.
- `-n <number>` to skip ahead to the `n`'th animation of a scene.
- `-f` show the file in the file browser.
- `-s` to skip to the end and just show the final frame.
- `-n <number>` to skip ahead to the `n`'th animation of a scene.
- `-f` show the file in the file browser.
For a thorough list of command line arguments, visit the [documentation](https://docs.manim.community/en/stable/guides/configuration.html).
@ -120,8 +120,8 @@ The contribution guide may become outdated quickly; we highly recommend joining
[Discord server](https://www.manim.community/discord/) to discuss any potential
contributions and keep up to date with the latest developments.
Most developers on the project use `poetry` for management. You'll want to have poetry installed and available in your environment.
Learn more about `poetry` at its [documentation](https://python-poetry.org/docs/) and find out how to install manim with poetry at the [manim dev-installation guide](https://docs.manim.community/en/stable/contributing/development.html) in the manim documentation.
Most developers on the project use `uv` for management. You'll want to have uv installed and available in your environment.
Learn more about `uv` at its [documentation](https://docs.astral.sh/uv/) and find out how to install manim with uv at the [manim dev-installation guide](https://docs.manim.community/en/latest/contributing/development.html) in the manim documentation.
## How to Cite Manim

View file

@ -50,7 +50,7 @@ For example:
)
self.add(text)
.. _Pango library: https://pango.gnome.org
.. _Pango library: https://pango.org
Working with :class:`~.Text`
============================

View file

@ -327,6 +327,13 @@ Generally, you start with the starting number and add only some part of the valu
So, the logic of calculating the number to display at each step will be ``50 + alpha * (100 - 50)``.
Once you set the calculated value for the :class:`~.DecimalNumber`, you are done.
.. note::
If you're creating a custom animation and want to use a ``rate_func``, you must explicitly apply
``self.rate_func(alpha)`` to the parameter you're animating. For example, try switching the rate
function to ``rate_functions.there_and_back`` to observe how it affects the counting behavior.
Once you have defined your ``Count`` animation, you can play it in your :class:`~.Scene` for any duration you want for any :class:`~.DecimalNumber` with any rate function.
.. manim:: CountingScene
@ -343,7 +350,7 @@ Once you have defined your ``Count`` animation, you can play it in your :class:`
def interpolate_mobject(self, alpha: float) -> None:
# Set value of DecimalNumber according to alpha
value = self.start + (alpha * (self.end - self.start))
value = self.start + (self.rate_func(alpha) * (self.end - self.start))
self.mobject.set_value(value)

View file

@ -16,7 +16,7 @@ import configparser
import copy
import json
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from rich import color, errors
from rich import print as printf
@ -91,7 +91,7 @@ def make_logger(
# set the rich handler
rich_handler = RichHandler(
console=console,
show_time=parser.getboolean("log_timestamps"),
show_time=parser.getboolean("log_timestamps", fallback=False),
keywords=HIGHLIGHTED_KEYWORDS,
)
@ -108,7 +108,7 @@ def make_logger(
return logger, console, error_console
def parse_theme(parser: configparser.SectionProxy) -> Theme:
def parse_theme(parser: configparser.SectionProxy) -> Theme | None:
"""Configure the rich style of logger and console output.
Parameters
@ -126,7 +126,7 @@ def parse_theme(parser: configparser.SectionProxy) -> Theme:
:func:`make_logger`.
"""
theme = {key.replace("_", "."): parser[key] for key in parser}
theme: dict[str, Any] = {key.replace("_", "."): parser[key] for key in parser}
theme["log.width"] = None if theme["log.width"] == "-1" else int(theme["log.width"])
theme["log.height"] = (
@ -188,8 +188,11 @@ class JSONFormatter(logging.Formatter):
"""Format the record in a custom JSON format."""
record_c = copy.deepcopy(record)
if record_c.args:
for arg in record_c.args:
record_c.args[arg] = "<>"
if isinstance(record_c.args, dict):
for arg in record_c.args:
record_c.args[arg] = "<>"
else:
record_c.args = ("<>",) * len(record_c.args)
return json.dumps(
{
"levelname": record_c.levelname,

View file

@ -1448,7 +1448,7 @@ class ManimConfig(MutableMapping):
@property
def gui_location(self) -> tuple[Any]:
"""Enable GUI interaction."""
"""Location parameters for the GUI window (e.g., screen coordinates or layout settings)."""
return self._d["gui_location"]
@gui_location.setter

View file

@ -123,6 +123,10 @@ class Transform(Animation):
self.play(*anims, run_time=2)
self.wait()
See also
--------
:class:`~.ReplacementTransform`, :meth:`~.Mobject.interpolate`, :meth:`~.Mobject.align_data`
"""
def __init__(

View file

@ -65,12 +65,41 @@ if TYPE_CHECKING:
class Line(TipableVMobject):
"""A straight or curved line segment between two points or mobjects.
Parameters
----------
start
The starting point or Mobject of the line.
end
The ending point or Mobject of the line.
buff
The distance to shorten the line from both ends.
path_arc
If nonzero, the line will be curved into an arc with this angle (in radians).
kwargs
Additional arguments to be passed to :class:`TipableVMobject`
Examples
--------
.. manim:: LineExample
:save_last_frame:
class LineExample(Scene):
def construct(self):
line1 = Line(LEFT*2, RIGHT*2)
line2 = Line(LEFT*2, RIGHT*2, buff=0.5)
line3 = Line(LEFT*2, RIGHT*2, path_arc=PI/2)
grp = VGroup(line1,line2,line3).arrange(DOWN, buff=2)
self.add(grp)
"""
def __init__(
self,
start: Point3DLike | Mobject = LEFT,
end: Point3DLike | Mobject = RIGHT,
buff: float = 0,
path_arc: float | None = None,
path_arc: float = 0,
**kwargs: Any,
) -> None:
self.dim = 3
@ -78,14 +107,13 @@ class Line(TipableVMobject):
self.path_arc = path_arc
self._set_start_and_end_attrs(start, end)
super().__init__(**kwargs)
# TODO: Deal with the situation where path_arc is None
def generate_points(self) -> None:
self.set_points_by_ends(
start=self.start,
end=self.end,
buff=self.buff,
path_arc=self.path_arc, # type: ignore[arg-type]
path_arc=self.path_arc,
)
def set_points_by_ends(
@ -112,9 +140,6 @@ class Line(TipableVMobject):
"""
self._set_start_and_end_attrs(start, end)
if path_arc:
# self.path_arc could potentially be None, which is not accepted
# as parameter.
assert self.path_arc is not None
arc = ArcBetweenPoints(self.start, self.end, angle=self.path_arc)
self.set_points(arc.points)
else:
@ -125,16 +150,13 @@ class Line(TipableVMobject):
init_points = generate_points
def _account_for_buff(self, buff: float) -> None:
if buff == 0:
if buff <= 0:
return
#
length = self.get_length() if self.path_arc == 0 else self.get_arc_length()
#
if length < 2 * buff:
return
buff_proportion = buff / length
self.pointwise_become_partial(self, buff_proportion, 1 - buff_proportion)
return
def _set_start_and_end_attrs(
self, start: Point3DLike | Mobject, end: Point3DLike | Mobject

View file

@ -38,7 +38,6 @@ if TYPE_CHECKING:
from typing_extensions import Self
from manim.typing import (
ManimFloat,
Point3D,
Point3D_Array,
Point3DLike,
@ -122,39 +121,45 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
"""
return self.get_start_anchors()
def get_vertex_groups(self) -> npt.NDArray[ManimFloat]:
def get_vertex_groups(self) -> list[Point3D_Array]:
"""Gets the vertex groups of the :class:`Polygram`.
Returns
-------
:class:`numpy.ndarray`
The vertex groups of the :class:`Polygram`.
list[Point3D_Array]
The list of vertex groups of the :class:`Polygram`.
Examples
--------
::
>>> poly = Polygram([ORIGIN, RIGHT, UP], [LEFT, LEFT + UP, 2 * LEFT])
>>> poly.get_vertex_groups()
array([[[ 0., 0., 0.],
[ 1., 0., 0.],
[ 0., 1., 0.]],
<BLANKLINE>
[[-1., 0., 0.],
[-1., 1., 0.],
[-2., 0., 0.]]])
>>> poly = Polygram([ORIGIN, RIGHT, UP, LEFT + UP], [LEFT, LEFT + UP, 2 * LEFT])
>>> groups = poly.get_vertex_groups()
>>> len(groups)
2
>>> groups[0]
array([[ 0., 0., 0.],
[ 1., 0., 0.],
[ 0., 1., 0.],
[-1., 1., 0.]])
>>> groups[1]
array([[-1., 0., 0.],
[-1., 1., 0.],
[-2., 0., 0.]])
"""
vertex_groups = []
# TODO: If any of the original vertex groups contained the starting vertex N
# times, then .get_vertex_groups() splits it into N vertex groups.
group = []
for start, end in zip(self.get_start_anchors(), self.get_end_anchors()):
group.append(start)
if self.consider_points_equals(end, group[0]):
vertex_groups.append(group)
vertex_groups.append(np.array(group))
group = []
return np.array(vertex_groups)
return vertex_groups
def round_corners(
self,
@ -223,18 +228,18 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
new_points: list[Point3D] = []
for vertices in self.get_vertex_groups():
for vertex_group in self.get_vertex_groups():
arcs = []
# Repeat the radius list as necessary in order to provide a radius
# for each vertex.
if isinstance(radius, (int, float)):
radius_list = [radius] * len(vertices)
radius_list = [radius] * len(vertex_group)
else:
radius_list = radius * ceil(len(vertices) / len(radius))
radius_list = radius * ceil(len(vertex_group) / len(radius))
for currentRadius, (v1, v2, v3) in zip(
radius_list, adjacent_n_tuples(vertices, 3)
for current_radius, (v1, v2, v3) in zip(
radius_list, adjacent_n_tuples(vertex_group, 3)
):
vect1 = v2 - v1
vect2 = v3 - v2
@ -243,10 +248,10 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
angle = angle_between_vectors(vect1, vect2)
# Negative radius gives concave curves
angle *= np.sign(currentRadius)
angle *= np.sign(current_radius)
# Distance between vertex and start of the arc
cut_off_length = currentRadius * np.tan(angle / 2)
cut_off_length = current_radius * np.tan(angle / 2)
# Determines counterclockwise vs. clockwise
sign = np.sign(np.cross(vect1, vect2)[2])
@ -261,17 +266,17 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
if evenly_distribute_anchors:
# Determine the average length of each curve
nonZeroLengthArcs = [arc for arc in arcs if len(arc.points) > 4]
if len(nonZeroLengthArcs):
totalArcLength = sum(
[arc.get_arc_length() for arc in nonZeroLengthArcs]
nonzero_length_arcs = [arc for arc in arcs if len(arc.points) > 4]
if len(nonzero_length_arcs) > 0:
total_arc_length = sum(
[arc.get_arc_length() for arc in nonzero_length_arcs]
)
totalCurveCount = (
sum([len(arc.points) for arc in nonZeroLengthArcs]) / 4
num_curves = (
sum([len(arc.points) for arc in nonzero_length_arcs]) / 4
)
averageLengthPerCurve = totalArcLength / totalCurveCount
average_arc_length = total_arc_length / num_curves
else:
averageLengthPerCurve = 1
average_arc_length = 1.0
# To ensure that we loop through starting with last
arcs = [arcs[-1], *arcs[:-1]]
@ -284,9 +289,7 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
# Make sure anchors are evenly distributed, if necessary
if evenly_distribute_anchors:
line.insert_n_curves(
ceil(line.get_length() / averageLengthPerCurve)
)
line.insert_n_curves(ceil(line.get_length() / average_arc_length))
new_points.extend(line.points)

View file

@ -335,7 +335,7 @@ def _tree_layout(
# Always make a copy of the children because they get eaten
stack = [list(children[root_vertex]).copy()]
stick = [root_vertex]
parent = {u: root_vertex for u in children[root_vertex]}
parent = dict.fromkeys(children[root_vertex], root_vertex)
pos = {}
obstruction = [0.0] * len(T)
o = -1 if orientation == "down" else 1
@ -808,12 +808,12 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
vertex_mobjects = {}
graph_center = self.get_center()
base_positions = {v: graph_center for v in vertices}
base_positions = dict.fromkeys(vertices, graph_center)
base_positions.update(positions)
positions = base_positions
if isinstance(labels, bool):
labels = {v: labels for v in vertices}
labels = dict.fromkeys(vertices, labels)
else:
assert isinstance(labels, dict)
base_labels = dict.fromkeys(vertices, False)
@ -1033,7 +1033,10 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
self._edge_config[(u, v)] = edge_config
edge_mobject = edge_type(
self[u].get_center(), self[v].get_center(), z_index=-1, **edge_config
start=self[u].get_center(),
end=self[v].get_center(),
z_index=-1,
**edge_config,
)
self.edges[(u, v)] = edge_mobject
@ -1541,8 +1544,8 @@ class Graph(GenericGraph):
):
self.edges = {
(u, v): edge_type(
self[u].get_center(),
self[v].get_center(),
start=self[u].get_center(),
end=self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)
@ -1748,8 +1751,8 @@ class DiGraph(GenericGraph):
):
self.edges = {
(u, v): edge_type(
self[u],
self[v],
start=self[u],
end=self[v],
z_index=-1,
**self._edge_config[(u, v)],
)

View file

@ -2319,7 +2319,7 @@ class Mobject:
Returns
-------
list
list[Mobject]
A list of mobjects in the family of the given mobject.
Examples
@ -2333,12 +2333,39 @@ class Mobject:
>>> gr.get_family()
[Group, VGroup(Square, Rectangle), Square, Rectangle, Mobject, VMobject]
See also
--------
:meth:`~.Mobject.family_members_with_points`, :meth:`~.Mobject.align_data`
"""
sub_families = [x.get_family() for x in self.submobjects]
all_mobjects = [self] + list(it.chain(*sub_families))
return remove_list_redundancies(all_mobjects)
def family_members_with_points(self) -> list[Self]:
"""Filters the list of family members (generated by :meth:`.get_family`) to include only mobjects with points.
Returns
-------
list[Mobject]
A list of mobjects that have points.
Examples
--------
::
>>> from manim import Square, Rectangle, VGroup, Group, Mobject, VMobject
>>> s, r, m, v = Square(), Rectangle(), Mobject(), VMobject()
>>> vg = VGroup(s, r)
>>> gr = Group(vg, m, v)
>>> gr.family_members_with_points()
[Square, Rectangle]
See also
--------
:meth:`~.Mobject.get_family`
"""
return [m for m in self.get_family() if m.get_num_points() > 0]
def arrange(
@ -2705,11 +2732,11 @@ class Mobject:
# Alignment
def align_data(self, mobject: Mobject, skip_point_alignment: bool = False) -> None:
"""Aligns the data of this mobject with another mobject.
"""Aligns the family structure and data of this mobject with another mobject.
Afterwards, the two mobjects will have the same number of submobjects
(see :meth:`.align_submobjects`), the same parent structure (see
:meth:`.null_point_align`). If ``skip_point_alignment`` is false,
(see :meth:`.align_submobjects`) and the same parent structure (see
:meth:`.null_point_align`). If ``skip_point_alignment`` is ``False``,
they will also have the same number of points (see :meth:`.align_points`).
Parameters
@ -2718,7 +2745,32 @@ class Mobject:
The other mobject this mobject should be aligned to.
skip_point_alignment
Controls whether or not the computationally expensive
point alignment is skipped (default: False).
point alignment is skipped (default: ``False``).
.. note::
This method is primarily used internally by :meth:`.become` and the
:class:`~.Transform` animation to ensure that mobjects are structurally
compatible before transformation.
Examples
--------
::
>>> from manim import Rectangle, Line, ORIGIN, RIGHT
>>> rect = Rectangle(width=4.0, height=2.0, grid_xstep=1.0, grid_ystep=0.5)
>>> line = Line(start=ORIGIN,end=RIGHT)
>>> line.align_data(rect)
>>> len(line.get_family()) == len(rect.get_family())
True
>>> line.get_num_points() == rect.get_num_points()
True
See also
--------
:class:`~.Transform`, :meth:`~.Mobject.become`, :meth:`~.VMobject.align_points`, :meth:`~.Mobject.get_family`
"""
self.null_point_align(mobject)
self.align_submobjects(mobject)
@ -2815,22 +2867,64 @@ class Mobject:
"""Turns this :class:`~.Mobject` into an interpolation between ``mobject1``
and ``mobject2``.
The interpolation is applied to the points and color of the mobject.
Parameters
----------
mobject1
The starting Mobject.
mobject2
The target Mobject.
alpha
Interpolation factor between 0 (at ``mobject1``) and 1 (at ``mobject2``).
path_func
The function defining the interpolation path. Defaults to a straight path.
Returns
-------
:class:`Mobject`
``self``
.. note::
- Both mobjects must have the same number of points. If not, this will raise an error.
Use :meth:`~.VMobject.align_points` to match point counts beforehand if needed.
- This method is used internally by the :class:`~.Transform` animation
to interpolate between two mobjects during a transformation.
Examples
--------
.. manim:: DotInterpolation
.. manim:: InterpolateExample
:save_last_frame:
class DotInterpolation(Scene):
class InterpolateExample(Scene):
def construct(self):
dotR = Dot(color=DARK_GREY)
dotR.shift(2 * RIGHT)
dotL = Dot(color=WHITE)
dotL.shift(2 * LEFT)
# No need for point alignment:
dotL = Dot(color=DARK_GREY).to_edge(LEFT)
dotR = Dot(color=YELLOW).scale(10).to_edge(RIGHT)
dotMid1 = VMobject().interpolate(dotL, dotR, alpha=0.1)
dotMid2 = VMobject().interpolate(dotL, dotR, alpha=0.25)
dotMid3 = VMobject().interpolate(dotL, dotR, alpha=0.5)
dotMid4 = VMobject().interpolate(dotL, dotR, alpha=0.75)
dots = VGroup(dotL, dotR, dotMid1, dotMid2, dotMid3, dotMid4)
dotMiddle = VMobject().interpolate(dotL, dotR, alpha=0.3)
# Needs point alignment:
line = Line(ORIGIN, UP).to_edge(LEFT)
sq = Square(color=RED, fill_opacity=1, stroke_color=BLUE).to_edge(RIGHT)
line.align_points(sq)
mid1 = VMobject().interpolate(line, sq, alpha=0.1)
mid2 = VMobject().interpolate(line, sq, alpha=0.25)
mid3 = VMobject().interpolate(line, sq, alpha=0.5)
mid4 = VMobject().interpolate(line, sq, alpha=0.75)
linesquares = VGroup(line, sq, mid1, mid2, mid3, mid4)
self.add(VGroup(dots, linesquares).arrange(DOWN, buff=1))
See also
--------
:class:`~.Transform`, :meth:`~.VMobject.align_points`, :meth:`~.VMobject.interpolate_color`
self.add(dotL, dotR, dotMiddle)
"""
self.points = path_func(mobject1.points, mobject2.points, alpha)
self.interpolate_color(mobject1, mobject2, alpha)
@ -2943,6 +3037,10 @@ class Mobject:
>>> result = rect.copy().become(circ, match_center=True)
>>> np.allclose(rect.get_center(), result.get_center())
True
See also
--------
:meth:`~.Mobject.align_data`, :meth:`~.VMobject.interpolate_color`
"""
mobject = mobject.copy()
if stretch:

View file

@ -21,7 +21,7 @@ class ConvertToOpenGL(ABCMeta):
_converted_classes = []
def __new__(mcls, name, bases, namespace): # noqa: B902
def __new__(mcls, name, bases, namespace):
if config.renderer == RendererType.OPENGL:
# Must check class names to prevent
# cyclic importing.
@ -40,6 +40,6 @@ class ConvertToOpenGL(ABCMeta):
return super().__new__(mcls, name, bases, namespace)
def __init__(cls, name, bases, namespace): # noqa: B902
def __init__(cls, name, bases, namespace):
super().__init__(name, bases, namespace)
cls._converted_classes.append(cls)

View file

@ -1908,15 +1908,16 @@ class OpenGLMobject:
::
>>> from manim import *
>>> import numpy as np
>>> sq = Square()
>>> sq.height
2.0
np.float64(2.0)
>>> sq.stretch_to_fit_width(5)
Square
>>> sq.width
5.0
np.float64(5.0)
>>> sq.height
2.0
np.float64(2.0)
"""
return self.rescale_to_fit(width, 0, stretch=True, **kwargs)
@ -1941,15 +1942,16 @@ class OpenGLMobject:
::
>>> from manim import *
>>> import numpy as np
>>> sq = Square()
>>> sq.height
2.0
np.float64(2.0)
>>> sq.scale_to_fit_width(5)
Square
>>> sq.width
5.0
np.float64(5.0)
>>> sq.height
5.0
np.float64(5.0)
"""
return self.rescale_to_fit(width, 0, stretch=stretch, **kwargs)

View file

@ -301,7 +301,7 @@ class Paragraph(VGroup):
class Text(SVGMobject):
r"""Display (non-LaTeX) text rendered using `Pango <https://pango.gnome.org/>`_.
r"""Display (non-LaTeX) text rendered using `Pango <https://pango.org/>`_.
Text objects behave like a :class:`.VGroup`-like iterable of all characters
in the given text. In particular, slicing is possible.
@ -864,7 +864,7 @@ class Text(SVGMobject):
class MarkupText(SVGMobject):
r"""Display (non-LaTeX) text rendered using `Pango <https://pango.gnome.org/>`_.
r"""Display (non-LaTeX) text rendered using `Pango <https://pango.org/>`_.
Text objects behave like a :class:`.VGroup`-like iterable of all characters
in the given text. In particular, slicing is possible.

View file

@ -1047,22 +1047,21 @@ class VMobject(Mobject):
The VMobject itself, after appending the straight lines to its
path.
"""
self.throw_error_if_no_points()
points = np.asarray(points).reshape(-1, self.dim)
num_points = points.shape[0]
if num_points == 0:
return self
start_corners = np.empty((num_points, self.dim))
start_corners[0] = self.points[-1]
start_corners[1:] = points[:-1]
end_corners = points
if self.has_new_path_started():
# Pop the last point from self.points and
# add it to start_corners
start_corners = np.empty((num_points, self.dim))
start_corners[0] = self.points[-1]
start_corners[1:] = points[:-1]
end_corners = points
# Remove the last point from the new path
self.points = self.points[:-1]
else:
start_corners = points[:-1]
end_corners = points[1:]
nppcc = self.n_points_per_cubic_curve
new_points = np.empty((nppcc * start_corners.shape[0], self.dim))
@ -1720,6 +1719,10 @@ class VMobject(Mobject):
-------
:class:`VMobject`
``self``
See also
--------
:meth:`~.Mobject.interpolate`, :meth:`~.Mobject.align_data`
"""
self.align_rgbas(vmobject)
# TODO: This shortcut can be a bit over eager. What if they have the same length, but different subpath lengths?
@ -1928,12 +1931,18 @@ class VMobject(Mobject):
upper_index, upper_residue = integer_interpolate(0, num_curves, b)
nppc = self.n_points_per_curve
# Copy vmobject.points if vmobject is self to prevent unintended in-place modification
vmobject_points = (
vmobject.points.copy() if self is vmobject else vmobject.points
)
# If both indices coincide, get a part of a single Bézier curve.
if lower_index == upper_index:
# Look at the "lower_index"-th Bézier curve and select its part from
# t=lower_residue to t=upper_residue.
self.points = partial_bezier_points(
vmobject.points[nppc * lower_index : nppc * (lower_index + 1)],
vmobject_points[nppc * lower_index : nppc * (lower_index + 1)],
lower_residue,
upper_residue,
)
@ -1943,19 +1952,19 @@ class VMobject(Mobject):
# Look at the "lower_index"-th Bezier curve and select its part from
# t=lower_residue to t=1. This is the first curve in self.points.
self.points[:nppc] = partial_bezier_points(
vmobject.points[nppc * lower_index : nppc * (lower_index + 1)],
vmobject_points[nppc * lower_index : nppc * (lower_index + 1)],
lower_residue,
1,
)
# If there are more curves between the "lower_index"-th and the
# "upper_index"-th Béziers, add them all to self.points.
self.points[nppc:-nppc] = vmobject.points[
self.points[nppc:-nppc] = vmobject_points[
nppc * (lower_index + 1) : nppc * upper_index
]
# Look at the "upper_index"-th Bézier curve and select its part from
# t=0 to t=upper_residue. This is the last curve in self.points.
self.points[-nppc:] = partial_bezier_points(
vmobject.points[nppc * upper_index : nppc * (upper_index + 1)],
vmobject_points[nppc * upper_index : nppc * (upper_index + 1)],
0,
upper_residue,
)

View file

@ -31,34 +31,35 @@ all other color types in Manim.
To implement a custom color space, you must subclass :class:`ManimColor` and implement
three important methods:
- :attr:`~.ManimColor._internal_value`: a ``@property`` implemented on
:class:`ManimColor` with the goal of keeping a consistent internal representation
which can be referenced by other functions in :class:`ManimColor`. This property acts
as a proxy to whatever representation you need in your class.
- :attr:`~.ManimColor._internal_value`: a ``@property`` implemented on
:class:`ManimColor` with the goal of keeping a consistent internal representation
which can be referenced by other functions in :class:`ManimColor`. This property acts
as a proxy to whatever representation you need in your class.
- The getter should always return a NumPy array in the format ``[r,g,b,a]``, in
accordance with the type :class:`ManimColorInternal`.
- The getter should always return a NumPy array in the format ``[r,g,b,a]``, in
accordance with the type :class:`ManimColorInternal`.
- The setter should always accept a value in the format ``[r,g,b,a]`` which can be
converted to whatever attributes you need.
- The setter should always accept a value in the format ``[r,g,b,a]`` which can be
converted to whatever attributes you need.
- :attr:`~ManimColor._internal_space`: a read-only ``@property`` implemented on
:class:`ManimColor` with the goal of providing a useful representation which can be
used by operators, interpolation and color transform functions.
- :attr:`~ManimColor._internal_space`: a read-only ``@property`` implemented on
:class:`ManimColor` with the goal of providing a useful representation which can be
used by operators, interpolation and color transform functions.
The only constraints on this value are:
The only constraints on this value are:
- It must be a NumPy array.
- It must be a NumPy array.
- The last value must be the opacity in a range ``0.0`` to ``1.0``.
- The last value must be the opacity in a range ``0.0`` to ``1.0``.
Additionally, your ``__init__`` must support this format as an initialization value
without additional parameters to ensure correct functionality of all other methods in
:class:`ManimColor`.
Additionally, your ``__init__`` must support this format as an initialization value
without additional parameters to ensure correct functionality of all other methods in
:class:`ManimColor`.
- :meth:`~ManimColor._from_internal`: a ``@classmethod`` which converts an
``[r,g,b,a]`` value into suitable parameters for your ``__init__`` method and calls
the ``cls`` parameter.
- :meth:`~ManimColor._from_internal`: a ``@classmethod`` which converts an
``[r,g,b,a]`` value into suitable parameters for your ``__init__`` method and calls
the ``cls`` parameter.
"""
from __future__ import annotations
@ -601,7 +602,8 @@ class ManimColor:
HSL_Array_Float
An HSL array of 3 floats from 0.0 to 1.0.
"""
return np.array(colorsys.rgb_to_hls(*self.to_rgb()))
hls = colorsys.rgb_to_hls(*self.to_rgb())
return np.array([hls[0], hls[2], hls[1]])
def invert(self, with_alpha: bool = False) -> Self:
"""Return a new, linearly inverted version of this :class:`ManimColor` (no
@ -906,7 +908,7 @@ class ManimColor:
The :class:`ManimColor` with the corresponding RGB values to the given HSL
array.
"""
rgb = colorsys.hls_to_rgb(*hsl)
rgb = colorsys.hls_to_rgb(hsl[0], hsl[2], hsl[1])
return cls._from_internal(ManimColor(rgb, alpha)._internal_value)
@overload

View file

@ -25,7 +25,8 @@ class Polygon:
Parameters
----------
rings
A collection of closed polygonal ring.
A sequence of points, where each sequence represents the rings of the polygon.
Typically, multiple rings indicate holes in the polygon.
"""
def __init__(self, rings: Sequence[Point2DLike_Array]) -> None:
@ -63,18 +64,84 @@ class Polygon:
)
return d if self.inside(point) else -d
def inside(self, point: Point2DLike) -> bool:
"""Check if a point is inside the polygon."""
# Views
px, py = point
x, y = self.start[:, 0], self.start[:, 1]
xr, yr = self.stop[:, 0], self.stop[:, 1]
def _is_point_on_segment(
self,
x_point: float,
y_point: float,
x0: float,
y0: float,
x1: float,
y1: float,
) -> bool:
"""
Check if a point is on the segment.
# Count Crossings (enforce short-circuit)
c = (y > py) != (yr > py)
c = px < x[c] + (py - y[c]) * (xr[c] - x[c]) / (yr[c] - y[c])
c_sum: int = np.sum(c)
return c_sum % 2 == 1
The segment is defined by (x0, y0) to (x1, y1).
"""
if min(x0, x1) <= x_point <= max(x0, x1) and min(y0, y1) <= y_point <= max(
y0, y1
):
dx = x1 - x0
dy = y1 - y0
cross = dx * (y_point - y0) - dy * (x_point - x0)
return bool(np.isclose(cross, 0.0))
return False
def _ray_crosses_segment(
self,
x_point: float,
y_point: float,
x0: float,
y0: float,
x1: float,
y1: float,
) -> bool:
"""
Check if a horizontal ray to the right from point (x_point, y_point) crosses the segment.
The segment is defined by (x0, y0) to (x1, y1).
"""
if (y0 > y_point) != (y1 > y_point):
slope = (x1 - x0) / (y1 - y0)
x_intersect = slope * (y_point - y0) + x0
return bool(x_point < x_intersect)
return False
def inside(self, point: Point2DLike) -> bool:
"""
Check if a point is inside the polygon.
Uses ray casting algorithm and checks boundary points consistently.
"""
point_x, point_y = point
start_x, start_y = self.start[:, 0], self.start[:, 1]
stop_x, stop_y = self.stop[:, 0], self.stop[:, 1]
segment_count = len(start_x)
for i in range(segment_count):
if self._is_point_on_segment(
point_x,
point_y,
start_x[i],
start_y[i],
stop_x[i],
stop_y[i],
):
return True
crossings = 0
for i in range(segment_count):
if self._ray_crosses_segment(
point_x,
point_y,
start_x[i],
start_y[i],
stop_x[i],
stop_y[i],
):
crossings += 1
return crossings % 2 == 1
class Cell:

View file

@ -98,7 +98,7 @@ class Comparable(Protocol):
def __gt__(self, other: Any) -> bool: ...
ComparableT = TypeVar("ComparableT", bound=Comparable) # noqa: Y001
ComparableT = TypeVar("ComparableT", bound=Comparable)
def clip(a: ComparableT, min_a: ComparableT, max_a: ComparableT) -> ComparableT:

View file

@ -52,6 +52,9 @@ warn_return_any = True
ignore_errors = True
disable_error_code = return-value
[mypy-manim._config.logger_utils]
ignore_errors = False
[mypy-manim.animation.*]
ignore_errors = True

View file

@ -128,7 +128,7 @@ profile = "black"
omit = ["*tests*"]
[tool.coverage.report]
exclude_lines = ["pragma: no cover"]
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:"]
[tool.ruff]
line-length = 88
@ -151,6 +151,7 @@ select = [
"C4",
"D",
"E",
"W",
"F",
"I",
"PGH",

View file

@ -4,7 +4,18 @@ import logging
import numpy as np
from manim import BackgroundRectangle, Circle, Sector, Square, SurroundingRectangle
from manim import (
DEGREES,
LEFT,
RIGHT,
BackgroundRectangle,
Circle,
Line,
Polygram,
Sector,
Square,
SurroundingRectangle,
)
logger = logging.getLogger(__name__)
@ -15,6 +26,88 @@ def test_get_arc_center():
)
def test_Polygram_get_vertex_groups():
# Test that, once a Polygram polygram is created with some vertex groups,
# polygram.get_vertex_groups() (usually) returns the same vertex groups.
vertex_groups_arr = [
# 2 vertex groups for polygram 1
[
# Group 1: Triangle
np.array(
[
[2, 1, 0],
[0, 2, 0],
[-2, 1, 0],
]
),
# Group 2: Square
np.array(
[
[1, 0, 0],
[0, 1, 0],
[-1, 0, 0],
[0, -1, 0],
]
),
],
# 3 vertex groups for polygram 1
[
# Group 1: Quadrilateral
np.array(
[
[2, 0, 0],
[0, -1, 0],
[0, 0, -2],
[0, 1, 0],
]
),
# Group 2: Triangle
np.array(
[
[3, 1, 0],
[0, 0, 2],
[2, 0, 0],
]
),
# Group 3: Pentagon
np.array(
[
[1, -1, 0],
[1, 1, 0],
[0, 2, 0],
[-1, 1, 0],
[-1, -1, 0],
]
),
],
]
for vertex_groups in vertex_groups_arr:
polygram = Polygram(*vertex_groups)
poly_vertex_groups = polygram.get_vertex_groups()
for poly_group, group in zip(poly_vertex_groups, vertex_groups):
np.testing.assert_array_equal(poly_group, group)
# If polygram is a Polygram of a vertex group containing the start vertex N times,
# then polygram.get_vertex_groups() splits it into N vertex groups.
splittable_vertex_group = np.array(
[
[0, 1, 0],
[1, -2, 0],
[1, 2, 0],
[0, 1, 0], # same vertex as start
[-1, 2, 0],
[-1, -2, 0],
[0, 1, 0], # same vertex as start
[0.5, 2, 0],
[-0.5, 2, 0],
]
)
polygram = Polygram(splittable_vertex_group)
assert len(polygram.get_vertex_groups()) == 3
def test_SurroundingRectangle():
circle = Circle()
square = Square()
@ -53,3 +146,36 @@ def test_changing_Square_side_length_updates_the_square_appropriately():
def test_Square_side_length_consistent_after_scale_and_rotation():
sq = Square(side_length=1).scale(3).rotate(np.pi / 4)
assert np.isclose(sq.side_length, 3)
def test_line_with_buff_and_path_arc():
line = Line(LEFT, RIGHT, path_arc=60 * DEGREES, buff=0.3)
expected_points = np.array(
[
[-0.7299265, -0.12999304, 0.0],
[-0.6605293, -0.15719695, 0.0],
[-0.58965623, -0.18050364, 0.0],
[-0.51763809, -0.19980085, 0.0],
[-0.51763809, -0.19980085, 0.0],
[-0.43331506, -0.22239513, 0.0],
[-0.34760317, -0.23944429, 0.0],
[-0.26105238, -0.25083892, 0.0],
[-0.26105238, -0.25083892, 0.0],
[-0.1745016, -0.26223354, 0.0],
[-0.08729763, -0.26794919, 0.0],
[0.0, -0.26794919, 0.0],
[0.0, -0.26794919, 0.0],
[0.08729763, -0.26794919, 0.0],
[0.1745016, -0.26223354, 0.0],
[0.26105238, -0.25083892, 0.0],
[0.26105238, -0.25083892, 0.0],
[0.34760317, -0.23944429, 0.0],
[0.43331506, -0.22239513, 0.0],
[0.51763809, -0.19980085, 0.0],
[0.51763809, -0.19980085, 0.0],
[0.58965623, -0.18050364, 0.0],
[0.6605293, -0.15719695, 0.0],
[0.7299265, -0.12999304, 0.0],
]
)
np.testing.assert_allclose(line.points, expected_points)

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from manim import DiGraph, Graph, Scene, Text, tempconfig
from manim import DiGraph, Graph, LabeledLine, Scene, Text, tempconfig
from manim.mobject.graph import _layouts
@ -91,6 +91,29 @@ def test_graph_remove_edges():
assert set(G.edges.keys()) == set()
def test_graph_accepts_labeledline_as_edge_type():
vertices = [1, 2, 3, 4]
edges = [(1, 2), (2, 3), (3, 4), (4, 1)]
edge_config = {
(1, 2): {"label": "A"},
(2, 3): {"label": "B"},
(3, 4): {"label": "C"},
(4, 1): {"label": "D"},
}
G_manual = Graph(vertices, edges, edge_type=LabeledLine, edge_config=edge_config)
G_directed = DiGraph(
vertices, edges, edge_type=LabeledLine, edge_config=edge_config
)
for _edge_key, edge_obj in G_manual.edges.items():
assert isinstance(edge_obj, LabeledLine)
assert hasattr(edge_obj, "label")
for _edge_key, edge_obj in G_directed.edges.items():
assert isinstance(edge_obj, LabeledLine)
assert hasattr(edge_obj, "label")
def test_custom_animation_mobject_list():
G = Graph([1, 2, 3], [(1, 2), (2, 3)])
scene = Scene()

View file

@ -0,0 +1,212 @@
from __future__ import annotations
import numpy as np
import pytest
from manim.mobject.matrix import (
DecimalMatrix,
IntegerMatrix,
Matrix,
)
from manim.mobject.text.tex_mobject import MathTex
from manim.mobject.types.vectorized_mobject import VGroup
class TestMatrix:
@pytest.mark.parametrize(
(
"matrix_elements",
"left_bracket",
"right_bracket",
"expected_rows",
"expected_columns",
),
[
([[1, 2], [3, 4]], "[", "]", 2, 2),
([[1, 2, 3]], "[", "]", 1, 3),
([[1], [2], [3]], "[", "]", 3, 1),
([[5]], "[", "]", 1, 1),
([[1, 0], [0, 1]], "(", ")", 2, 2),
([["a", "b"], ["c", "d"]], "[", "]", 2, 2),
(np.array([[10, 20], [30, 40]]), "[", "]", 2, 2),
],
ids=[
"2x2_default",
"1x3_default",
"3x1_default",
"1x1_default",
"2x2_parentheses",
"2x2_strings",
"2x2_numpy",
],
)
def test_matrix_init_valid(
self,
matrix_elements,
left_bracket,
right_bracket,
expected_rows,
expected_columns,
):
matrix = Matrix(
matrix_elements, left_bracket=left_bracket, right_bracket=right_bracket
)
assert isinstance(matrix, Matrix)
assert matrix.left_bracket == left_bracket
assert matrix.right_bracket == right_bracket
assert len(matrix.get_rows()) == expected_rows
assert len(matrix.get_columns()) == expected_columns
@pytest.mark.parametrize(
("invalid_elements", "expected_error"),
[
(10, TypeError),
(10.4, TypeError),
([1, 2, 3], TypeError),
],
ids=[
"integer",
"float",
"flat_list",
],
)
def test_matrix_init_invalid(self, invalid_elements, expected_error):
with pytest.raises(expected_error):
Matrix(invalid_elements)
@pytest.mark.parametrize(
("matrix_elements", "expected_columns"),
[
([[1, 2], [3, 4]], 2),
([[1, 2, 3]], 3),
([[1], [2], [3]], 1),
],
ids=["2x2", "1x3", "3x1"],
)
def test_get_columns(self, matrix_elements, expected_columns):
matrix = Matrix(matrix_elements)
assert isinstance(matrix, Matrix)
assert len(matrix.get_columns()) == expected_columns
for column in matrix.get_columns():
assert isinstance(column, VGroup)
@pytest.mark.parametrize(
("matrix_elements", "expected_rows"),
[
([[1, 2], [3, 4]], 2),
([[1, 2, 3]], 1),
([[1], [2], [3]], 3),
],
ids=["2x2", "1x3", "3x1"],
)
def test_get_rows(self, matrix_elements, expected_rows):
matrix = Matrix(matrix_elements)
assert isinstance(matrix, Matrix)
assert len(matrix.get_rows()) == expected_rows
for row in matrix.get_rows():
assert isinstance(row, VGroup)
@pytest.mark.parametrize(
("matrix_elements", "expected_entries_tex_string", "expected_entries_count"),
[
([[1, 2], [3, 4]], ["1", "2", "3", "4"], 4),
([[1, 2, 3]], ["1", "2", "3"], 3),
],
ids=["2x2", "1x3"],
)
def test_get_entries(
self, matrix_elements, expected_entries_tex_string, expected_entries_count
):
matrix = Matrix(matrix_elements)
entries = matrix.get_entries()
assert isinstance(matrix, Matrix)
assert len(entries) == expected_entries_count
for index_entry, entry in enumerate(entries):
assert isinstance(entry, MathTex)
assert expected_entries_tex_string[index_entry] == entry.tex_string
@pytest.mark.parametrize(
("matrix_elements", "row", "column", "expected_value_str"),
[
([[1, 2], [3, 4]], 0, 0, "1"),
([[1, 2], [3, 4]], 1, 1, "4"),
([[1, 2, 3]], 0, 2, "3"),
([[1], [2], [3]], 2, 0, "3"),
],
ids=["2x2_00", "2x2_11", "1x3_02", "3x1_20"],
)
def test_get_element(self, matrix_elements, row, column, expected_value_str):
matrix = Matrix(matrix_elements)
assert isinstance(matrix.get_columns()[column][row], MathTex)
assert isinstance(matrix.get_rows()[row][column], MathTex)
assert matrix.get_columns()[column][row].tex_string == expected_value_str
assert matrix.get_rows()[row][column].tex_string == expected_value_str
@pytest.mark.parametrize(
("matrix_elements", "row", "column", "expected_error"),
[
([[1, 2]], 1, 0, IndexError),
([[1, 2]], 0, 2, IndexError),
],
ids=["row_out_of_bounds", "col_out_of_bounds"],
)
def test_get_element_invalid(self, matrix_elements, row, column, expected_error):
matrix = Matrix(matrix_elements)
with pytest.raises(expected_error):
matrix.get_columns()[column][row]
with pytest.raises(expected_error):
matrix.get_rows()[row][column]
class TestDecimalMatrix:
@pytest.mark.parametrize(
("matrix_elements", "num_decimal_places", "expected_elements"),
[
([[1.234, 5.678], [9.012, 3.456]], 2, [[1.234, 5.678], [9.012, 3.456]]),
([[1.0, 2.0], [3.0, 4.0]], 0, [[1, 2], [3, 4]]),
([[1, 2.3], [4.567, 7]], 1, [[1.0, 2.3], [4.567, 7.0]]),
],
ids=[
"basic_2_decimal_points",
"basic_0_decimal_points",
"mixed_1_decimal_points",
],
)
def test_decimal_matrix_init(
self, matrix_elements, num_decimal_places, expected_elements
):
matrix = DecimalMatrix(
matrix_elements,
element_to_mobject_config={"num_decimal_places": num_decimal_places},
)
assert isinstance(matrix, DecimalMatrix)
for column_index, column in enumerate(matrix.get_columns()):
for row_index, element in enumerate(column):
assert element.number == expected_elements[row_index][column_index]
assert element.num_decimal_places == num_decimal_places
class TestIntegerMatrix:
@pytest.mark.parametrize(
("matrix_elements", "expected_elements"),
[
([[1, 2], [3, 4]], [[1, 2], [3, 4]]),
([[1.2, 2.8], [3.5, 4]], [[1.2, 2.8], [3.5, 4]]),
],
ids=["basic_int", "mixed_float_int"],
)
def test_integer_matrix_init(self, matrix_elements, expected_elements):
matrix = IntegerMatrix(matrix_elements)
assert isinstance(matrix, IntegerMatrix)
for row_index, row in enumerate(matrix.get_rows()):
for column_index, element in enumerate(row):
assert element.number == expected_elements[row_index][column_index]

View file

@ -526,3 +526,25 @@ def test_proportion_from_point():
abc.scale(0.8)
props = [abc.proportion_from_point(p) for p in abc.get_vertices()]
np.testing.assert_allclose(props, [0, 1 / 3, 2 / 3])
def test_pointwise_become_partial_where_vmobject_is_self():
sq = Square()
sq.pointwise_become_partial(vmobject=sq, a=0.2, b=0.7)
expected_points = np.array(
[
[-0.6, 1.0, 0.0],
[-0.73333333, 1.0, 0.0],
[-0.86666667, 1.0, 0.0],
[-1.0, 1.0, 0.0],
[-1.0, 1.0, 0.0],
[-1.0, 0.33333333, 0.0],
[-1.0, -0.33333333, 0.0],
[-1.0, -1.0, 0.0],
[-1.0, -1.0, 0.0],
[-0.46666667, -1.0, 0.0],
[0.06666667, -1.0, 0.0],
[0.6, -1.0, 0.0],
]
)
np.testing.assert_allclose(sq.points, expected_points)

View file

@ -116,9 +116,19 @@ def test_to_hsv() -> None:
def test_to_hsl() -> None:
color = ManimColor((0x1, 0x2, 0x3, 0x4))
nt.assert_array_equal(
color.to_hsl(), colorsys.rgb_to_hls(0x1 / 255, 0x2 / 255, 0x3 / 255)
)
hls = colorsys.rgb_to_hls(0x1 / 255, 0x2 / 255, 0x3 / 255)
nt.assert_array_equal(color.to_hsl(), np.array([hls[0], hls[2], hls[1]]))
def test_from_hsl() -> None:
hls = colorsys.rgb_to_hls(0x1 / 255, 0x2 / 255, 0x3 / 255)
hsl = np.array([hls[0], hls[2], hls[1]])
color = ManimColor.from_hsl(hsl)
rgb = np.array([0x1 / 255, 0x2 / 255, 0x3 / 255])
nt.assert_allclose(color.to_rgb(), rgb)
def test_invert() -> None:

View file

@ -0,0 +1,157 @@
import numpy as np
import pytest
from manim.utils.polylabel import Cell, Polygon, polylabel
# Test simple square and square with a hole for inside/outside logic
@pytest.mark.parametrize(
("rings", "inside_points", "outside_points"),
[
(
# Simple square: basic convex polygon
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
[
[2, 2],
[1, 1],
[3.9, 3.9],
[0, 0],
[2, 0],
[0, 2],
[0, 4],
[4, 0],
[4, 2],
[2, 4],
[4, 4],
], # inside points
[[-1, -1], [5, 5], [4.1, 2]], # outside points
),
(
# Square with a square hole (donut shape): tests handling of interior voids
[
[[1, 1], [5, 1], [5, 5], [1, 5], [1, 1]],
[[2, 2], [2, 4], [4, 4], [4, 2], [2, 2]],
], # rings
[[1.5, 1.5], [3, 1.5], [1.5, 3]], # inside points
[[3, 3], [6, 6], [0, 0]], # outside points
),
(
# Non-convex polygon (same shape as flags used in Brazilian june festivals)
[[[0, 0], [2, 2], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
[[1, 3], [3.9, 3.9], [2, 3.5]], # inside points
[
[0.1, 0],
[1, 0],
[2, 0],
[2, 1],
[2, 1.9],
[3, 0],
[3.9, 0],
], # outside points
),
],
)
def test_polygon_inside_outside(rings, inside_points, outside_points):
polygon = Polygon(rings)
for point in inside_points:
assert polygon.inside(point)
for point in outside_points:
assert not polygon.inside(point)
# Test distance calculation with known expected distances
@pytest.mark.parametrize(
("rings", "points", "expected_distance"),
[
(
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
[[2, 2]], # points
2.0, # Distance from center to closest edge in square
),
(
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
[[0, 0], [2, 0], [4, 2], [2, 4], [0, 2]], # points
0.0, # On the edge
),
(
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
[[5, 5]], # points
-np.sqrt(2), # Outside and diagonally offset
),
],
)
def test_polygon_compute_distance(rings, points, expected_distance):
polygon = Polygon(rings)
for point in points:
result = polygon.compute_distance(np.array(point))
assert pytest.approx(result, rel=1e-3) == expected_distance
@pytest.mark.parametrize(
("center", "h", "rings"),
[
(
[2, 2], # center
1.0, # h
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]], # rings
),
(
[3, 1.5], # center
0.5, # h
[
[[1, 1], [5, 1], [5, 5], [1, 5], [1, 1]],
[[2, 2], [2, 4], [4, 4], [4, 2], [2, 2]],
], # rings
),
],
)
def test_cell(center, h, rings):
polygon = Polygon(rings)
cell = Cell(center, h, polygon)
assert isinstance(cell.d, float)
assert isinstance(cell.p, float)
assert np.allclose(cell.c, center)
assert cell.h == h
other = Cell(np.add(center, [0.1, 0.1]), h, polygon)
assert (cell < other) == (cell.d < other.d)
assert (cell > other) == (cell.d > other.d)
assert (cell <= other) == (cell.d <= other.d)
assert (cell >= other) == (cell.d >= other.d)
@pytest.mark.parametrize(
("rings", "expected_centers"),
[
(
# Simple square: basic convex polygon
[[[0, 0], [4, 0], [4, 4], [0, 4], [0, 0]]],
[[2.0, 2.0]], # single correct pole of inaccessibility
),
(
# Square with a square hole (donut shape): tests handling of interior voids
[
[[1, 1], [5, 1], [5, 5], [1, 5], [1, 1]],
[[2, 2], [2, 4], [4, 4], [4, 2], [2, 2]],
],
[ # any of the four pole of inaccessibility options
[1.5, 1.5],
[1.5, 4.5],
[4.5, 1.5],
[4.5, 4.5],
],
),
],
)
def test_polylabel(rings, expected_centers):
# Add third dimension to conform to polylabel input format
rings_3d = [np.column_stack([ring, np.zeros(len(ring))]) for ring in rings]
result = polylabel(rings_3d, precision=0.01)
assert isinstance(result, Cell)
assert result.h <= 0.01
assert result.d >= 0.0
match_found = any(np.allclose(result.c, ec, atol=0.1) for ec in expected_centers)
assert match_found, f"Expected one of {expected_centers}, but got {result.c}"