Refactored play logic (#1019)

* refactored play logic

* fixed tests (and minor improvement)

* moved time progression generation

* fixed typo :/

* added tests

* fixed tests

* fixed imports

* fixed tests

* black

* fixed tests in  python 3.7

* fixed tests for python 3.7

* Update manim/renderer/cairo_renderer.py

* removed is_cached variable

* updated webgl renderer and removed redondent code

* Update manim/scene/scene.py

* Add self.duration back into compile_animation_data

* Fixed openGL ?

* black

* fixed merge conflict by adding fixture marker

* Called begin_animation to fix OpenGL

* Fixed original_skippingstatu sname for opengl renderer

* Update manim/utils/caching.py

Co-authored-by: Devin Neal <devin@eulertour.com>

Co-authored-by: Hugues Devimeux <hugues.devimeux@gmail.com>
Co-authored-by: Jason G. Villanueva <a@jsonvillanueva.com>
Co-authored-by: Devin Neal <devin@eulertour.com>
This commit is contained in:
Hugues Devimeux 2021-03-23 09:58:49 +01:00 committed by GitHub
commit fc6159b45d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 309 additions and 50 deletions

View file

@ -248,6 +248,7 @@ class Wait(Animation):
self.duration = duration
self.mobject = None
self.stop_condition = stop_condition
self.is_static_wait = False
super().__init__(None, **kwargs)
def begin(self) -> None:

View file

@ -1,18 +1,15 @@
import typing
import time
import numpy as np
from .. import config
from ..utils.iterables import list_update
from ..utils.exceptions import EndSceneEarlyException
from ..scene.scene_file_writer import SceneFileWriter
from ..utils.caching import handle_caching_play
from manim.utils.hashing import get_hash_from_play_call
from .. import config, logger
from ..camera.camera import Camera
def pass_scene_reference(func):
def wrapper(self, scene, *args, **kwargs):
func(self, scene, *args, **kwargs)
return wrapper
from ..scene.scene_file_writer import SceneFileWriter
from ..utils.exceptions import EndSceneEarlyException
from ..utils.iterables import list_update
from ..mobject.mobject import Mobject
def handle_play_like_call(func):
@ -36,6 +33,12 @@ def handle_play_like_call(func):
to the video file stream.
"""
# NOTE : This is only kept for OpenGL renderer.
# The play logic of the cairo renderer as been refactored and does not need this function anymore.
# When OpenGL renderer will have a proper testing system,
# the play logic of the latter has to be refactored in the same way the cairo renderer has been, and thus this
# method has to be deleted.
def wrapper(self, scene, *args, **kwargs):
self.animation_start_time = time.time()
self.file_writer.begin_animation(not self.skip_animations)
@ -60,7 +63,7 @@ class CairoRenderer:
self.file_writer = None
camera_cls = camera_class if camera_class is not None else Camera
self.camera = camera_cls()
self.original_skipping_status = skip_animations
self._original_skipping_status = skip_animations
self.skip_animations = skip_animations
self.animations_hashes = []
self.num_plays = 0
@ -73,12 +76,55 @@ class CairoRenderer:
scene.__class__.__name__,
)
@pass_scene_reference
@handle_caching_play
@handle_play_like_call
def play(self, scene, *args, **kwargs):
if scene.compile_animation_data(*args, **kwargs):
# Reset skip_animations to the original state.
# Needed when rendering only some animations, and skipping others.
self.skip_animations = self._original_skipping_status
self.update_skipping_status()
scene.compile_animation_data(*args, **kwargs)
# If skip_animations is already True, we can skip all the caching process.
if not config["disable_caching"] and not self.skip_animations:
hash_current_animation = get_hash_from_play_call(
scene, self.camera, scene.animations, scene.mobjects
)
if self.file_writer.is_already_cached(hash_current_animation):
logger.info(
f"Animation {self.num_plays} : Using cached data (hash : %(hash_current_animation)s)",
{"hash_current_animation": hash_current_animation},
)
self.skip_animations = True
else:
hash_current_animation = f"uncached_{self.num_plays:05}"
if self.skip_animations:
logger.debug(f"Skipping animation {self.num_plays}")
hash_current_animation = None
# adding None as a partial movie file will make file_writer ignore the latter.
self.file_writer.add_partial_movie_file(hash_current_animation)
self.animations_hashes.append(hash_current_animation)
logger.debug(
"List of the first few animation hashes of the scene: %(h)s",
{"h": str(self.animations_hashes[:5])},
)
# Save a static image, to avoid rendering non moving objects.
self.static_image = self.save_static_frame_data(scene, scene.static_mobjects)
self.file_writer.begin_animation(not self.skip_animations)
scene.begin_animations()
if scene.is_current_animation_frozen_frame():
self.update_frame(scene)
# self.duration stands for the total run time of all the animations.
# In this case, as there is only a wait, it will be the length of the wait.
self.freeze_current_frame(scene.duration)
else:
scene.play_internal()
self.file_writer.end_animation(not self.skip_animations)
self.num_plays += 1
def update_frame( # TODO Description in Docstring
self,
@ -154,6 +200,20 @@ class CairoRenderer:
for _ in range(num_frames):
self.file_writer.write_frame(frame)
def freeze_current_frame(self, duration: float):
"""Adds a static frame to the movie for a given duration. The static frame is the current frame.
Parameters
----------
duration : float
[description]
"""
dt = 1 / self.camera.frame_rate
self.add_frame(
self.get_frame(),
num_frames=int(duration / dt),
)
def show_frame(self):
"""
Opens the current frame in the Default Image Viewer
@ -162,7 +222,27 @@ class CairoRenderer:
self.update_frame(ignore_skipping=True)
self.camera.get_image().show()
def save_static_frame_data(self, scene, static_mobjects):
def save_static_frame_data(
self, scene, static_mobjects: typing.Iterable[Mobject]
) -> typing.Iterable[Mobject]:
"""Compute and save the static frame, that will be reused at each frame to avoid to unecesseraly computer
static mobjects.
Parameters
----------
scene : Scene
The scene played.
static_mobjects : typing.Iterable[Mobject]
Static mobjects of the scene. If None, self.static_image is set to None
Returns
-------
typing.Iterable[Mobject]
the static image computed.
"""
if static_mobjects == None or len(static_mobjects) == 0:
self.static_image = None
return
self.update_frame(scene, mobjects=static_mobjects)
self.static_image = self.get_frame()
return self.static_image
@ -175,6 +255,8 @@ class CairoRenderer:
the number of animations that need to be played, and
raises an EndSceneEarlyException if they don't correspond.
"""
if config["save_last_frame"]:
self.skip_animations = True
if config["from_animation_number"]:
if self.num_plays < config["from_animation_number"]:
self.skip_animations = True

View file

@ -1,6 +1,6 @@
from manim.utils.exceptions import EndSceneEarlyException
from manim.utils.caching import handle_caching_play
from manim.renderer.cairo_renderer import pass_scene_reference, handle_play_like_call
from manim.renderer.cairo_renderer import handle_play_like_call
from manim.utils.color import color_to_rgba
import moderngl
from .opengl_renderer_window import Window
@ -185,7 +185,7 @@ class OpenGLRenderer:
# Measured in pixel widths, used for vector graphics
self.anti_alias_width = 1.5
self.original_skipping_status = skip_animations
self._original_skipping_status = skip_animations
self.skip_animations = skip_animations
self.animations_hashes = []
self.num_plays = 0
@ -383,6 +383,7 @@ class OpenGLRenderer:
def play(self, scene, *args, **kwargs):
# TODO: Handle data locking / unlocking.
if scene.compile_animation_data(*args, **kwargs):
scene.begin_animations()
scene.play_internal()
def render(self, scene, frame_offset, moving_mobjects):

View file

@ -24,6 +24,7 @@ class WebGLRenderer:
self.skip_animations = False
break
s = scene.compile_animation_data(*args, skip_rendering=True, **kwargs)
scene.begin_animations()
self.skip_animations = True
scene_copy = copy.deepcopy(scene)

View file

@ -821,10 +821,25 @@ class Scene(Container):
"""
self.wait(max_time, stop_condition=stop_condition)
def compile_animation_data(self, *animations, skip_rendering=False, **play_kwargs):
def compile_animation_data(self, *animations: Animation, **play_kwargs):
"""Given a list of animations, compile statics and moving mobjects, duration from them.
This also begin the animations.
Parameters
----------
skip_rendering : bool, optional
Whether the rendering should be skipped, by default False
Returns
-------
self, None
None if there is nothing to play, or self otherwise.
"""
# NOTE TODO : returns statement of this method are wrong. It should return nothing, as it makes a little sense to get any information from this method.
# The return are kept to keep webgl renderer from breaking.
if len(animations) == 0:
warnings.warn("Called Scene.play with no animations")
return None
raise ValueError("Called Scene.play with no animations")
self.animations = self.compile_animations(*animations, **play_kwargs)
self.add_mobjects_from_animations(self.animations)
@ -833,18 +848,16 @@ class Scene(Container):
self.stop_condition = None
self.moving_mobjects = None
self.static_mobjects = None
if not config["use_opengl_renderer"]:
if len(self.animations) == 1 and isinstance(self.animations[0], Wait):
self.update_mobjects(dt=0) # Any problems with this?
if self.should_update_mobjects():
# TODO, be smart about setting a static image
# the same way Scene.play does
self.renderer.static_image = None
self.stop_condition = self.animations[0].stop_condition
else:
self.duration = self.animations[0].duration
if not skip_rendering:
self.add_static_frames(self.animations[0].duration)
# Static image logic when the wait is static is done by the renderer, not here.
self.animations[0].is_static_wait = True
return None
else:
# Paint all non-moving objects onto the screen, so they don't
@ -853,17 +866,21 @@ class Scene(Container):
self.moving_mobjects,
self.static_mobjects,
) = self.get_moving_and_static_mobjects(self.animations)
self.renderer.save_static_frame_data(self, self.static_mobjects)
self.duration = self.get_run_time(self.animations)
self.time_progression = self._get_animation_time_progression(
self.animations, self.duration
)
return self
def begin_animations(self) -> None:
"""Start the animations of the scene."""
for animation in self.animations:
animation.begin()
return self
def is_current_animation_frozen_frame(self) -> bool:
"""Returns wether the current animation produces a static frame (generally a Wait)."""
return (
isinstance(self.animations[0], Wait)
and len(self.animations) == 1
and self.animations[0].is_static_wait
)
def play_internal(self, skip_rendering=False):
"""
@ -879,6 +896,10 @@ class Scene(Container):
named parameters affecting what was passed in ``args``,
e.g. ``run_time``, ``lag_ratio`` and so on.
"""
self.duration = self.get_run_time(self.animations)
self.time_progression = self._get_animation_time_progression(
self.animations, self.duration
)
for t in self.time_progression:
self.update_to_time(t)
if not skip_rendering:
@ -893,6 +914,8 @@ class Scene(Container):
if not self.renderer.skip_animations:
self.update_mobjects(0)
self.renderer.static_image = None
# Closing the progress bar at the end of the play.
self.time_progression.close()
def interact(self):
self.quit_interaction = False
@ -954,14 +977,6 @@ class Scene(Container):
animation.interpolate(alpha)
self.update_mobjects(dt)
def add_static_frames(self, duration):
self.renderer.update_frame(self)
dt = 1 / self.renderer.camera.frame_rate
self.renderer.add_frame(
self.renderer.get_frame(),
num_frames=int(duration / dt),
)
def add_sound(self, sound_file, time_offset=0, gain=None, **kwargs):
"""
This method is used to add a sound to the animation.

View file

@ -17,8 +17,14 @@ def handle_caching_play(func):
Take the same parameters as `scene.play`.
"""
# NOTE : This is only kept for OpenGL renderer.
# The play logic of the cairo renderer as been refactored and does not need this function anymore.
# When OpenGL renderer will have a proper testing system,
# the play logic of the latter has to be refactored in the same way the cairo renderer has been, and thus this
# method has to be deleted.
def wrapper(self, scene, *args, **kwargs):
self.skip_animations = self.original_skipping_status
self.skip_animations = self._original_skipping_status
self.update_skipping_status()
animations = scene.compile_animations(*args, **kwargs)
scene.add_mobjects_from_animations(animations)

View file

@ -2,7 +2,7 @@
{"levelname": "DEBUG", "module": "hashing", "message": "Hashing ..."}
{"levelname": "DEBUG", "module": "hashing", "message": "Hashing done in <> s."}
{"levelname": "DEBUG", "module": "hashing", "message": "Hash generated : <>"}
{"levelname": "DEBUG", "module": "caching", "message": "List of the first few animation hashes of the scene: <>"}
{"levelname": "DEBUG", "module": "cairo_renderer", "message": "List of the first few animation hashes of the scene: <>"}
{"levelname": "INFO", "module": "scene_file_writer", "message": "Animation 0 : Partial movie file written in <>"}
{"levelname": "DEBUG", "module": "scene_file_writer", "message": "Partial movie files to combine (1 files): <>"}
{"levelname": "INFO", "module": "scene_file_writer", "message": "\nFile ready at <>\n"}

View file

@ -2,6 +2,8 @@ import pytest
from pathlib import Path
from manim import config, tempconfig
@pytest.fixture
def manim_cfg_file():
@ -13,6 +15,20 @@ def simple_scenes_path():
return str(Path(__file__).parent / "simple_scenes.py")
@pytest.fixture
def using_temp_config(tmpdir):
"""Standard fixture that makes tests use a standard_config.cfg with a temp dir."""
with tempconfig(config.digest_file(Path(__file__).parent / "standard_config.cfg")):
config.media_dir = tmpdir
yield
@pytest.fixture
def disabling_caching():
with tempconfig({"disable_caching": True}):
yield
@pytest.fixture
def infallible_scenes_path():
return str(Path(__file__).parent / "infallible_scenes.py")

View file

@ -13,8 +13,7 @@ class SceneWithMultipleCalls(Scene):
number = Integer(0)
self.add(number)
for i in range(10):
number.become(Integer(i))
self.play(Animation(number))
self.play(Animation(Square()))
class SceneWithMultipleWaitCalls(Scene):
@ -34,3 +33,18 @@ class NoAnimations(Scene):
dot = Dot().set_color(GREEN)
self.add(dot)
self.wait(1)
class SceneWithStaticWait(Scene):
def construct(self):
self.add(Square())
self.wait()
class SceneWithNonStaticWait(Scene):
def construct(self):
s = Square()
# Non static wait are triggered by mobject with time based updaters.
s.add_updater(lambda mob, dt: None)
self.add(s)
self.wait()

View file

@ -1,4 +1,5 @@
[CLI]
frame_rate = 15
pixel_height = 480
pixel_width = 300
pixel_width = 300
verbosity = DEBUG

View file

@ -0,0 +1,31 @@
from manim import *
import pytest
from .simple_scenes import *
from unittest.mock import Mock
def test_render(using_temp_config, disabling_caching):
scene = SquareToCircle()
renderer = scene.renderer
renderer.update_frame = Mock()
renderer.add_frame = Mock()
scene.render()
assert renderer.add_frame.call_count == config["frame_rate"]
assert renderer.update_frame.call_count == config["frame_rate"]
def test_skipping_status_with_from_to_and_up_to(using_temp_config, disabling_caching):
"""Test if skip_animations is well udpated when -n flag is passed"""
config.from_animation_number = 2
config.upto_animation_number = 6
class SceneWithMultipleCalls(Scene):
def construct(self):
number = Integer(0)
self.add(number)
for i in range(10):
self.play(Animation(Square()))
assert ((i >= 2) and (i <= 6)) or self.renderer.skip_animations
SceneWithMultipleCalls().render()

View file

@ -0,0 +1,89 @@
from functools import wraps
from pathlib import Path
from unittest.mock import Mock
from numpy.core.defchararray import asarray
import sys
import pytest
from manim import *
from manim import config
from .simple_scenes import (
SceneWithMultipleCalls,
SceneWithNonStaticWait,
SceneWithStaticWait,
SquareToCircle,
)
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mock object has a different implementation in python 3.7, which makes it broken with this logic.",
)
@pytest.mark.parametrize("frame_rate", argvalues=[15, 30, 60])
def test_t_values(using_temp_config, disabling_caching, frame_rate):
"""Test that the framerate corresponds to the number of t values generated"""
config.frame_rate = frame_rate
scene = SquareToCircle()
scene.update_to_time = Mock()
scene.render()
assert scene.update_to_time.call_count == config["frame_rate"]
np.testing.assert_allclose(
([call.args[0] for call in scene.update_to_time.call_args_list]),
np.arange(0, 1, 1 / config["frame_rate"]),
)
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mock object has a different implementation in python 3.7, which makes it broken with this logic.",
)
def test_t_values_with_skip_animations(using_temp_config, disabling_caching):
"""Test the behaviour of scene.skip_animations"""
scene = SquareToCircle()
scene.update_to_time = Mock()
scene.renderer._original_skipping_status = True
scene.render()
assert scene.update_to_time.call_count == 1
np.testing.assert_almost_equal(
scene.update_to_time.call_args.args[0],
1.0,
)
def test_static_wait_detection(using_temp_config, disabling_caching):
"""Test if a static wait (wait that freeze the frame) is correctly detected"""
scene = SceneWithStaticWait()
scene.render()
# Test is is_static_wait of the Wait animation has been set to True by compile_animation_ata
assert scene.animations[0].is_static_wait
assert scene.is_current_animation_frozen_frame()
def test_non_static_wait_detection(using_temp_config, disabling_caching):
scene = SceneWithNonStaticWait()
scene.render()
assert not scene.animations[0].is_static_wait
assert not scene.is_current_animation_frozen_frame()
def test_t_values_with_cached_data(using_temp_config):
"""Test the proper generation and use of the t values when an animation is cached."""
scene = SceneWithMultipleCalls()
# Mocking the file_writer will skip all the writing process.
scene.renderer.file_writer = Mock(scene.renderer.file_writer)
# Simulate that all animations are cached.
scene.renderer.file_writer.is_already_cached.return_value = True
scene.update_to_time = Mock()
scene.render()
assert scene.update_to_time.call_count == 10
def test_t_values_save_last_frame(using_temp_config):
"""Test that there is only one t value handled when only saving the last frame"""
config.save_last_frame = True
scene = SquareToCircle()
scene.update_to_time = Mock()
scene.render()
scene.update_to_time.assert_called_once_with(1)

View file

@ -21,7 +21,8 @@ def _check_logs(reference_logfile, generated_logfile):
msg_assert += f"Logs generated are LONGER than the expected logs.\n There are {diff} extra logs :\n"
for log in generated_logs[len(reference_logs) :]:
msg_assert += log
assert 0, msg_assert
msg_assert += f"\nPath of reference log: {reference_logfile}\nPath of generated logs: {generated_logfile}"
assert 0, msg_assert + reference_logfile + " " + generated_logfile
for index, ref, gen in zip(itertools.count(), reference_logs, generated_logs):
# As they are string, we only need to check if they are equal. If they are not, we then compute a more precise difference, to debug.
@ -35,9 +36,10 @@ def _check_logs(reference_logfile, generated_logfile):
# \n and \t don't not work in f-strings.
newline = "\n"
tab = "\t"
assert (
len(diff_keys) == 0
), f"Logs don't match at {index} log. : \n{newline.join([f'In {key} field, got -> {newline}{tab}{repr(gen_log[key])}. {newline}Expected : -> {newline}{tab}{repr(ref_log[key])}.' for key in diff_keys])}"
assert len(diff_keys) == 0, (
f"Logs don't match at {index} log. : \n{newline.join([f'In {key} field, got -> {newline}{tab}{repr(gen_log[key])}. {newline}Expected : -> {newline}{tab}{repr(ref_log[key])}.' for key in diff_keys])}"
+ f"\nPath of reference log: {reference_logfile}\nPath of generated logs: {generated_logfile}"
)
def logs_comparison(control_data_file, log_path_from_media_dir):