Note changed families in some places, and fix colors (#3894)

This commit is contained in:
Aarush Deshpande 2024-08-31 11:46:41 -04:00 committed by GitHub
commit 3f431a12f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 96 additions and 86 deletions

View file

@ -394,6 +394,7 @@ class ApplyWave(Homotopy):
time_width: float = 1,
ripples: int = 1,
run_time: float = 2,
introducer: bool = True,
**kwargs,
) -> None:
x_min = mobject.get_left()[0]
@ -468,7 +469,9 @@ class ApplyWave(Homotopy):
nudge = wave(wave_phase) * vect
return np.array([x, y, z]) + nudge
super().__init__(homotopy, mobject, run_time=run_time, **kwargs)
super().__init__(
homotopy, mobject, run_time=run_time, introducer=introducer, **kwargs
)
class Wiggle(Animation):

View file

@ -56,12 +56,12 @@ class Homotopy(Animation):
**kwargs,
) -> None:
self.homotopy = homotopy
self.apply_function_kwargs = (
apply_function_kwargs if apply_function_kwargs is not None else {}
)
self.apply_function_kwargs = apply_function_kwargs or {}
super().__init__(mobject, run_time=run_time, **kwargs)
def function_at_time_t(self, t: float) -> tuple[float, float, float]:
def function_at_time_t(
self, t: float
) -> Callable[[tuple[float, float, float]], tuple[float, float, float]]:
return lambda p: self.homotopy(*p, t)
def interpolate_submobject(
@ -70,7 +70,7 @@ class Homotopy(Animation):
starting_submobject: Mobject,
alpha: float,
) -> None:
submobject.points = starting_submobject.points
submobject.match_points(starting_submobject)
submobject.apply_function(
self.function_at_time_t(alpha), **self.apply_function_kwargs
)

View file

@ -1318,11 +1318,12 @@ class Mobject:
def apply_function(self, function: MappingFunction, **kwargs) -> Self:
# Default to applying matrix about the origin, not mobjects center
if len(kwargs) == 0:
if not kwargs:
kwargs["about_point"] = ORIGIN
self.apply_points_function_about_point(
lambda points: np.apply_along_axis(function, 1, points), **kwargs
)
self.note_changed_family()
return self
def apply_function_to_position(self, function: MappingFunction) -> Self:

View file

@ -53,7 +53,7 @@ if TYPE_CHECKING:
from manim.animation.animation import Animation
from manim.renderer.renderer import RendererData
from manim.typing import PathFuncType, Point3D, Point3D_Array
from manim.typing import ManimFloat, PathFuncType, Point3D, Point3D_Array
TimeBasedUpdater: TypeAlias = Callable[
["OpenGLMobject", float], "OpenGLMobject | None"
@ -157,12 +157,12 @@ class OpenGLMobject:
self.name = self.__class__.__name__ if name is None else name
# internal_state
self.points = np.zeros((0, 3))
self.points: npt.NDArray[ManimFloat] = np.zeros((0, 3))
self.submobjects: list[OpenGLMobject] = []
self.parents: list[OpenGLMobject] = []
self.family: list[OpenGLMobject] = [self]
self.needs_new_bounding_box: bool = True
self._bounding_box = np.zeros((3, 3))
self._bounding_box: npt.NDArray[ManimFloat] = np.zeros((3, 3))
self._is_animating: bool = False
self.saved_state: OpenGLMobject | None = None
self.target: OpenGLMobject | None = None
@ -232,12 +232,12 @@ class OpenGLMobject:
>>> from manim import Square, GREEN
>>> Square.set_default(color=GREEN, fill_opacity=0.25)
>>> s = Square()
>>> s.color, s.fill_opacity
(ManimColor('#83C167'), 0.25)
>>> s.get_color().to_hex(with_alpha=True)
'#83C1673F'
>>> Square.set_default()
>>> s = Square()
>>> s.color, s.fill_opacity
(ManimColor('#FFFFFF'), 0.0)
>>> s.get_color().to_hex(with_alpha=True)
'#FFFFFFFF'
.. manim:: ChangedDefaultTextcolor
:save_last_frame:
@ -370,7 +370,7 @@ class OpenGLMobject:
self.refresh_bounding_box()
return self
def set_points(self, points):
def set_points(self, points: npt.NDArray[ManimFloat]) -> Self:
if len(points) == len(self.points):
self.points[:] = points
elif isinstance(points, np.ndarray):
@ -555,7 +555,7 @@ class OpenGLMobject:
# Others related to points
def match_points(self, mobject):
def match_points(self, mobject: OpenGLMobject) -> Self:
"""Edit points, positions, and submobjects to be identical
to another :class:`~.OpenGLMobject`, while keeping the style unchanged.
@ -667,15 +667,7 @@ class OpenGLMobject:
return self.submobjects
def note_changed_family(self) -> Self:
"""Updates bounding boxes and updater statuses.
This used to be called ``assemble_family``
.. warning::
Remove the above remark about ``assemble_family`` before experimental
is merged, it's a note to MrDiver and other devs
"""
"""Updates bounding boxes and updater statuses."""
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *uniq_chain(*sub_families)]
self.refresh_has_updater_status()
@ -684,16 +676,16 @@ class OpenGLMobject:
parent.note_changed_family()
return self
def get_family(self, recurse=True) -> list[OpenGLMobject]:
if recurse and hasattr(self, "family"):
def get_family(self, recurse: bool = True) -> Sequence[OpenGLMobject]:
if recurse:
return self.family
else:
return [self]
def family_members_with_points(self) -> list[OpenGLMobject]:
def family_members_with_points(self) -> Sequence[OpenGLMobject]:
return [m for m in self.get_family() if m.has_points()]
def get_ancestors(self, extended: bool = False) -> list[OpenGLMobject]:
def get_ancestors(self, extended: bool = False) -> Sequence[OpenGLMobject]:
"""
Returns parents, grandparents, etc.
Order of result should be from higher members of the hierarchy down.
@ -1260,6 +1252,7 @@ class OpenGLMobject:
self.submobjects.sort(key=submob_func)
else:
self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
self.note_changed_family()
return self
def shuffle(self, recurse=False):
@ -1368,9 +1361,12 @@ class OpenGLMobject:
sm.parents = [result]
result.note_changed_family()
for current, copy_ in zip(self.get_family(), result.get_family()):
copy_.points = np.array(current.points)
copy_.match_color(current)
# this seems correct, but is not needed in 3b1b manim - investigate
# for current, copy_ in zip(self.get_family(), result.get_family()):
# copy_.points = np.array(current.points)
# copy_.match_color(current)
# Similarly, instead of calling match_updaters, since we know the status
# won't have changed, just directly match with shallow copies.
result.non_time_updaters = self.non_time_updaters.copy()
@ -1727,7 +1723,7 @@ class OpenGLMobject:
def apply_function(self, function: PointUpdateFunction, **kwargs) -> Self:
# Default to applying matrix about the origin, not mobjects center
if len(kwargs) == 0:
if not kwargs:
kwargs["about_point"] = ORIGIN
self.apply_points_function(
lambda points: np.array([function(p) for p in points]), **kwargs
@ -2522,6 +2518,7 @@ class OpenGLMobject:
null_mob = self.copy()
null_mob.set_points([self.get_center()])
self.submobjects = [null_mob.copy() for k in range(n)]
self.note_changed_family()
return self
target = curr + n
repeat_indices = (np.arange(target) * curr) // target
@ -2537,6 +2534,7 @@ class OpenGLMobject:
new_submob.set_opacity(0)
new_submobs.append(new_submob)
self.submobjects = new_submobs
self.note_changed_family()
return self
# Interpolate
@ -2647,7 +2645,7 @@ class OpenGLMobject:
for sm1, sm2 in zip(family1, family2):
sm1.depth_test = sm2.depth_test
# Make sure named family members carry over
for attr, value in list(mobject.__dict__.items()):
for attr, value in mobject.__dict__.items():
if isinstance(value, OpenGLMobject) and value in family2:
setattr(self, attr, family1[family2.index(value)])
self.refresh_bounding_box(recurse_down=True)

View file

@ -85,10 +85,9 @@ class OpenGLVMobject(OpenGLMobject):
# so users can get autocomplete
def __init__(
self,
color: ParsableManimColor | list[ParsableManimColor] | None = None,
fill_color: ParsableManimColor | list[ParsableManimColor] | None = None,
fill_color: ParsableManimColor | Sequence[ParsableManimColor] | None = None,
fill_opacity: float | None = None,
stroke_color: ParsableManimColor | list[ParsableManimColor] | None = None,
stroke_color: ParsableManimColor | Sequence[ParsableManimColor] | None = None,
stroke_opacity: float | None = None,
stroke_width: float = DEFAULT_STROKE_WIDTH,
draw_stroke_behind_fill: bool = False,
@ -100,23 +99,17 @@ class OpenGLVMobject(OpenGLMobject):
):
super().__init__(**kwargs)
if fill_color is None:
fill_color = color
fill_color = self.color
if stroke_color is None:
stroke_color = color
self.fill_color: Sequence[ManimColor] = listify(ManimColor.parse(fill_color))
self.set_fill(opacity=fill_opacity)
self.stroke_color: Sequence[ManimColor] = listify(
ManimColor.parse(stroke_color)
)
self.set_stroke(opacity=stroke_opacity)
stroke_color = self.color
self.set_fill(color=fill_color, opacity=fill_opacity)
self.set_stroke(color=stroke_color, opacity=stroke_opacity)
self.stroke_width = listify(stroke_width)
self.draw_stroke_behind_fill = draw_stroke_behind_fill
self.background_image_file = background_image_file
self.long_lines = long_lines
self.joint_type = joint_type
self.flat_stroke = flat_stroke
# TODO: Remove this because the new shader doesn't need it
self.anti_alias_width = 1.0
self.needs_new_triangulation = True
self.triangulation = np.zeros(0, dtype="i4")
@ -134,10 +127,10 @@ class OpenGLVMobject(OpenGLMobject):
return OpenGLVMobject
# These are here just to make type checkers happy
def get_family(self, recurse: bool = True) -> list[OpenGLVMobject]: # type: ignore
def get_family(self, recurse: bool = True) -> Sequence[OpenGLVMobject]:
return super().get_family(recurse) # type: ignore
def family_members_with_points(self) -> list[OpenGLVMobject]: # type: ignore
def family_members_with_points(self) -> Sequence[OpenGLVMobject]: # type: ignore
return super().family_members_with_points() # type: ignore
def replicate(self, n: int) -> OpenGLVGroup: # type: ignore
@ -216,7 +209,7 @@ class OpenGLVMobject(OpenGLMobject):
for submob in self.submobjects:
submob.set_fill(color, opacity, recurse=True)
if color is not None:
self.fill_color = listify(ManimColor.parse(color))
self.fill_color: list[ManimColor] = listify(ManimColor.parse(color))
if opacity is not None:
self.fill_color = [c.opacity(opacity) for c in self.fill_color]
return self
@ -1583,6 +1576,7 @@ class OpenGLVGroup(OpenGLVMobject):
"""
self._assert_valid_submobjects(tuplify(value))
self.submobjects[key] = value # type: ignore
self.note_changed_family()
class OpenGLVectorizedPoint(OpenGLPoint, OpenGLVMobject):

View file

@ -154,6 +154,8 @@ class DecimalNumber(VMobject, metaclass=ConvertToOpenGL):
def _set_submobjects_from_number(self, number):
self.number = number
# the self.add below will recalculate the family,
# no need to do it here.
self.submobjects = []
num_string = self._get_num_string(number)

View file

@ -33,7 +33,6 @@ from textwrap import dedent
from manim import config, logger
from manim.constants import *
from manim.mobject.geometry.line import Line
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.mobject.svg.svg_mobject import SVGMobject
from manim.mobject.types.vectorized_mobject import VGroup
from manim.utils.tex import TexTemplate
@ -65,11 +64,6 @@ class SingleStringMathTex(SVGMobject):
font_size: float = DEFAULT_FONT_SIZE,
**kwargs,
):
if kwargs.get("color") is None:
# makes it so that color isn't explicitly passed for these mobs,
# and can instead inherit from the parent
kwargs["color"] = OpenGLVMobject().color
self._font_size = font_size
self.organize_left_to_right = organize_left_to_right
self.tex_environment = tex_environment
@ -248,7 +242,7 @@ class MathTex(SingleStringMathTex):
*tex_strings,
arg_separator: str = " ",
substrings_to_isolate: Iterable[str] | None = None,
tex_to_color_map: dict[str, ManimColor] = None,
tex_to_color_map: dict[str, ManimColor] | None = None,
tex_environment: str = "align*",
**kwargs,
):
@ -271,7 +265,7 @@ class MathTex(SingleStringMathTex):
**kwargs,
)
self._break_up_by_substrings()
except ValueError as compilation_error:
except ValueError:
if self.brace_notation_split_occurred:
logger.error(
dedent(
@ -285,17 +279,11 @@ class MathTex(SingleStringMathTex):
""",
),
)
raise compilation_error
raise
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
if self.organize_left_to_right:
self._organize_submobjects_left_to_right()
self.note_changed_family()
# 5 hours of work went into this line
# and it's still not perfect
# July 18, 2024
self.note_changed_family()
def _break_up_tex_strings(self, tex_strings):
# Separate out anything surrounded in double braces
@ -349,9 +337,16 @@ class MathTex(SingleStringMathTex):
sub_tex_mob.move_to(self.submobjects[last_submob_index], RIGHT)
else:
sub_tex_mob.submobjects = self.submobjects[curr_index:new_index]
sub_tex_mob.note_changed_family()
new_submobjects.append(sub_tex_mob)
curr_index = new_index
self.submobjects = new_submobjects
# 5 hours of work went into this line
# and it's still not perfect
# July 18, 2024
self.note_changed_family()
return self
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
@ -423,6 +418,7 @@ class MathTex(SingleStringMathTex):
def sort_alphabetically(self):
self.submobjects.sort(key=lambda m: m.get_tex_string())
self.note_changed_family()
class Tex(MathTex):

View file

@ -525,6 +525,7 @@ class Text(SVGMobject):
self.text = text
if self.disable_ligatures:
self.submobjects = [*self._gen_chars()]
self.note_changed_family()
self.chars = self.get_group_class()(*self.submobjects)
self.text = text_without_tabs.replace(" ", "").replace("\n", "")
nppc = self.n_points_per_curve

View file

@ -176,6 +176,7 @@ class PMobject(Mobject, metaclass=ConvertToOpenGL):
for attr, array in zip(attrs, arrays):
setattr(self, attr, array)
self.submobjects = []
self.note_changed_family()
return self
def get_color(self):

View file

@ -321,10 +321,11 @@ class VMobject(Mobject):
if family:
for submobject in self.submobjects:
submobject.set_fill(color, opacity, family)
self.update_rgbas_array("fill_rgbas", color, opacity)
self.fill_rgbas: RGBA_Array_Float
if color is not None:
self.fill_color = ManimColor.parse(color)
if opacity is not None:
self.fill_opacity = opacity
self.fill_color = [c.opacity(opacity) for c in self.fill_color]
return self
def set_stroke(

View file

@ -512,6 +512,7 @@ class SpecialThreeDScene(ThreeDScene):
piece.shade_in_3d = True
new_pieces.match_style(axis.pieces)
axis.pieces.submobjects = new_pieces.submobjects
axis.pieces.note_changed_family()
for tick in axis.tick_marks:
tick.add(VectorizedPoint(1.5 * tick.get_center()))
return axes

View file

@ -505,7 +505,7 @@ class ManimColor:
tmp[3] = alpha * 255
return tmp.astype(int)
def to_hex(self, with_alpha: bool = False) -> str:
def to_hex(self, *, with_alpha: bool = False) -> str:
"""Converts the manim color to a hexadecimal representation of the color
Parameters
@ -560,7 +560,7 @@ class ManimColor:
"""
return np.array(colorsys.rgb_to_hls(*self.to_rgb()))
def invert(self, with_alpha=False) -> Self:
def invert(self, *, with_alpha: bool = False) -> Self:
"""Returns an linearly inverted version of the color (no inplace changes)
Parameters

View file

@ -31,16 +31,16 @@ class _FramesTester:
def testing(self):
with np.load(self._file_path) as data:
self._frames = data["frame_data"]
# For backward compatibility, when the control data contains only one frame (<= v0.8.0)
if len(self._frames.shape) != 4:
self._frames = np.expand_dims(self._frames, axis=0)
logger.debug(self._frames.shape)
self._number_frames = np.ma.size(self._frames, axis=0)
yield
assert self._frames_compared == self._number_frames, (
f"The scene tested contained {self._frames_compared} frames, "
f"when there are {self._number_frames} control frames for this test."
)
# For backward compatibility, when the control data contains only one frame (<= v0.8.0)
if len(self._frames.shape) != 4:
self._frames = np.expand_dims(self._frames, axis=0)
logger.debug(self._frames.shape)
self._number_frames = np.ma.size(self._frames, axis=0)
yield
assert self._frames_compared == self._number_frames, (
f"The scene tested contained {self._frames_compared} frames, "
f"when there are {self._number_frames} control frames for this test."
)
def check_frame(self, frame_number: int, frame: PixelArray):
assert frame_number < self._number_frames, (
@ -56,7 +56,7 @@ class _FramesTester:
verbose=False,
)
self._frames_compared += 1
except AssertionError as e:
except AssertionError:
number_of_matches = np.isclose(
frame, self._frames[frame_number], atol=FRAME_ABSOLUTE_TOLERANCE
).sum()
@ -80,7 +80,7 @@ class _FramesTester:
self._frames[frame_number],
self._file_path.name,
)
raise e from e
raise
class _ControlDataWriter(_FramesTester):

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import inspect
from pathlib import Path
from typing import Callable
from typing import TYPE_CHECKING
import pytest
@ -18,6 +18,13 @@ from ._test_class_makers import (
_make_test_scene_class,
)
if TYPE_CHECKING:
from collections.abc import Callable
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
__all__ = ["frames_comparison"]
SCENE_PARAMETER_NAME = "scene"
@ -27,12 +34,12 @@ MIN_CAIRO_VERSION = 11800
def frames_comparison(
func: Callable[..., object] | None = None,
func: Callable[P, object] | None = None,
*,
last_frame: bool = True,
base_scene: type[Scene] = Scene,
**custom_config,
):
) -> Callable[Concatenate[pytest.FixtureRequest, Path, P], object]:
"""Compares the frames generated by the test with control frames previously registered.
If there is no control frames for this test, the test will fail. To generate
@ -77,7 +84,12 @@ def frames_comparison(
@functools.wraps(tested_scene_construct)
# The "request" parameter is meant to be used as a fixture by pytest. See below.
def wrapper(*args, request: pytest.FixtureRequest, tmp_path, **kwargs):
def wrapper(
request: pytest.FixtureRequest,
tmp_path: Path,
*args: P.args,
**kwargs: P.kwargs,
):
# Wraps the test_function to a construct method, to "freeze" the eventual additional arguments (parametrizations fixtures).
construct = functools.partial(tested_scene_construct, *args, **kwargs)