Cleanup legacy code in experimental (#3874)

* Remove data/uniforms
* Clean up some manager code.
This commit is contained in:
adeshpande 2024-07-24 21:13:46 -04:00 committed by GitHub
commit 424cb27c6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 161 additions and 504 deletions

View file

@ -6,6 +6,7 @@ class Test(Scene):
s = Square()
self.add(s)
self.play(Rotate(s, PI / 2))
self.wait(7)
self.play(FadeOut(s))
sq = RegularPolygon(6)
c = Circle()

View file

@ -6,8 +6,7 @@ import numpy as np
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
from .. import config, logger
from ..constants import RendererType
from .. import logger
from ..mobject import mobject
from ..mobject.mobject import Mobject
from ..mobject.opengl import opengl_mobject
@ -255,11 +254,7 @@ class Animation(AnimationProtocol):
return self.mobject, self.starting_mobject
def get_all_families_zipped(self) -> Iterable[tuple]:
if config["renderer"] == RendererType.OPENGL:
return zip(*(mob.get_family() for mob in self.get_all_mobjects()))
return zip(
*(mob.family_members_with_points() for mob in self.get_all_mobjects())
)
return zip(*(mob.get_family() for mob in self.get_all_mobjects()))
def update_mobjects(self, dt: float) -> None:
"""
@ -456,7 +451,7 @@ def prepare_animation(
| mobject._AnimationBuilder
| opengl_mobject._AnimationBuilder
| opengl_mobject.OpenGLMobject,
) -> Animation:
) -> AnimationProtocol:
r"""Returns either an unchanged animation, or the animation built
from a passed animation factory.
@ -526,7 +521,6 @@ class Wait(Animation):
if stop_condition and frozen_frame:
raise ValueError("A static Wait animation cannot have a stop condition.")
self.duration: float = run_time
self.stop_condition = stop_condition
self.is_static_wait: bool = bool(frozen_frame)
super().__init__(None, run_time=run_time, rate_func=rate_func, **kwargs)

View file

@ -3,6 +3,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from manim.typing import RateFunc
from .scene_buffer import SceneBuffer
@ -13,12 +15,19 @@ class AnimationProtocol(Protocol):
buffer: SceneBuffer
apply_buffer: bool
def begin(self) -> None: ...
def begin(self) -> object: ...
def finish(self) -> None: ...
def finish(self) -> object: ...
def update_mobjects(self, dt: float) -> None: ...
def update_mobjects(self, dt: float) -> object: ...
def interpolate(self, alpha: float) -> None: ...
def interpolate(self, alpha: float) -> object: ...
def get_run_time(self) -> float: ...
def update_rate_info(
self,
run_time: float | None,
rate_func: RateFunc | None,
lag_ratio: float | None,
) -> object: ...

View file

@ -30,7 +30,7 @@ __all__ = [
import inspect
import types
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable
import numpy as np
@ -185,15 +185,14 @@ class Transform(Animation):
self._path_func = path_func
def begin(self) -> None:
# Use a copy of target_mobject for the align_data
# call so that the actual target_mobject stays
# preserved.
self.target_mobject = self.create_target()
# Note, this potentially changes the structure
# of both mobject and target_mobject
if self.mobject.is_aligned_with(self.target_mobject):
self.target_copy = self.target_mobject
else:
# Use a copy of target_mobject for the align_data_and_family
# call so that the actual target_mobject stays
# preserved, since calling align_data will potentially
# change the structure of both arguments
self.target_copy = self.target_mobject.copy()
self.mobject.align_data_and_family(self.target_copy)
@ -220,7 +219,9 @@ class Transform(Animation):
self.target_copy,
]
def get_all_families_zipped(self) -> Iterable[tuple]: # more precise typing?
def get_all_families_zipped(
self,
) -> zip[tuple[OpenGLMobject, OpenGLMobject, OpenGLMobject]]:
mobs = [
self.mobject,
self.starting_mobject,

View file

@ -112,7 +112,7 @@ class Manager(Generic[Scene_co]):
-------
A file writer satisfying :class:`.FileWriterProtocol`
"""
return FileWriter(self.scene.get_default_scene_name())
return FileWriter(scene_name=self.scene.get_default_scene_name())
def setup(self) -> None:
"""Set up processes and manager"""
@ -122,6 +122,7 @@ class Manager(Generic[Scene_co]):
# these are used for making sure it feels like the correct
# amount of time has passed in the window instead of rendering
# at full speed
# See the docstring of :meth:`_wait_for_animation_time`
self.virtual_animation_start_time = 0.0
self.real_animation_start_time = time.perf_counter()
@ -176,7 +177,7 @@ class Manager(Generic[Scene_co]):
self.file_writer.finish()
# otherwise no animations were played
elif config.write_to_movie or config.save_last_frame:
self.render_state(write_to_file=False)
self.render_state(write_frame=False)
# FIXME: for some reason the OpenGLRenderer does not give out the
# correct frame values here
frame = self.renderer.get_pixels()
@ -190,9 +191,6 @@ class Manager(Generic[Scene_co]):
self.scene.tear_down()
if config.save_last_frame:
self._update_frame(0)
if self.window is not None:
self.window.close()
self.window = None
@ -213,13 +211,17 @@ class Manager(Generic[Scene_co]):
while not self.window.is_closing:
self._update_frame(dt)
def _update_frame(self, dt: float, *, write_to_file: bool | None = None) -> None:
# ----------------------------------#
# Animation Pipeline #
# ----------------------------------#
def _update_frame(self, dt: float, *, write_frame: bool | None = None) -> None:
"""Update the current frame by ``dt``
Parameters
----------
dt : the time in between frames
write_to_file : Whether to write the result to the output stream.
write_frame : Whether to write the result to the output stream (videos ONLY).
Default value checks :attr:`_write_files` to see if it should be written.
"""
self.time += dt
@ -229,20 +231,38 @@ class Manager(Generic[Scene_co]):
if self.window is not None:
self.window.clear()
# if it's closing, then any subsequent methods will
# raise an error because the internal C window pointer is nullptr.
if self.window.is_closing:
raise EndSceneEarlyException()
self.render_state(write_to_file=write_to_file)
self.render_state(write_frame=write_frame)
if self.window is not None:
self._wait_for_animation_time()
def _wait_for_animation_time(self) -> None:
"""Wait for the real time to catch up to the "virtual" animation time.
Animations can render faster than real time, so we have to
slow the window down for the correct amount of time, such
as during a wait animation.
"""
if self.window is None:
return
self.window.swap_buffers()
vt = self.time - self.virtual_animation_start_time
rt = time.perf_counter() - self.real_animation_start_time
# we can't sleep because we still need to poll for events,
# e.g. hitting Escape or close
while rt < vt:
if self.window.is_closing:
raise EndSceneEarlyException()
# make sure to poll for events
self.window.swap_buffers()
# This recursively updates the window with dt=0 until the correct
# amount of time has passed
# TODO: do ^ better with less overhead
vt = self.time - self.virtual_animation_start_time
rt = time.perf_counter() - self.real_animation_start_time
if rt < vt:
self._update_frame(0, write_to_file=False)
def _play(self, *animations: AnimationProtocol) -> None:
"""Play a bunch of animations"""
@ -298,7 +318,7 @@ class Manager(Generic[Scene_co]):
) -> tqdm | contextlib.nullcontext[NullProgressBar]:
"""Create a progressbar"""
if not config.write_to_movie or not config.progress_bar:
if not config.progress_bar:
return contextlib.nullcontext(NullProgressBar())
else:
return tqdm(
@ -322,28 +342,35 @@ class Manager(Generic[Scene_co]):
self._write_hashed_movie_file(animations=[])
if self.window is not None:
self.real_animation_start_time = time.perf_counter()
self.virtual_animation_start_time = self.time
update_mobjects = self.scene.should_update_mobjects()
condition = stop_condition or (lambda: False)
progression = self._calc_time_progression(duration)
state = self.scene.get_state()
with self._create_progressbar(
progression.shape[0], "Waiting %(num)d: "
) as progress:
last_t = 0
for t in progression:
dt, last_t = t - last_t, t
if update_mobjects:
if update_mobjects or stop_condition is not None:
self._update_frame(dt)
if condition():
progress.update(duration - t)
break
else:
# if we don't need to update mobjects
# we can just leave the mobjects on the window
# and increment the time
# but we still have to write frames
self.time += dt
self.write_frame()
# this fixes it, but at that point we might as well
# just not cache
self.renderer.render(self.scene.camera, state.mobjects)
if self.window is not None and self.window.is_closing:
raise EndSceneEarlyException()
self._wait_for_animation_time()
progress.update(1)
self.scene.post_play()
@ -379,27 +406,36 @@ class Manager(Generic[Scene_co]):
"""
return max(animation.get_run_time() for animation in animations)
def render_state(self, write_to_file: bool | None = None) -> None:
# -------------------------#
# Rendering #
# -------------------------#
def render_state(self, write_frame: bool | None = None) -> None:
"""Render the current state of the scene.
Any extra kwargs are passed to :meth:`_render_frame`.
"""
state = self.scene.get_state()
self._render_frame(state, write_file=write_to_file)
self._render_frame(state, write_frame=write_frame)
def _render_frame(
self, state: SceneState, *, write_file: bool | None = None
self, state: SceneState, *, write_frame: bool | None = None
) -> None:
"""Renders a frame based on a state, and writes it to a file.
"""Renders a frame based on a state, and writes it to the file writers stream.
Any extra kwargs are passed to :meth:`write_frame`.
This is used for writing a single frame. Any extra kwargs are passed to :meth:`write_frame`.
.. warning::
This method will not work if :meth:`.FileWriter.begin_animation` and
:meth:`.FileWriter.add_partial_movie_file` have not been called. Do NOT
use this to write a single frame!
"""
# render the frame to the window
# TODO: change self.scene.camera to state.camera
self.renderer.render(self.scene.camera, state.mobjects)
should_write = write_file if write_file is not None else self._write_files
should_write = write_frame if write_frame is not None else self._write_files
if should_write:
self.write_frame()

View file

@ -23,7 +23,6 @@ from manim.constants import *
from manim.event_handler import EVENT_DISPATCHER
from manim.event_handler.event_listener import EventListener
from manim.event_handler.event_type import EventType
from manim.renderer.shader_wrapper import ShaderWrapper, get_colormap_code
from manim.utils.bezier import integer_interpolate, interpolate
from manim.utils.color import *
from manim.utils.deprecation import deprecated
@ -31,11 +30,8 @@ from manim.utils.deprecation import deprecated
# from ..utils.iterables import batch_by_property
from manim.utils.iterables import (
list_update,
listify,
make_even,
resize_array,
resize_preserving_order,
resize_with_interpolation,
uniq_chain,
)
from manim.utils.paths import straight_path
@ -49,7 +45,7 @@ from manim.utils.space_ops import (
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Union
from typing import Callable
import numpy.typing as npt
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias
@ -62,7 +58,7 @@ if TYPE_CHECKING:
["OpenGLMobject", float], "OpenGLMobject | None"
]
NonTimeUpdater: TypeAlias = Callable[["OpenGLMobject"], "OpenGLMobject | None"]
Updater: TypeAlias = Union[TimeBasedUpdater, NonTimeUpdater]
Updater: TypeAlias = TimeBasedUpdater | NonTimeUpdater
PointUpdateFunction: TypeAlias = Callable[[np.ndarray], np.ndarray]
M = TypeVar("M", bound="OpenGLMobject")
@ -75,8 +71,6 @@ T_co = TypeVar("T_co", covariant=True, bound="OpenGLMobject")
logger = logging.getLogger("manim")
UNIFORM_DTYPE = np.float64
def stash_mobject_pointers(
func: Callable[Concatenate[M, P], T],
@ -169,24 +163,22 @@ class OpenGLMobject(Generic[R]):
self.name = self.__class__.__name__ if name is None else name
# internal_state
self.points = np.zeros((0, 3))
self.submobjects: list[OpenGLMobject] = []
self.parents: list[OpenGLMobject] = []
self.family: list[OpenGLMobject] = [self]
self.locked_data_keys: set[str] = set()
self.needs_new_bounding_box: bool = True
self._bounding_box = np.zeros((3, 3))
self._is_animating: bool = False
self.saved_state: OpenGLMobject | None = None
self.target: OpenGLMobject | None = None
self.data: dict[str, np.ndarray] = {}
self.uniforms: dict[str, float | np.ndarray] = {}
# TODO replace with protocol
self.renderer_data: R | None = None
# currently does nothing
self.status = MobjectStatus()
self.init_data()
self.init_uniforms()
self.init_updaters()
self.init_event_listeners()
self.init_points()
@ -214,6 +206,14 @@ class OpenGLMobject(Generic[R]):
raise TypeError(f"Only int can be multiplied to Mobjects not {type(other)}")
return self.replicate(other)
@property
def bounding_box(self) -> npt.NDArray[np.float64]:
return self._bounding_box
@bounding_box.setter
def bounding_box(self, box: npt.NDArray[np.float64]):
self._bounding_box = box
@classmethod
def set_default(cls, **kwargs):
"""Sets the default values of keyword arguments.
@ -262,50 +262,6 @@ class OpenGLMobject(Generic[R]):
else:
cls.__init__ = cls._original__init__
@property
def points(self):
return self.data["points"]
@points.setter
def points(self, value):
self.data["points"] = value
@property
def bounding_box(self):
return self.data["bounding_box"]
@bounding_box.setter
def bounding_box(self, value):
self.data["bounding_box"] = value
@property
def rgbas(self):
return self.data["rgbas"]
@rgbas.setter
def rgbas(self, value):
self.data["rgbas"] = value
def init_data(self):
"""Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data.
Subclasses can inherit and overwrite this method to extend `self.data`."""
self.data = {
"points": np.zeros((0, 3)),
"bounding_box": np.zeros((3, 3)),
"rgbas": np.zeros((1, 4)),
}
def init_uniforms(self):
"""Initializes the uniforms.
Gets called upon creation"""
self.uniforms = {
"is_fixed_in_frame": float(self.is_fixed_in_frame),
"gloss": float(self.gloss),
"shadow": float(self.shadow),
"reflectiveness": float(self.reflectiveness),
}
def init_colors(self):
"""Initializes the colors.
@ -320,18 +276,6 @@ class OpenGLMobject(Generic[R]):
# Typically implemented in subclass, unless purposefully left blank
pass
def set_data(self, data: dict[str, Any]):
for key, value in data.items():
self.data[key] = value
return self
def set_uniforms(self, uniforms):
for key, value in uniforms.items():
if isinstance(value, np.ndarray):
value = value.copy()
self.uniforms[key] = value
return self
# https://github.com/python/typing/issues/802
# so we hack around it by doing | Self
# but this causes issues in Scene.play which only
@ -446,12 +390,7 @@ class OpenGLMobject(Generic[R]):
return self
def reverse_points(self, recursive=False):
for key in self.data:
self.data[key] = self.data[key][::-1]
if recursive:
for mob in self.submobjects:
for key in mob.data:
mob.data[key] = mob.data[key][::-1]
self.points = self.points[::-1]
return self
def apply_points_function(
@ -1379,6 +1318,7 @@ class OpenGLMobject(Generic[R]):
# Copying
# TODO: don't use self
@stash_mobject_pointers
def serialize(self) -> bytes:
return pickle.dumps(self)
@ -1417,27 +1357,25 @@ class OpenGLMobject(Generic[R]):
result = copy.copy(self)
# The line above is only a shallow copy, so the internal
# data which are numpyu arrays or other mobjects still
# need to be further copied.
result.data = {k: np.array(v) for k, v in self.data.items()}
result.uniforms = {k: np.array(v) for k, v in self.uniforms.items()}
result.parents = []
result.target = None
result.saved_state = None
#
result.points = np.array(self.points)
#
# Instead of adding using result.add, which does some checks for updating
# updater statues and bounding box, just directly modify the family-related
# lists
result.submobjects = [sm.copy() for sm in self.submobjects]
for sm in result.submobjects:
sm.parents = [result]
result.family = [
result,
*it.chain(*(sm.get_family() for sm in result.submobjects)),
]
result.note_changed_family()
# Similarly, instead of calling match_updaters, since we know the status
# won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters)
result.time_based_updaters = list(self.time_based_updaters)
# won't have changed, just directly match with shallow copies.
result.non_time_updaters = self.non_time_updaters.copy()
result.time_based_updaters = self.time_based_updaters.copy()
family = self.get_family()
for attr, value in list(self.__dict__.items()):
@ -1449,8 +1387,6 @@ class OpenGLMobject(Generic[R]):
setattr(result, attr, result.family[self.family.index(value)])
if isinstance(value, np.ndarray):
setattr(result, attr, value.copy())
if isinstance(value, ShaderWrapper):
setattr(result, attr, value.copy())
return result
def generate_target(self, use_deepcopy: bool = False):
@ -2216,123 +2152,20 @@ class OpenGLMobject(Generic[R]):
# Color functions
def set_rgba_array_legacy(
self, color=None, opacity=None, name="rgbas", recurse=True
):
if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
if opacity is not None:
opacities = listify(opacity)
# Color only
if color is not None and opacity is None:
for mob in self.get_family(recurse):
mob.data[name] = resize_array(
mob.data[name] if name in mob.data else np.empty((1, 3)), len(rgbs)
)
mob.data[name][:, :3] = rgbs
# Opacity only
if color is None and opacity is not None:
for mob in self.get_family(recurse):
mob.data[name] = resize_array(
mob.data[name] if name in mob.data else np.empty((1, 3)),
len(opacities),
)
mob.data[name][:, 3] = opacities
# Color and opacity
if color is not None and opacity is not None:
rgbas = np.array([[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))])
for mob in self.get_family(recurse):
mob.data[name] = rgbas.copy()
return self
def set_rgba_array(
self, rgba_array: np.ndarray, name: str = "rgbas", recurse: bool = False
):
"""Directly set rgba data from `rgbas` and optionally do the same recursively
with submobjects. This can be used if the `rgbas` have already been generated
with the correct shape and simply need to be set.
Parameters
----------
rgbas
the rgba to be set as data
name
the name of the data attribute to be set
recurse
set to true to recursively apply this method to submobjects
"""
for mob in self.get_family(recurse):
mob.data[name] = np.array(rgba_array) # type: ignore
return self
def set_color_by_rgba_func(
self, func: Callable[[np.ndarray], np.ndarray], recurse: bool = True
):
"""
Func should take in a point in R3 and output an rgba value
"""
for mob in self.get_family(recurse):
rgba_array = np.asarray([func(point) for point in mob.points])
mob.set_rgba_array(rgba_array)
return self
def set_color_by_rgb_func(
self,
func: Callable[[np.ndarray], np.ndarray],
opacity: float = 1,
recurse: bool = True,
):
"""
Func should take in a point in R3 and output an rgb value
"""
for mob in self.get_family(recurse):
rgba_array = np.asarray([[*func(point), opacity] for point in mob.points])
mob.set_rgba_array(rgba_array)
return self
def set_rgba_array_by_color(
self,
color=None,
opacity: float | Iterable[float] | None = None,
name: str = "rgbas",
recurse: bool = True,
):
max_len = 0
if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
max_len = len(rgbs)
if opacity is not None:
opacities = np.array(listify(opacity))
max_len = max(max_len, len(opacities))
for mob in self.get_family(recurse):
if max_len > len(mob.data[name]): # type: ignore
mob.data[name] = resize_array(mob.data[name], max_len) # type: ignore
size = len(mob.data[name]) # type: ignore
if color is not None:
mob.data[name][:, :3] = resize_array(rgbs, size) # type: ignore
if opacity is not None:
mob.data[name][:, 3] = resize_array(opacities, size) # type: ignore
return self
def set_color(self, color: ParsableManimColor | None, opacity=None, recurse=True):
self.set_rgba_array(color, opacity, recurse=False)
# Recurse to submobjects differently from how set_rgba_array
# in case they implement set_color differently
if color is not None:
self.color: ManimColor = ManimColor.parse(color)
if opacity is not None:
self.opacity = opacity
self.color.set_opacity(opacity)
if recurse:
for submob in self.submobjects:
submob.set_color(color, recurse=True)
return self
def set_opacity(self, opacity, recurse=True):
self.set_rgba_array(color=None, opacity=opacity, recurse=False)
# self.set_rgba_array(color=None, opacity=opacity, recurse=False)
if recurse:
for submob in self.submobjects:
submob.set_opacity(opacity, recurse=True)
@ -2342,9 +2175,9 @@ class OpenGLMobject(Generic[R]):
return rgb_to_hex(self.rgbas[0, :3])
def get_opacity(self):
return self.data["rgbas"][0, 3]
return self.color._internal_value[3]
def set_color_by_gradient(self, *colors: Color):
def set_color_by_gradient(self, *colors: ParsableManimColor):
if self.has_points():
self.set_color(colors)
else:
@ -2368,30 +2201,6 @@ class OpenGLMobject(Generic[R]):
def fade(self, darkness=0.5, recurse=True):
self.set_opacity(1.0 - darkness, recurse=recurse)
def get_reflectiveness(self) -> np.ndarray:
return self.uniforms["reflectiveness"]
def set_reflectiveness(self, reflectiveness: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.uniforms["reflectiveness"] = float(reflectiveness)
return self
def get_shadow(self) -> np.ndarray:
return self.uniforms["shadow"]
def set_shadow(self, shadow: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.uniforms["shadow"] = float(shadow)
return self
def get_gloss(self) -> np.ndarray:
return self.uniforms["gloss"]
def set_gloss(self, gloss: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.uniforms["gloss"] = float(gloss)
return self
# Background rectangle
def add_background_rectangle(
@ -2668,13 +2477,9 @@ class OpenGLMobject(Generic[R]):
# Alignment
def is_aligned_with(self, mobject: OpenGLMobject) -> bool:
return (
len(self.data) == len(mobject.data)
and len(self.submobjects) == len(mobject.submobjects)
and all(
sm1.is_aligned_with(sm2)
for sm1, sm2 in zip(self.submobjects, mobject.submobjects)
)
return len(self.submobjects) == len(mobject.submobjects) and all(
sm1.is_aligned_with(sm2)
for sm1, sm2 in zip(self.submobjects, mobject.submobjects)
)
def align_data_and_family(self, mobject):
@ -2686,15 +2491,6 @@ class OpenGLMobject(Generic[R]):
# Separate out how points are treated so that subclasses
# can handle that case differently if they choose
mob1.align_points(mob2)
for key in mob1.data.keys() & mob2.data.keys():
if key == "points":
continue
arr1 = mob1.data[key] # type: ignore
arr2 = mob2.data[key]
if len(arr2) > len(arr1):
mob1.data[key] = resize_preserving_order(arr1, len(arr2)) # type: ignore
elif len(arr1) > len(arr2):
mob2.data[key] = resize_preserving_order(arr2, len(arr1))
def align_points(self, mobject) -> Self:
max_len = max(self.get_num_points(), mobject.get_num_points())
@ -2779,6 +2575,9 @@ class OpenGLMobject(Generic[R]):
self.interpolate_color(mobject1, mobject2, alpha)
return self
def interpolate_color(self, mobject1, mobject2, alpha):
raise NotImplementedError("Implemented in subclasses")
def pointwise_become_partial(self, mobject, a, b):
"""
Set points in such a way as to become only
@ -2788,41 +2587,6 @@ class OpenGLMobject(Generic[R]):
"""
pass # To implement in subclass
# Locking data
def lock_data(self, keys: Iterable[str]):
"""
To speed up some animations, particularly transformations,
it can be handy to acknowledge which pieces of data
won't change during the animation so that calls to
interpolate can skip this, and so that it's not
read into the shader_wrapper objects needlessly
"""
if self.has_updaters:
return
# Be sure shader data has most up to date information
self.refresh_shader_data()
self.locked_data_keys = set(keys)
def lock_matching_data(self, mobject1: OpenGLMobject, mobject2: OpenGLMobject):
for sm, sm1, sm2 in zip(
self.get_family(), mobject1.get_family(), mobject2.get_family()
):
keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys()
sm.lock_data(
list(
filter(
lambda key: np.all(sm1.data[key] == sm2.data[key]), # type: ignore
keys,
)
)
)
return self
def unlock_data(self):
for mob in self.get_family():
mob.locked_data_keys = set()
def become(
self,
mobject: OpenGLMobject,
@ -2888,8 +2652,6 @@ class OpenGLMobject(Generic[R]):
family1 = self.get_family()
family2 = mobject.get_family()
for sm1, sm2 in zip(family1, family2):
sm1.set_data(sm2.data)
sm1.set_uniforms(sm2.uniforms)
sm1.shader_folder = sm2.shader_folder
sm1.texture_paths = sm2.texture_paths
sm1.depth_test = sm2.depth_test
@ -2906,22 +2668,7 @@ class OpenGLMobject(Generic[R]):
def looks_identical(self, mobject: OpenGLMobject) -> bool:
fam1 = self.family_members_with_points()
fam2 = mobject.family_members_with_points()
if len(fam1) != len(fam2):
return False
for m1, m2 in zip(fam1, fam2):
for d1, d2 in [(m1.data, m2.data), (m1.uniforms, m2.uniforms)]:
if set(d1).difference(d2):
return False
for key in d1:
if (
isinstance(d1[key], np.ndarray)
and isinstance(d2[key], np.ndarray)
and (d1[key].size != d2[key].size)
):
return False
if not np.isclose(d1[key], d2[key]).all():
return False
return True
return len(fam1) == len(fam2)
def has_same_shape_as(self, mobject: OpenGLMobject) -> bool:
# Normalize both point sets by centering and making height 1
@ -2933,20 +2680,16 @@ class OpenGLMobject(Generic[R]):
return False
return bool(np.isclose(points1, points2).all())
# Operations touching shader uniforms
def fix_in_frame(self) -> Self:
self.uniforms["is_fixed_in_frame"] = 1.0
self.is_fixed_in_frame = True
return self
def fix_orientation(self) -> Self:
self.uniforms["is_fixed_orientation"] = 1.0
self.is_fixed_orientation = True
self.fixed_orientation_center = tuple(self.get_center())
return self
def unfix_from_frame(self) -> Self:
self.uniforms["is_fixed_in_frame"] = 0.0
self.is_fixed_in_frame = False
return self
@ -2963,58 +2706,6 @@ class OpenGLMobject(Generic[R]):
self.depth_test = False
return self
# Shader code manipulation
def replace_shader_code(self, old, new):
# TODO, will this work with VMobject structure, given
# that it does not simpler return shader_wrappers of
# family?
for wrapper in self.get_shader_wrapper_list():
wrapper.replace_code(old, new)
return self
def set_color_by_code(self, glsl_code):
"""
Takes a snippet of code and inserts it into a
context which has the following variables:
vec4 color, vec3 point, vec3 unit_normal.
The code should change the color variable
"""
self.replace_shader_code("///// INSERT COLOR FUNCTION HERE /////", glsl_code)
return self
def set_color_by_xyz_func(
self,
glsl_snippet,
min_value=-5.0,
max_value=5.0,
colormap="viridis",
):
"""
Pass in a glsl expression in terms of x, y and z which returns
a float.
"""
# TODO, add a version of this which changes the point data instead
# of the shader code
for char in "xyz":
glsl_snippet = glsl_snippet.replace(char, "point." + char)
rgb_list = get_colormap_list(colormap)
self.set_color_by_code(
f"color.rgb = float_to_color({glsl_snippet}, {float(min_value)}, {float(max_value)}, {get_colormap_code(rgb_list)});",
)
return self
def check_data_alignment(self, array, data_key):
# Makes sure that self.data[key] can be broadcast into
# the given array, meaning its length has to be either 1
# or the length of the array
d_len = len(self.data[data_key])
if d_len != 1 and d_len != len(array):
self.data[data_key] = resize_with_interpolation(
self.data[data_key],
len(array),
)
return self
# Event Handlers
"""
Event handling follows the Event Bubbling model of DOM in javascript.

View file

@ -3,11 +3,9 @@ from __future__ import annotations
import itertools as it
import operator as op
from functools import reduce
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal
import moderngl
import numpy as np
from numpy.typing import NDArray
from manim.constants import *
from manim.mobject.opengl.opengl_mobject import (
@ -30,7 +28,6 @@ from manim.utils.deprecation import deprecated
from manim.utils.iterables import (
listify,
make_even,
resize_with_interpolation,
)
from manim.utils.space_ops import (
angle_between_vectors,
@ -42,6 +39,7 @@ from manim.utils.space_ops import (
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence
import numpy.typing as npt
from typing_extensions import Self
__all__ = [
@ -61,23 +59,6 @@ class OpenGLVMobject(OpenGLMobject):
"""A vectorized mobject."""
n_points_per_curve: int = 3
stroke_shader_folder = "quadratic_bezier_stroke"
fill_shader_folder = "quadratic_bezier_fill"
fill_dtype = [
("point", np.float32, (3,)),
("unit_normal", np.float32, (3,)),
("color", np.float32, (4,)),
("vert_index", np.float32, (1,)),
]
stroke_dtype = [
("point", np.float32, (3,)),
("prev_point", np.float32, (3,)),
("next_point", np.float32, (3,)),
("stroke_width", np.float32, (1,)),
("color", np.float32, (4,)),
]
render_primitive: int = moderngl.TRIANGLES
pre_function_handle_to_anchor_scale_factor: float = 0.01
make_smooth_after_applying_functions: bool = False
tolerance_for_point_equality: float = 1e-8
@ -97,6 +78,7 @@ class OpenGLVMobject(OpenGLMobject):
flat_stroke: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if fill_color is None:
fill_color = color
if stroke_color is None:
@ -107,8 +89,6 @@ class OpenGLVMobject(OpenGLMobject):
ManimColor.parse(stroke_color)
)
self.set_stroke(opacity=stroke_opacity)
if stroke_width is None:
stroke_width = DEFAULT_STROKE_WIDTH
self.stroke_width = listify(stroke_width)
self.draw_stroke_behind_fill = draw_stroke_behind_fill
self.background_image_file = background_image_file
@ -121,7 +101,6 @@ class OpenGLVMobject(OpenGLMobject):
self.needs_new_triangulation = True
self.triangulation = np.zeros(0, dtype="i4")
super().__init__(**kwargs)
# self.refresh_unit_normal()
def _assert_valid_submobjects(self, submobjects: Iterable[OpenGLVMobject]) -> Self:
@ -134,36 +113,6 @@ class OpenGLVMobject(OpenGLMobject):
def get_mobject_type_class():
return OpenGLVMobject
@property
def rgbas(self):
raise NotImplementedError(
"rgbas is not implemented for OpenGLVMobject. please use fill_rgba and stroke_rgba."
)
@rgbas.setter
def rgbas(self, value):
raise NotImplementedError(
"rgbas is not implemented for OpenGLVMobject. please use fill_rgba and stroke_rgba."
)
def init_data(self):
super().init_data()
self.data.pop("rgbas")
self.data.update(
{
"fill_rgba": np.zeros((1, 4)),
"stroke_rgba": np.zeros((1, 4)),
"stroke_width": np.zeros((1, 1)),
"unit_normal": np.zeros((1, 3)),
}
)
def init_uniforms(self):
super().init_uniforms()
self.uniforms["anti_alias_width"] = float(self.anti_alias_width)
self.uniforms["joint_type"] = float(self.joint_type.value)
self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy
def get_family(self, recurse: bool = True) -> list[OpenGLVMobject]: # type: ignore
return super().get_family(recurse) # type: ignore
@ -202,21 +151,6 @@ class OpenGLVMobject(OpenGLMobject):
# self.color = self.get_color()
return self
def set_rgba_array(
self, rgba_array: np.ndarray, name: str | None = None, recurse: bool = False
) -> Self:
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
names = [name]
for name in names:
if name in self.data:
self.data[name] = rgba_array
else:
raise Exception(f"{name} is not a valid data name.")
return self
def set_fill(
self,
color: ParsableManimColor | Sequence[ParsableManimColor] | None = None,
@ -258,11 +192,13 @@ class OpenGLVMobject(OpenGLMobject):
--------
:meth:`~.OpenGLVMobject.set_style`
"""
for mob in self.get_family(recurse):
if color is not None:
mob.fill_color = listify(ManimColor.parse(color))
if opacity is not None:
mob.fill_color = [c.set_opacity(opacity) for c in mob.fill_color]
if recurse:
for submob in self.submobjects:
submob.set_fill(color, opacity, recurse=True)
if color is not None:
self.fill_color = listify(ManimColor.parse(color))
if opacity is not None:
self.fill_color = [c.set_opacity(opacity) for c in self.fill_color]
return self
def set_stroke(
@ -295,12 +231,6 @@ class OpenGLVMobject(OpenGLMobject):
self.set_stroke(color, width, background=background)
return self
def align_stroke_width_data_to_points(self, recurse: bool = True) -> None:
for mob in self.get_family(recurse):
mob.data["stroke_width"] = resize_with_interpolation(
mob.data["stroke_width"], len(mob.points)
)
def set_style(
self,
fill_color: ParsableManimColor | Iterable[ParsableManimColor] | None = None,
@ -323,12 +253,6 @@ class OpenGLVMobject(OpenGLMobject):
recurse=False,
background=stroke_background,
)
if reflectiveness is not None:
mob.set_reflectiveness(reflectiveness, recurse=False)
if gloss is not None:
mob.set_gloss(gloss, recurse=False)
if shadow is not None:
mob.set_shadow(shadow, recurse=False)
return self
def get_style(self):
@ -337,9 +261,6 @@ class OpenGLVMobject(OpenGLMobject):
"stroke_color": self.stroke_color.copy(),
"stroke_width": self.stroke_width.copy(),
# "stroke_background": self.draw_stroke_behind_fill,
"reflectiveness": self.get_reflectiveness(),
"gloss": self.get_gloss(),
"shadow": self.get_shadow(),
}
def match_style(self, vmobject: OpenGLVMobject, recurse: bool = True):
@ -502,7 +423,7 @@ class OpenGLVMobject(OpenGLMobject):
else:
self.append_points([self.get_last_point(), handle, anchor])
def add_line_to(self, point: Sequence[float] | NDArray[float]) -> Self:
def add_line_to(self, point: Sequence[float] | npt.NDArray[float]) -> Self:
"""Add a straight line from the last point of OpenGLVMobject to the given point.
Parameters
@ -564,7 +485,7 @@ class OpenGLVMobject(OpenGLMobject):
return self.consider_points_equals(self.points[0], self.points[-1])
def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True):
vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()]
vmobs = [vm for vm in self.get_family(recurse=recurse) if vm.has_points()]
for vmob in vmobs:
new_points = []
for tup in vmob.get_bezier_tuples():
@ -619,7 +540,9 @@ class OpenGLVMobject(OpenGLMobject):
self.make_approximately_smooth()
return self
def change_anchor_mode(self, mode) -> Self:
def change_anchor_mode(
self, mode: Literal["jagged", "approx_smooth", "true_smooth"]
) -> Self:
"""Changes the anchor mode of the bezier curves. This will modify the handles.
There can be only three modes, "jagged", "approx_smooth" and "true_smooth".
@ -689,11 +612,10 @@ class OpenGLVMobject(OpenGLMobject):
if self.has_new_path_started():
# Remove last point, which is starting
# a new path
self.resize_data(len(self.points - 1))
self.points = self.points[:-1]
self.append_points(new_points)
return self
#
def consider_points_equals(self, p0, p1):
return np.linalg.norm(p1 - p0) < self.tolerance_for_point_equality
@ -1217,7 +1139,7 @@ class OpenGLVMobject(OpenGLMobject):
if self.get_num_points() == vmobject.get_num_points():
return
for mob in self, vmobject:
for mob in (self, vmobject):
# If there are no points, add one to
# where the "center" is
if not mob.has_points():

View file

@ -495,13 +495,12 @@ class VMobjectFromSVGPath(VMobject, metaclass=ConvertToOpenGL):
self.handle_commands()
if config.renderer == "opengl":
if self.should_subdivide_sharp_curves:
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
self.set_points(self.get_points_without_null_curves())
if self.should_subdivide_sharp_curves:
# For a healthy triangulation later
self.subdivide_sharp_curves()
if self.should_remove_null_curves:
# Get rid of any null curves
self.set_points(self.get_points_without_null_curves())
generate_points = init_points

View file

@ -33,8 +33,9 @@ from textwrap import dedent
from manim import config, logger
from manim.constants import *
from manim.mobject.geometry.line import Line
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.mobject.svg.svg_mobject import SVGMobject
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
from manim.mobject.types.vectorized_mobject import VGroup
from manim.utils.tex import TexTemplate
from manim.utils.tex_file_writing import tex_to_svg_file
@ -67,13 +68,13 @@ class SingleStringMathTex(SVGMobject):
if kwargs.get("color") is None:
# makes it so that color isn't explicitly passed for these mobs,
# and can instead inherit from the parent
kwargs["color"] = VMobject().color
kwargs["color"] = OpenGLVMobject().color
self._font_size = font_size
self.organize_left_to_right = organize_left_to_right
self.tex_environment = tex_environment
if tex_template is None:
tex_template = config["tex_template"]
tex_template = config.tex_template
self.tex_template = tex_template
assert isinstance(tex_string, str)
@ -289,6 +290,7 @@ class MathTex(SingleStringMathTex):
if self.organize_left_to_right:
self._organize_submobjects_left_to_right()
self.note_changed_family()
def _break_up_tex_strings(self, tex_strings):
# Separate out anything surrounded in double braces

View file

@ -135,8 +135,6 @@ class VMobject(Mobject):
cap_style: CapStyleType = CapStyleType.AUTO,
**kwargs,
):
self.fill_opacity = fill_opacity
self.stroke_opacity = stroke_opacity
self.stroke_width = stroke_width
if background_stroke_color is not None:
self.background_stroke_color: ManimColor = ManimColor(
@ -176,6 +174,11 @@ class VMobject(Mobject):
if stroke_color is not None:
self.stroke_color = ManimColor.parse(stroke_color)
if fill_opacity is not None:
self.fill_color = self.fill_color.set_opacity(fill_opacity)
if stroke_opacity is not None:
self.stroke_color = self.stroke_color.set_opacity(stroke_opacity)
def _assert_valid_submobjects(self, submobjects: Iterable[VMobject]) -> Self:
return self._assert_valid_submobjects_internal(submobjects, VMobject)

View file

@ -545,7 +545,6 @@ class GLVMobjectManager:
[mob.get_unit_normal()], points_length, axis=0
)
mob.renderer_data.bounding_box = compute_bounding_box(mob)
# print(mob.renderer_data)
@staticmethod
def read_uniforms(mob: OpenGLVMobject):