mirror of
https://github.com/ManimCommunity/manim.git
synced 2026-06-22 10:01:47 +00:00
Note changed families in some places, and fix colors (#3894)
This commit is contained in:
parent
7e8c5c8144
commit
3f431a12f7
14 changed files with 96 additions and 86 deletions
|
|
@ -394,6 +394,7 @@ class ApplyWave(Homotopy):
|
|||
time_width: float = 1,
|
||||
ripples: int = 1,
|
||||
run_time: float = 2,
|
||||
introducer: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
x_min = mobject.get_left()[0]
|
||||
|
|
@ -468,7 +469,9 @@ class ApplyWave(Homotopy):
|
|||
nudge = wave(wave_phase) * vect
|
||||
return np.array([x, y, z]) + nudge
|
||||
|
||||
super().__init__(homotopy, mobject, run_time=run_time, **kwargs)
|
||||
super().__init__(
|
||||
homotopy, mobject, run_time=run_time, introducer=introducer, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class Wiggle(Animation):
|
||||
|
|
|
|||
|
|
@ -56,12 +56,12 @@ class Homotopy(Animation):
|
|||
**kwargs,
|
||||
) -> None:
|
||||
self.homotopy = homotopy
|
||||
self.apply_function_kwargs = (
|
||||
apply_function_kwargs if apply_function_kwargs is not None else {}
|
||||
)
|
||||
self.apply_function_kwargs = apply_function_kwargs or {}
|
||||
super().__init__(mobject, run_time=run_time, **kwargs)
|
||||
|
||||
def function_at_time_t(self, t: float) -> tuple[float, float, float]:
|
||||
def function_at_time_t(
|
||||
self, t: float
|
||||
) -> Callable[[tuple[float, float, float]], tuple[float, float, float]]:
|
||||
return lambda p: self.homotopy(*p, t)
|
||||
|
||||
def interpolate_submobject(
|
||||
|
|
@ -70,7 +70,7 @@ class Homotopy(Animation):
|
|||
starting_submobject: Mobject,
|
||||
alpha: float,
|
||||
) -> None:
|
||||
submobject.points = starting_submobject.points
|
||||
submobject.match_points(starting_submobject)
|
||||
submobject.apply_function(
|
||||
self.function_at_time_t(alpha), **self.apply_function_kwargs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1318,11 +1318,12 @@ class Mobject:
|
|||
|
||||
def apply_function(self, function: MappingFunction, **kwargs) -> Self:
|
||||
# Default to applying matrix about the origin, not mobjects center
|
||||
if len(kwargs) == 0:
|
||||
if not kwargs:
|
||||
kwargs["about_point"] = ORIGIN
|
||||
self.apply_points_function_about_point(
|
||||
lambda points: np.apply_along_axis(function, 1, points), **kwargs
|
||||
)
|
||||
self.note_changed_family()
|
||||
return self
|
||||
|
||||
def apply_function_to_position(self, function: MappingFunction) -> Self:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from manim.animation.animation import Animation
|
||||
from manim.renderer.renderer import RendererData
|
||||
from manim.typing import PathFuncType, Point3D, Point3D_Array
|
||||
from manim.typing import ManimFloat, PathFuncType, Point3D, Point3D_Array
|
||||
|
||||
TimeBasedUpdater: TypeAlias = Callable[
|
||||
["OpenGLMobject", float], "OpenGLMobject | None"
|
||||
|
|
@ -157,12 +157,12 @@ class OpenGLMobject:
|
|||
self.name = self.__class__.__name__ if name is None else name
|
||||
|
||||
# internal_state
|
||||
self.points = np.zeros((0, 3))
|
||||
self.points: npt.NDArray[ManimFloat] = np.zeros((0, 3))
|
||||
self.submobjects: list[OpenGLMobject] = []
|
||||
self.parents: list[OpenGLMobject] = []
|
||||
self.family: list[OpenGLMobject] = [self]
|
||||
self.needs_new_bounding_box: bool = True
|
||||
self._bounding_box = np.zeros((3, 3))
|
||||
self._bounding_box: npt.NDArray[ManimFloat] = np.zeros((3, 3))
|
||||
self._is_animating: bool = False
|
||||
self.saved_state: OpenGLMobject | None = None
|
||||
self.target: OpenGLMobject | None = None
|
||||
|
|
@ -232,12 +232,12 @@ class OpenGLMobject:
|
|||
>>> from manim import Square, GREEN
|
||||
>>> Square.set_default(color=GREEN, fill_opacity=0.25)
|
||||
>>> s = Square()
|
||||
>>> s.color, s.fill_opacity
|
||||
(ManimColor('#83C167'), 0.25)
|
||||
>>> s.get_color().to_hex(with_alpha=True)
|
||||
'#83C1673F'
|
||||
>>> Square.set_default()
|
||||
>>> s = Square()
|
||||
>>> s.color, s.fill_opacity
|
||||
(ManimColor('#FFFFFF'), 0.0)
|
||||
>>> s.get_color().to_hex(with_alpha=True)
|
||||
'#FFFFFFFF'
|
||||
|
||||
.. manim:: ChangedDefaultTextcolor
|
||||
:save_last_frame:
|
||||
|
|
@ -370,7 +370,7 @@ class OpenGLMobject:
|
|||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
||||
def set_points(self, points):
|
||||
def set_points(self, points: npt.NDArray[ManimFloat]) -> Self:
|
||||
if len(points) == len(self.points):
|
||||
self.points[:] = points
|
||||
elif isinstance(points, np.ndarray):
|
||||
|
|
@ -555,7 +555,7 @@ class OpenGLMobject:
|
|||
|
||||
# Others related to points
|
||||
|
||||
def match_points(self, mobject):
|
||||
def match_points(self, mobject: OpenGLMobject) -> Self:
|
||||
"""Edit points, positions, and submobjects to be identical
|
||||
to another :class:`~.OpenGLMobject`, while keeping the style unchanged.
|
||||
|
||||
|
|
@ -667,15 +667,7 @@ class OpenGLMobject:
|
|||
return self.submobjects
|
||||
|
||||
def note_changed_family(self) -> Self:
|
||||
"""Updates bounding boxes and updater statuses.
|
||||
|
||||
This used to be called ``assemble_family``
|
||||
|
||||
.. warning::
|
||||
|
||||
Remove the above remark about ``assemble_family`` before experimental
|
||||
is merged, it's a note to MrDiver and other devs
|
||||
"""
|
||||
"""Updates bounding boxes and updater statuses."""
|
||||
sub_families = (sm.get_family() for sm in self.submobjects)
|
||||
self.family = [self, *uniq_chain(*sub_families)]
|
||||
self.refresh_has_updater_status()
|
||||
|
|
@ -684,16 +676,16 @@ class OpenGLMobject:
|
|||
parent.note_changed_family()
|
||||
return self
|
||||
|
||||
def get_family(self, recurse=True) -> list[OpenGLMobject]:
|
||||
if recurse and hasattr(self, "family"):
|
||||
def get_family(self, recurse: bool = True) -> Sequence[OpenGLMobject]:
|
||||
if recurse:
|
||||
return self.family
|
||||
else:
|
||||
return [self]
|
||||
|
||||
def family_members_with_points(self) -> list[OpenGLMobject]:
|
||||
def family_members_with_points(self) -> Sequence[OpenGLMobject]:
|
||||
return [m for m in self.get_family() if m.has_points()]
|
||||
|
||||
def get_ancestors(self, extended: bool = False) -> list[OpenGLMobject]:
|
||||
def get_ancestors(self, extended: bool = False) -> Sequence[OpenGLMobject]:
|
||||
"""
|
||||
Returns parents, grandparents, etc.
|
||||
Order of result should be from higher members of the hierarchy down.
|
||||
|
|
@ -1260,6 +1252,7 @@ class OpenGLMobject:
|
|||
self.submobjects.sort(key=submob_func)
|
||||
else:
|
||||
self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
|
||||
self.note_changed_family()
|
||||
return self
|
||||
|
||||
def shuffle(self, recurse=False):
|
||||
|
|
@ -1368,9 +1361,12 @@ class OpenGLMobject:
|
|||
sm.parents = [result]
|
||||
|
||||
result.note_changed_family()
|
||||
for current, copy_ in zip(self.get_family(), result.get_family()):
|
||||
copy_.points = np.array(current.points)
|
||||
copy_.match_color(current)
|
||||
|
||||
# this seems correct, but is not needed in 3b1b manim - investigate
|
||||
# for current, copy_ in zip(self.get_family(), result.get_family()):
|
||||
# copy_.points = np.array(current.points)
|
||||
# copy_.match_color(current)
|
||||
|
||||
# Similarly, instead of calling match_updaters, since we know the status
|
||||
# won't have changed, just directly match with shallow copies.
|
||||
result.non_time_updaters = self.non_time_updaters.copy()
|
||||
|
|
@ -1727,7 +1723,7 @@ class OpenGLMobject:
|
|||
|
||||
def apply_function(self, function: PointUpdateFunction, **kwargs) -> Self:
|
||||
# Default to applying matrix about the origin, not mobjects center
|
||||
if len(kwargs) == 0:
|
||||
if not kwargs:
|
||||
kwargs["about_point"] = ORIGIN
|
||||
self.apply_points_function(
|
||||
lambda points: np.array([function(p) for p in points]), **kwargs
|
||||
|
|
@ -2522,6 +2518,7 @@ class OpenGLMobject:
|
|||
null_mob = self.copy()
|
||||
null_mob.set_points([self.get_center()])
|
||||
self.submobjects = [null_mob.copy() for k in range(n)]
|
||||
self.note_changed_family()
|
||||
return self
|
||||
target = curr + n
|
||||
repeat_indices = (np.arange(target) * curr) // target
|
||||
|
|
@ -2537,6 +2534,7 @@ class OpenGLMobject:
|
|||
new_submob.set_opacity(0)
|
||||
new_submobs.append(new_submob)
|
||||
self.submobjects = new_submobs
|
||||
self.note_changed_family()
|
||||
return self
|
||||
|
||||
# Interpolate
|
||||
|
|
@ -2647,7 +2645,7 @@ class OpenGLMobject:
|
|||
for sm1, sm2 in zip(family1, family2):
|
||||
sm1.depth_test = sm2.depth_test
|
||||
# Make sure named family members carry over
|
||||
for attr, value in list(mobject.__dict__.items()):
|
||||
for attr, value in mobject.__dict__.items():
|
||||
if isinstance(value, OpenGLMobject) and value in family2:
|
||||
setattr(self, attr, family1[family2.index(value)])
|
||||
self.refresh_bounding_box(recurse_down=True)
|
||||
|
|
|
|||
|
|
@ -85,10 +85,9 @@ class OpenGLVMobject(OpenGLMobject):
|
|||
# so users can get autocomplete
|
||||
def __init__(
|
||||
self,
|
||||
color: ParsableManimColor | list[ParsableManimColor] | None = None,
|
||||
fill_color: ParsableManimColor | list[ParsableManimColor] | None = None,
|
||||
fill_color: ParsableManimColor | Sequence[ParsableManimColor] | None = None,
|
||||
fill_opacity: float | None = None,
|
||||
stroke_color: ParsableManimColor | list[ParsableManimColor] | None = None,
|
||||
stroke_color: ParsableManimColor | Sequence[ParsableManimColor] | None = None,
|
||||
stroke_opacity: float | None = None,
|
||||
stroke_width: float = DEFAULT_STROKE_WIDTH,
|
||||
draw_stroke_behind_fill: bool = False,
|
||||
|
|
@ -100,23 +99,17 @@ class OpenGLVMobject(OpenGLMobject):
|
|||
):
|
||||
super().__init__(**kwargs)
|
||||
if fill_color is None:
|
||||
fill_color = color
|
||||
fill_color = self.color
|
||||
if stroke_color is None:
|
||||
stroke_color = color
|
||||
self.fill_color: Sequence[ManimColor] = listify(ManimColor.parse(fill_color))
|
||||
self.set_fill(opacity=fill_opacity)
|
||||
self.stroke_color: Sequence[ManimColor] = listify(
|
||||
ManimColor.parse(stroke_color)
|
||||
)
|
||||
self.set_stroke(opacity=stroke_opacity)
|
||||
stroke_color = self.color
|
||||
self.set_fill(color=fill_color, opacity=fill_opacity)
|
||||
self.set_stroke(color=stroke_color, opacity=stroke_opacity)
|
||||
self.stroke_width = listify(stroke_width)
|
||||
self.draw_stroke_behind_fill = draw_stroke_behind_fill
|
||||
self.background_image_file = background_image_file
|
||||
self.long_lines = long_lines
|
||||
self.joint_type = joint_type
|
||||
self.flat_stroke = flat_stroke
|
||||
# TODO: Remove this because the new shader doesn't need it
|
||||
self.anti_alias_width = 1.0
|
||||
|
||||
self.needs_new_triangulation = True
|
||||
self.triangulation = np.zeros(0, dtype="i4")
|
||||
|
|
@ -134,10 +127,10 @@ class OpenGLVMobject(OpenGLMobject):
|
|||
return OpenGLVMobject
|
||||
|
||||
# These are here just to make type checkers happy
|
||||
def get_family(self, recurse: bool = True) -> list[OpenGLVMobject]: # type: ignore
|
||||
def get_family(self, recurse: bool = True) -> Sequence[OpenGLVMobject]:
|
||||
return super().get_family(recurse) # type: ignore
|
||||
|
||||
def family_members_with_points(self) -> list[OpenGLVMobject]: # type: ignore
|
||||
def family_members_with_points(self) -> Sequence[OpenGLVMobject]: # type: ignore
|
||||
return super().family_members_with_points() # type: ignore
|
||||
|
||||
def replicate(self, n: int) -> OpenGLVGroup: # type: ignore
|
||||
|
|
@ -216,7 +209,7 @@ class OpenGLVMobject(OpenGLMobject):
|
|||
for submob in self.submobjects:
|
||||
submob.set_fill(color, opacity, recurse=True)
|
||||
if color is not None:
|
||||
self.fill_color = listify(ManimColor.parse(color))
|
||||
self.fill_color: list[ManimColor] = listify(ManimColor.parse(color))
|
||||
if opacity is not None:
|
||||
self.fill_color = [c.opacity(opacity) for c in self.fill_color]
|
||||
return self
|
||||
|
|
@ -1583,6 +1576,7 @@ class OpenGLVGroup(OpenGLVMobject):
|
|||
"""
|
||||
self._assert_valid_submobjects(tuplify(value))
|
||||
self.submobjects[key] = value # type: ignore
|
||||
self.note_changed_family()
|
||||
|
||||
|
||||
class OpenGLVectorizedPoint(OpenGLPoint, OpenGLVMobject):
|
||||
|
|
|
|||
|
|
@ -154,6 +154,8 @@ class DecimalNumber(VMobject, metaclass=ConvertToOpenGL):
|
|||
|
||||
def _set_submobjects_from_number(self, number):
|
||||
self.number = number
|
||||
# the self.add below will recalculate the family,
|
||||
# no need to do it here.
|
||||
self.submobjects = []
|
||||
|
||||
num_string = self._get_num_string(number)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from textwrap import dedent
|
|||
from manim import config, logger
|
||||
from manim.constants import *
|
||||
from manim.mobject.geometry.line import Line
|
||||
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
|
||||
from manim.mobject.svg.svg_mobject import SVGMobject
|
||||
from manim.mobject.types.vectorized_mobject import VGroup
|
||||
from manim.utils.tex import TexTemplate
|
||||
|
|
@ -65,11 +64,6 @@ class SingleStringMathTex(SVGMobject):
|
|||
font_size: float = DEFAULT_FONT_SIZE,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs.get("color") is None:
|
||||
# makes it so that color isn't explicitly passed for these mobs,
|
||||
# and can instead inherit from the parent
|
||||
kwargs["color"] = OpenGLVMobject().color
|
||||
|
||||
self._font_size = font_size
|
||||
self.organize_left_to_right = organize_left_to_right
|
||||
self.tex_environment = tex_environment
|
||||
|
|
@ -248,7 +242,7 @@ class MathTex(SingleStringMathTex):
|
|||
*tex_strings,
|
||||
arg_separator: str = " ",
|
||||
substrings_to_isolate: Iterable[str] | None = None,
|
||||
tex_to_color_map: dict[str, ManimColor] = None,
|
||||
tex_to_color_map: dict[str, ManimColor] | None = None,
|
||||
tex_environment: str = "align*",
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -271,7 +265,7 @@ class MathTex(SingleStringMathTex):
|
|||
**kwargs,
|
||||
)
|
||||
self._break_up_by_substrings()
|
||||
except ValueError as compilation_error:
|
||||
except ValueError:
|
||||
if self.brace_notation_split_occurred:
|
||||
logger.error(
|
||||
dedent(
|
||||
|
|
@ -285,17 +279,11 @@ class MathTex(SingleStringMathTex):
|
|||
""",
|
||||
),
|
||||
)
|
||||
raise compilation_error
|
||||
raise
|
||||
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
|
||||
|
||||
if self.organize_left_to_right:
|
||||
self._organize_submobjects_left_to_right()
|
||||
self.note_changed_family()
|
||||
|
||||
# 5 hours of work went into this line
|
||||
# and it's still not perfect
|
||||
# July 18, 2024
|
||||
self.note_changed_family()
|
||||
|
||||
def _break_up_tex_strings(self, tex_strings):
|
||||
# Separate out anything surrounded in double braces
|
||||
|
|
@ -349,9 +337,16 @@ class MathTex(SingleStringMathTex):
|
|||
sub_tex_mob.move_to(self.submobjects[last_submob_index], RIGHT)
|
||||
else:
|
||||
sub_tex_mob.submobjects = self.submobjects[curr_index:new_index]
|
||||
sub_tex_mob.note_changed_family()
|
||||
new_submobjects.append(sub_tex_mob)
|
||||
curr_index = new_index
|
||||
self.submobjects = new_submobjects
|
||||
|
||||
# 5 hours of work went into this line
|
||||
# and it's still not perfect
|
||||
# July 18, 2024
|
||||
self.note_changed_family()
|
||||
|
||||
return self
|
||||
|
||||
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
|
||||
|
|
@ -423,6 +418,7 @@ class MathTex(SingleStringMathTex):
|
|||
|
||||
def sort_alphabetically(self):
|
||||
self.submobjects.sort(key=lambda m: m.get_tex_string())
|
||||
self.note_changed_family()
|
||||
|
||||
|
||||
class Tex(MathTex):
|
||||
|
|
|
|||
|
|
@ -525,6 +525,7 @@ class Text(SVGMobject):
|
|||
self.text = text
|
||||
if self.disable_ligatures:
|
||||
self.submobjects = [*self._gen_chars()]
|
||||
self.note_changed_family()
|
||||
self.chars = self.get_group_class()(*self.submobjects)
|
||||
self.text = text_without_tabs.replace(" ", "").replace("\n", "")
|
||||
nppc = self.n_points_per_curve
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ class PMobject(Mobject, metaclass=ConvertToOpenGL):
|
|||
for attr, array in zip(attrs, arrays):
|
||||
setattr(self, attr, array)
|
||||
self.submobjects = []
|
||||
self.note_changed_family()
|
||||
return self
|
||||
|
||||
def get_color(self):
|
||||
|
|
|
|||
|
|
@ -321,10 +321,11 @@ class VMobject(Mobject):
|
|||
if family:
|
||||
for submobject in self.submobjects:
|
||||
submobject.set_fill(color, opacity, family)
|
||||
self.update_rgbas_array("fill_rgbas", color, opacity)
|
||||
self.fill_rgbas: RGBA_Array_Float
|
||||
|
||||
if color is not None:
|
||||
self.fill_color = ManimColor.parse(color)
|
||||
if opacity is not None:
|
||||
self.fill_opacity = opacity
|
||||
self.fill_color = [c.opacity(opacity) for c in self.fill_color]
|
||||
return self
|
||||
|
||||
def set_stroke(
|
||||
|
|
|
|||
|
|
@ -512,6 +512,7 @@ class SpecialThreeDScene(ThreeDScene):
|
|||
piece.shade_in_3d = True
|
||||
new_pieces.match_style(axis.pieces)
|
||||
axis.pieces.submobjects = new_pieces.submobjects
|
||||
axis.pieces.note_changed_family()
|
||||
for tick in axis.tick_marks:
|
||||
tick.add(VectorizedPoint(1.5 * tick.get_center()))
|
||||
return axes
|
||||
|
|
|
|||
|
|
@ -505,7 +505,7 @@ class ManimColor:
|
|||
tmp[3] = alpha * 255
|
||||
return tmp.astype(int)
|
||||
|
||||
def to_hex(self, with_alpha: bool = False) -> str:
|
||||
def to_hex(self, *, with_alpha: bool = False) -> str:
|
||||
"""Converts the manim color to a hexadecimal representation of the color
|
||||
|
||||
Parameters
|
||||
|
|
@ -560,7 +560,7 @@ class ManimColor:
|
|||
"""
|
||||
return np.array(colorsys.rgb_to_hls(*self.to_rgb()))
|
||||
|
||||
def invert(self, with_alpha=False) -> Self:
|
||||
def invert(self, *, with_alpha: bool = False) -> Self:
|
||||
"""Returns an linearly inverted version of the color (no inplace changes)
|
||||
|
||||
Parameters
|
||||
|
|
|
|||
|
|
@ -31,16 +31,16 @@ class _FramesTester:
|
|||
def testing(self):
|
||||
with np.load(self._file_path) as data:
|
||||
self._frames = data["frame_data"]
|
||||
# For backward compatibility, when the control data contains only one frame (<= v0.8.0)
|
||||
if len(self._frames.shape) != 4:
|
||||
self._frames = np.expand_dims(self._frames, axis=0)
|
||||
logger.debug(self._frames.shape)
|
||||
self._number_frames = np.ma.size(self._frames, axis=0)
|
||||
yield
|
||||
assert self._frames_compared == self._number_frames, (
|
||||
f"The scene tested contained {self._frames_compared} frames, "
|
||||
f"when there are {self._number_frames} control frames for this test."
|
||||
)
|
||||
# For backward compatibility, when the control data contains only one frame (<= v0.8.0)
|
||||
if len(self._frames.shape) != 4:
|
||||
self._frames = np.expand_dims(self._frames, axis=0)
|
||||
logger.debug(self._frames.shape)
|
||||
self._number_frames = np.ma.size(self._frames, axis=0)
|
||||
yield
|
||||
assert self._frames_compared == self._number_frames, (
|
||||
f"The scene tested contained {self._frames_compared} frames, "
|
||||
f"when there are {self._number_frames} control frames for this test."
|
||||
)
|
||||
|
||||
def check_frame(self, frame_number: int, frame: PixelArray):
|
||||
assert frame_number < self._number_frames, (
|
||||
|
|
@ -56,7 +56,7 @@ class _FramesTester:
|
|||
verbose=False,
|
||||
)
|
||||
self._frames_compared += 1
|
||||
except AssertionError as e:
|
||||
except AssertionError:
|
||||
number_of_matches = np.isclose(
|
||||
frame, self._frames[frame_number], atol=FRAME_ABSOLUTE_TOLERANCE
|
||||
).sum()
|
||||
|
|
@ -80,7 +80,7 @@ class _FramesTester:
|
|||
self._frames[frame_number],
|
||||
self._file_path.name,
|
||||
)
|
||||
raise e from e
|
||||
raise
|
||||
|
||||
|
||||
class _ControlDataWriter(_FramesTester):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import functools
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -18,6 +18,13 @@ from ._test_class_makers import (
|
|||
_make_test_scene_class,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
__all__ = ["frames_comparison"]
|
||||
|
||||
SCENE_PARAMETER_NAME = "scene"
|
||||
|
|
@ -27,12 +34,12 @@ MIN_CAIRO_VERSION = 11800
|
|||
|
||||
|
||||
def frames_comparison(
|
||||
func: Callable[..., object] | None = None,
|
||||
func: Callable[P, object] | None = None,
|
||||
*,
|
||||
last_frame: bool = True,
|
||||
base_scene: type[Scene] = Scene,
|
||||
**custom_config,
|
||||
):
|
||||
) -> Callable[Concatenate[pytest.FixtureRequest, Path, P], object]:
|
||||
"""Compares the frames generated by the test with control frames previously registered.
|
||||
|
||||
If there is no control frames for this test, the test will fail. To generate
|
||||
|
|
@ -77,7 +84,12 @@ def frames_comparison(
|
|||
|
||||
@functools.wraps(tested_scene_construct)
|
||||
# The "request" parameter is meant to be used as a fixture by pytest. See below.
|
||||
def wrapper(*args, request: pytest.FixtureRequest, tmp_path, **kwargs):
|
||||
def wrapper(
|
||||
request: pytest.FixtureRequest,
|
||||
tmp_path: Path,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
):
|
||||
# Wraps the test_function to a construct method, to "freeze" the eventual additional arguments (parametrizations fixtures).
|
||||
construct = functools.partial(tested_scene_construct, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue