Porting and drastical rewrite to reach compatibility with 3b1b/manim again (#3107)

* ported functionality of Mobject from 3b1b to OpenGLMobject

* ported functionality of VMobject from 3b1b to OpenVGLMobject

* first working render

* first step to dump old scene structure

* copied scene without adapting
This commit is contained in:
Tristan Schulz 2023-01-03 12:46:56 +01:00 committed by GitHub
commit 8b345d9a07
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 3105 additions and 2467 deletions

View file

@ -15,6 +15,8 @@ import argparse
import configparser
import copy
import errno
import importlib
import inspect
import logging
import os
import re

View file

@ -194,6 +194,9 @@ TAU: float = 2 * PI
DEGREES: float = TAU / 360
"""The exchange rate between radians and degrees."""
RADIANS: float = 1.0
"""Just a default to select for camera."""
# Video qualities
QUALITIES: dict[str, dict[str, str | int | None]] = {
"fourk_quality": {

View file

@ -0,0 +1,5 @@
from manim.event_handler.event_dispatcher import EventDispatcher
# This is supposed to be a Singleton
# i.e., during runtime there should be only one object of Event Dispatcher
EVENT_DISPATCHER = EventDispatcher()

View file

@ -0,0 +1,91 @@
from __future__ import annotations
import numpy as np
from manim.event_handler.event_listener import EventListener
from manim.event_handler.event_type import EventType
class EventDispatcher:
def __init__(self):
self.event_listeners: dict[EventType, list[EventListener]] = {
event_type: [] for event_type in EventType
}
self.mouse_point = np.array((0.0, 0.0, 0.0))
self.mouse_drag_point = np.array((0.0, 0.0, 0.0))
self.pressed_keys: set[int] = set()
self.draggable_object_listeners: list[EventListener] = []
def add_listener(self, event_listener: EventListener):
assert isinstance(event_listener, EventListener)
self.event_listeners[event_listener.event_type].append(event_listener)
return self
def remove_listener(self, event_listener: EventListener):
assert isinstance(event_listener, EventListener)
try:
while event_listener in self.event_listeners[event_listener.event_type]:
self.event_listeners[event_listener.event_type].remove(event_listener)
except Exception:
# raise ValueError("Handler is not handling this event, so cannot remove it.")
pass
return self
def dispatch(self, event_type: EventType, **event_data):
if event_type == EventType.MouseMotionEvent:
self.mouse_point = event_data["point"]
elif event_type == EventType.MouseDragEvent:
self.mouse_drag_point = event_data["point"]
elif event_type == EventType.KeyPressEvent:
self.pressed_keys.add(event_data["symbol"]) # Modifiers?
elif event_type == EventType.KeyReleaseEvent:
self.pressed_keys.difference_update({event_data["symbol"]}) # Modifiers?
elif event_type == EventType.MousePressEvent:
self.draggable_object_listeners = [
listener
for listener in self.event_listeners[EventType.MouseDragEvent]
if listener.mobject.is_point_touching(self.mouse_point)
]
elif event_type == EventType.MouseReleaseEvent:
self.draggable_object_listeners = []
propagate_event = None
if event_type == EventType.MouseDragEvent:
for listener in self.draggable_object_listeners:
assert isinstance(listener, EventListener)
propagate_event = listener.callback(listener.mobject, event_data)
if propagate_event is not None and propagate_event is False:
return propagate_event
elif event_type.value.startswith("mouse"):
for listener in self.event_listeners[event_type]:
if listener.mobject.is_point_touching(self.mouse_point):
propagate_event = listener.callback(listener.mobject, event_data)
if propagate_event is not None and propagate_event is False:
return propagate_event
elif event_type.value.startswith("key"):
for listener in self.event_listeners[event_type]:
propagate_event = listener.callback(listener.mobject, event_data)
if propagate_event is not None and propagate_event is False:
return propagate_event
return propagate_event
def get_listeners_count(self) -> int:
return sum([len(value) for key, value in self.event_listeners.items()])
def get_mouse_point(self) -> np.ndarray:
return self.mouse_point
def get_mouse_drag_point(self) -> np.ndarray:
return self.mouse_drag_point
def is_key_pressed(self, symbol: int) -> bool:
return symbol in self.pressed_keys
__iadd__ = add_listener
__isub__ = remove_listener
__call__ = dispatch
__len__ = get_listeners_count

View file

@ -0,0 +1,34 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable
import manim.mobject.opengl.opengl_mobject as glmob
from manim.event_handler.event_type import EventType
class EventListener:
def __init__(
self,
mobject: glmob.OpenGLMobject,
event_type: EventType,
event_callback: Callable[[glmob.OpenGLMobject, dict[str, str]], None],
):
self.mobject = mobject
self.event_type = event_type
self.callback = event_callback
def __eq__(self, o: object) -> bool:
return_val = False
if isinstance(o, EventListener):
try:
return_val = (
self.callback == o.callback
and self.mobject == o.mobject
and self.event_type == o.event_type
)
except Exception:
pass
return return_val

View file

@ -0,0 +1,11 @@
from enum import Enum
class EventType(Enum):
MouseMotionEvent = "mouse_motion_event"
MousePressEvent = "mouse_press_event"
MouseReleaseEvent = "mouse_release_event"
MouseDragEvent = "mouse_drag_event"
MouseScrollEvent = "mouse_scroll_event"
KeyPressEvent = "key_press_event"
KeyReleaseEvent = "key_release_event"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -2410,7 +2410,6 @@ class DashedVMobject(VMobject, metaclass=ConvertToOpenGL):
equal_lengths=True,
**kwargs,
):
self.dashed_ratio = dashed_ratio
self.num_dashes = num_dashes
super().__init__(color=color, **kwargs)

View file

@ -1,9 +1,14 @@
from __future__ import annotations
import itertools as it
import math
import sys
import time
from typing import Any
from typing import Any, Iterable
from manim.renderer.shader_wrapper import ShaderWrapper
from ..constants import RADIANS
if sys.version_info < (3, 8):
from backports.cached_property import cached_property
@ -12,22 +17,25 @@ else:
import moderngl
import numpy as np
import OpenGL.GL as gl
from PIL import Image
from scipy.spatial.transform import Rotation
from manim import config, logger
from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.utils.caching import handle_caching_play
from manim.utils.color import color_to_rgba
from manim.utils.color import BLACK, color_to_rgba
from manim.utils.exceptions import EndSceneEarlyException
from ..constants import *
from ..scene.scene_file_writer import SceneFileWriter
from ..utils import opengl
from ..utils.config_ops import _Data
from ..utils.simple_functions import clip
from ..utils.simple_functions import clip, fdiv
from ..utils.space_ops import (
angle_of_vector,
normalize,
quaternion_from_angle_axis,
quaternion_mult,
rotation_matrix_transpose,
@ -40,7 +48,495 @@ from .vectorized_mobject_rendering import (
)
class OpenGLCamera(OpenGLMobject):
class OpenGLCameraFrame(OpenGLMobject):
def __init__(
self,
frame_shape: tuple[float, float] = (config.frame_width, config.frame_height),
center_point: np.ndarray = ORIGIN,
focal_dist_to_height: float = 2.0,
**kwargs,
):
self.frame_shape = frame_shape
self.center_point = center_point
self.focal_dist_to_height = focal_dist_to_height
super().__init__(**kwargs)
def init_uniforms(self):
super().init_uniforms()
# as a quarternion
self.uniforms["orientation"] = Rotation.identity().as_quat()
self.uniforms["focal_dist_to_height"] = self.focal_dist_to_height
def init_points(self) -> None:
self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP])
self.set_width(self.frame_shape[0], stretch=True)
self.set_height(self.frame_shape[1], stretch=True)
self.move_to(self.center_point)
def set_orientation(self, rotation: Rotation):
self.uniforms["orientation"] = rotation.as_quat()
return self
def get_orientation(self):
return Rotation.from_quat(self.uniforms["orientation"])
def to_default_state(self):
self.center()
self.set_height(config.frame_width)
self.set_width(config.frame_height)
self.set_orientation(Rotation.identity())
return self
def get_euler_angles(self):
return self.get_orientation().as_euler("zxz")[::-1]
def get_theta(self):
return self.get_euler_angles()[0]
def get_phi(self):
return self.get_euler_angles()[1]
def get_gamma(self):
return self.get_euler_angles()[2]
def get_inverse_camera_rotation_matrix(self):
return self.get_orientation().as_matrix().T
def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs): # type: ignore
rot = Rotation.from_rotvec(axis * normalize(axis)) # type: ignore
self.set_orientation(rot * self.get_orientation())
def set_euler_angles(
self,
theta: float | None = None,
phi: float | None = None,
gamma: float | None = None,
units: float = RADIANS,
):
eulers = self.get_euler_angles() # theta, phi, gamma
for i, var in enumerate([theta, phi, gamma]):
if var is not None:
eulers[i] = var * units
self.set_orientation(Rotation.from_euler("zxz", eulers[::-1]))
return self
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
"""
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta: float):
return self.set_euler_angles(theta=theta)
def set_phi(self, phi: float):
return self.set_euler_angles(phi=phi)
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta: float):
self.rotate(dtheta, OUT)
return self
def increment_phi(self, dphi: float):
self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0])
return self
def increment_gamma(self, dgamma: float):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2])
return self
def set_focal_distance(self, focal_distance: float):
self.uniforms["focal_dist_to_height"] = focal_distance / self.get_height()
return self
def set_field_of_view(self, field_of_view: float):
self.uniforms["focal_dist_to_height"] = 2 * math.tan(field_of_view / 2)
return self
def get_shape(self):
return (self.get_width(), self.get_height())
def get_center(self) -> np.ndarray:
# Assumes first point is at the center
return self.points[0]
def get_width(self) -> float:
points = self.points
return points[2, 0] - points[1, 0]
def get_height(self) -> float:
points = self.points
return points[4, 1] - points[3, 1]
def get_focal_distance(self) -> float:
return self.uniforms["focal_dist_to_height"] * self.get_height() # type: ignore
def get_field_of_view(self) -> float:
return 2 * math.atan(self.uniforms["focal_dist_to_height"] / 2)
def get_implied_camera_location(self) -> np.ndarray:
to_camera = self.get_inverse_camera_rotation_matrix()[2]
dist = self.get_focal_distance()
return self.get_center() + dist * to_camera
class OpenGLCamera:
def __init__(
self,
ctx: moderngl.Context | None = None,
background_image: str | None = None,
frame_config: dict = {},
pixel_width: int = config.pixel_width,
pixel_height: int = config.pixel_height,
fps: int = config.frame_rate,
# Note: frame height and width will be resized to match the pixel aspect rati
background_color=BLACK,
background_opacity: float = 1.0,
# Points in vectorized mobjects with norm greater
# than this value will be rescaled
max_allowable_norm: float = 1.0,
image_mode: str = "RGBA",
n_channels: int = 4,
pixel_array_dtype: type = np.uint8,
light_source_position: np.ndarray = np.array([-10, 10, 10]),
# Although vector graphics handle antialiasing fine
# without multisampling, for 3d scenes one might want
# to set samples to be greater than 0.
samples: int = 0,
) -> None:
self.background_image = background_image
self.pixel_width = pixel_width
self.pixel_height = pixel_height
self.fps = fps
self.max_allowable_norm = max_allowable_norm
self.image_mode = image_mode
self.n_channels = n_channels
self.pixel_array_dtype = pixel_array_dtype
self.light_source_position = light_source_position
self.samples = samples
self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max
self.background_color: list[float] = list(
color_to_rgba(background_color, background_opacity)
)
self.init_frame(**frame_config)
self.init_context(ctx)
self.init_shaders()
self.init_textures()
self.init_light_source()
self.refresh_perspective_uniforms()
# A cached map from mobjects to their associated list of render groups
# so that these render groups are not regenerated unnecessarily for static
# mobjects
self.mob_to_render_groups: dict = {}
def init_frame(self, **config) -> None:
self.frame = OpenGLCameraFrame(**config)
def init_context(self, ctx: moderngl.Context | None = None) -> None:
if ctx is None:
ctx = moderngl.create_standalone_context()
fbo = self.get_fbo(ctx, 0)
else:
fbo = ctx.detect_framebuffer()
self.ctx = ctx
self.fbo = fbo
self.set_ctx_blending()
# For multisample antisampling
fbo_msaa = self.get_fbo(ctx, self.samples)
fbo_msaa.use()
self.fbo_msaa = fbo_msaa
def set_ctx_blending(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.BLEND)
else:
self.ctx.disable(moderngl.BLEND)
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def init_light_source(self) -> None:
self.light_source = OpenGLPoint(self.light_source_position)
# Methods associated with the frame buffer
def get_fbo(self, ctx: moderngl.Context, samples: int = 0) -> moderngl.Framebuffer:
pw = self.pixel_width
ph = self.pixel_height
return ctx.framebuffer(
color_attachments=ctx.texture(
(pw, ph), components=self.n_channels, samples=samples
),
depth_attachment=ctx.depth_renderbuffer((pw, ph), samples=samples),
)
def clear(self) -> None:
self.fbo.clear(*self.background_color)
self.fbo_msaa.clear(*self.background_color)
def reset_pixel_shape(self, new_width: int, new_height: int) -> None:
self.pixel_width = new_width
self.pixel_height = new_height
self.refresh_perspective_uniforms()
def get_raw_fbo_data(self, dtype: str = "f1") -> bytes:
# Copy blocks from the fbo_msaa to the drawn fbo using Blit
pw, ph = (self.pixel_width, self.pixel_height)
gl.glBindFrameBuffer(gl.GL_READ_FRAMEBUFFER, self.fbo_msaa.glo)
gl.glBindFrameBuffer(gl.GL_DRAW_FRAMEBUFFER, self.fbo.glo)
gl.glBlitFramebuffer(
0, 0, pw, ph, 0, 0, pw, ph, gl.GL_COLOR_BUFFER_BIT, gl.GL_LINEAR
)
return self.fbo.read(
viewport=self.fbo.viewport,
components=self.n_channels,
dtype=dtype,
)
def get_image(self) -> Image.Image:
return Image.frombytes(
"RGBA",
self.get_pixel_shape(),
self.get_raw_fbo_data(),
"raw",
"RGBA",
0,
-1,
)
def get_pixel_array(self) -> np.ndarray:
raw = self.get_raw_fbo_data(dtype="f4")
flat_arr = np.frombuffer(raw, dtype="f4")
arr = flat_arr.reshape([*reversed(self.fbo.size), self.n_channels])
arr = arr[::-1]
# Convert from float
return (self.rgb_max_val * arr).astype(self.pixel_array_dtype)
def get_texture(self):
texture = self.ctx.texture(
size=self.fbo.size, components=4, data=self.get_raw_fbo_data(), dtype="f4"
)
return texture
# Getting camera attributes
def get_pixel_shape(self) -> tuple[int, int]:
return self.fbo.viewport[2:4]
# return (self.pixel_width, self.pixel_height)
def get_pixel_width(self) -> int:
return self.get_pixel_shape()[0]
def get_pixel_height(self) -> int:
return self.get_pixel_shape()[1]
def get_frame_height(self) -> float:
return self.frame.get_height()
def get_frame_width(self) -> float:
return self.frame.get_width()
def get_frame_shape(self) -> tuple[float, float]:
return (self.get_frame_width(), self.get_frame_height())
def get_frame_center(self) -> np.ndarray:
return self.frame.get_center()
def get_location(self) -> tuple[float, float, float] | np.ndarray:
return self.frame.get_implied_camera_location()
def resize_frame_shape(self, fixed_dimension: bool = False) -> None:
"""
Changes frame_shape to match the aspect ratio
of the pixels, where fixed_dimension determines
whether frame_height or frame_width
remains fixed while the other changes accordingly.
"""
pixel_height = self.get_pixel_height()
pixel_width = self.get_pixel_width()
frame_height = self.get_frame_height()
frame_width = self.get_frame_width()
aspect_ratio = fdiv(pixel_width, pixel_height)
if not fixed_dimension:
frame_height = frame_width / aspect_ratio
else:
frame_width = aspect_ratio * frame_height
self.frame.set_height(frame_height)
self.frame.set_width(frame_width)
# Rendering
def capture(self, *mobjects: OpenGLMobject) -> None:
self.refresh_perspective_uniforms()
for mobject in mobjects:
for render_group in self.get_render_group_list(mobject):
self.render(render_group)
def render(self, render_group: dict[str, Any]) -> None:
shader_wrapper: ShaderWrapper = render_group["shader_wrapper"]
shader_program = render_group["prog"]
self.set_shader_uniforms(shader_program, shader_wrapper)
self.set_ctx_depth_test(shader_wrapper.depth_test)
render_group["vao"].render(int(shader_wrapper.render_primitive))
if render_group["single_use"]:
self.release_render_group(render_group)
def get_render_group_list(self, mobject: OpenGLMobject) -> Iterable[dict[str, Any]]:
if mobject.is_changing():
return self.generate_render_group_list(mobject)
# Otherwise, cache result for later use
key = id(mobject)
if key not in self.mob_to_render_groups:
self.mob_to_render_groups[key] = list(
self.generate_render_group_list(mobject)
)
return self.mob_to_render_groups[key]
def generate_render_group_list(
self, mobject: OpenGLMobject
) -> Iterable[dict[str, Any]]:
return (
self.get_render_group(sw, single_use=mobject.is_changing())
for sw in mobject.get_shader_wrapper_list()
)
def get_render_group(
self, shader_wrapper: ShaderWrapper, single_use: bool = True
) -> dict[str, Any]:
# Data buffers
vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes())
if shader_wrapper.vert_indices is None:
ibo = None
else:
vert_index_data = shader_wrapper.vert_indices.astype("i4").tobytes()
if vert_index_data:
ibo = self.ctx.buffer(vert_index_data)
else:
ibo = None
# Program an vertex array
shader_program, vert_format = self.get_shader_program(shader_wrapper) # type: ignore
vao = self.ctx.vertex_array(
program=shader_program,
content=[(vbo, vert_format, *shader_wrapper.vert_attributes)],
index_buffer=ibo,
)
return {
"vbo": vbo,
"ibo": ibo,
"vao": vao,
"prog": shader_program,
"shader_wrapper": shader_wrapper,
"single_use": single_use,
}
def release_render_group(self, render_group: dict[str, Any]) -> None:
for key in ["vbo", "ibo", "vao"]:
if render_group[key] is not None:
render_group[key].release()
def refresh_static_mobjects(self) -> None:
for render_group in it.chain(*self.mob_to_render_groups.values()):
self.release_render_group(render_group)
self.mob_to_render_groups = {}
# Shaders
def init_shaders(self) -> None:
# Initialize with the null id going to None
self.id_to_shader_program: dict[int, tuple[moderngl.Program, str] | None] = {
hash(""): None
}
def get_shader_program(
self, shader_wrapper: ShaderWrapper
) -> tuple[moderngl.Program, str] | None:
sid = shader_wrapper.get_program_id()
if sid not in self.id_to_shader_program:
# Create shader program for the first time, then cache
# in the id_to_shader_program dictionary
program = self.ctx.program(**shader_wrapper.get_program_code())
vert_format = moderngl.detect_format(
program, shader_wrapper.vert_attributes
)
self.id_to_shader_program[sid] = (program, vert_format)
return self.id_to_shader_program[sid]
def set_shader_uniforms(
self,
shader: moderngl.Program,
shader_wrapper: ShaderWrapper,
) -> None:
for name, path in shader_wrapper.texture_paths.items():
tid = self.get_texture_id(path)
shader[name].value = tid
for name, value in it.chain(
self.perspective_uniforms.items(), shader_wrapper.uniforms.items()
):
if name in shader:
if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value)
shader[name].value = value
def refresh_perspective_uniforms(self) -> None:
frame = self.frame
# Orient light
rotation = frame.get_inverse_camera_rotation_matrix()
offset = frame.get_center()
light_pos = np.dot(rotation, self.light_source.get_location() + offset)
cam_pos = self.frame.get_implied_camera_location() # TODO
self.perspective_uniforms = {
"frame_shape": frame.get_shape(),
"pixel_shape": self.get_pixel_shape(),
"camera_offset": tuple(offset),
"camera_rotation": tuple(np.array(rotation).T.flatten()),
"camera_position": tuple(cam_pos),
"light_source_position": tuple(light_pos),
"focal_distance": frame.get_focal_distance(),
}
def init_textures(self) -> None:
self.n_textures: int = 0
self.path_to_texture: dict[str, tuple[int, moderngl.Texture]] = {}
def get_texture_id(self, path: str) -> int:
if path not in self.path_to_texture:
if self.n_textures == 15: # I have no clue why this is needed
self.n_textures += 1
tid = self.n_textures
self.n_textures += 1
im = Image.open(path).convert("RGBA")
texture = self.ctx.texture(
size=im.size,
components=len(im.getbands()),
data=im.tobytes(),
)
texture.use(location=tid)
self.path_to_texture[path] = (tid, texture)
return self.path_to_texture[path][0]
def release_texture(self, path: str):
tid_and_texture = self.path_to_texture.pop(path, None)
if tid_and_texture:
tid_and_texture[1].release()
return self
class OpenGLCameraLegacy(OpenGLMobject):
euler_angles = _Data()
def __init__(
@ -465,8 +961,6 @@ class OpenGLRenderer:
self.refresh_perspective_uniforms(scene.camera)
for mobject in scene.mobjects:
if not mobject.should_render:
continue
self.render_mobject(mobject)
for obj in scene.meshes:

View file

@ -2,11 +2,14 @@ from __future__ import annotations
import copy
import re
from functools import lru_cache
from pathlib import Path
import moderngl
import numpy as np
from manim.utils.iterables import resize_array
from .. import logger
# Mobjects that should be rendered with
@ -55,6 +58,29 @@ class ShaderWrapper:
self.init_program_code()
self.refresh_id()
def __eq__(self, shader_wrapper: object):
if not isinstance(shader_wrapper, ShaderWrapper):
raise TypeError(
f"Cannot compare ShaderWrapper with non-ShaderWrapper object of type {type(shader_wrapper)}"
)
return all(
(
np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices),
self.shader_folder == shader_wrapper.shader_folder,
all(
np.all(self.uniforms[key] == shader_wrapper.uniforms[key])
for key in self.uniforms
),
all(
self.texture_paths[key] == shader_wrapper.texture_paths[key]
for key in self.texture_paths
),
self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive,
)
)
def copy(self):
result = copy.copy(self)
result.vert_data = np.array(self.vert_data)
@ -125,30 +151,34 @@ class ShaderWrapper:
def replace_code(self, old, new):
code_map = self.program_code
for (name, _code) in code_map.items():
for name, _code in code_map.items():
if code_map[name] is None:
continue
code_map[name] = re.sub(old, new, code_map[name])
self.refresh_id()
def combine_with(self, *shader_wrappers):
# Assume they are of the same type
if len(shader_wrappers) == 0:
return
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
self.read_in(self.copy(), *shader_wrappers)
return self
def read_in(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
# Assume all are of the same type
total_len = sum(len(sw.vert_data) for sw in shader_wrappers)
self.vert_data = resize_array(self.vert_data, total_len)
if self.vert_indices is not None:
num_verts = len(self.vert_data)
indices_list = [self.vert_indices]
data_list = [self.vert_data]
for sw in shader_wrappers:
indices_list.append(sw.vert_indices + num_verts)
data_list.append(sw.vert_data)
num_verts += len(sw.vert_data)
self.vert_indices = np.hstack(indices_list)
self.vert_data = np.hstack(data_list)
else:
self.vert_data = np.hstack(
[self.vert_data, *(sw.vert_data for sw in shader_wrappers)],
)
total_verts = sum(len(sw.vert_indices) for sw in shader_wrappers)
self.vert_indices = resize_array(self.vert_indices, total_verts)
n_points = 0
n_verts = 0
for sw in shader_wrappers:
new_n_points = n_points + len(sw.vert_data)
self.vert_data[n_points:new_n_points] = sw.vert_data
if self.vert_indices is not None and sw.vert_indices is not None:
new_n_verts = n_verts + len(sw.vert_indices)
self.vert_indices[n_verts:new_n_verts] = sw.vert_indices + n_points
n_verts = new_n_verts
n_points = new_n_points
return self
@ -156,6 +186,7 @@ class ShaderWrapper:
filename_to_code_map: dict = {}
@lru_cache(maxsize=12)
def get_shader_code_from_file(filename: Path) -> str | None:
if filename in filename_to_code_map:
return filename_to_code_map[filename]

File diff suppressed because it is too large Load diff

View file

@ -288,6 +288,37 @@ def match_interpolate(
# Figuring out which bezier curves most smoothly connect a sequence of points
def get_smooth_quadratic_bezier_handle_points(points: FloatArray) -> FloatArray:
"""
Figuring out which bezier curves most smoothly connect a sequence of points.
Given three successive points, P0, P1 and P2, you can compute that by defining
h = (1/4) P0 + P1 - (1/4)P2, the bezier curve defined by (P0, h, P1) will pass
through the point P2.
So for a given set of four successive points, P0, P1, P2, P3, if we want to add
a handle point h between P1 and P2 so that the quadratic bezier (P1, h, P2) is
part of a smooth curve passing through all four points, we calculate one solution
for h that would produce a parbola passing through P3, call it smooth_to_right, and
another that would produce a parabola passing through P0, call it smooth_to_left,
and use the midpoint between the two.
"""
if len(points) == 2:
return midpoint(*points)
smooth_to_right, smooth_to_left = (
0.25 * ps[0:-2] + ps[1:-1] - 0.25 * ps[2:] for ps in (points, points[::-1])
)
if np.isclose(points[0], points[-1]).all():
last_str = 0.25 * points[-2] + points[-1] - 0.25 * points[1]
last_stl = 0.25 * points[1] + points[0] - 0.25 * points[-2]
else:
last_str = smooth_to_left[0]
last_stl = smooth_to_right[0]
handles = 0.5 * np.vstack([smooth_to_right, [last_str]])
handles += 0.5 * np.vstack([last_stl, smooth_to_left[::-1]])
return handles
def get_smooth_cubic_bezier_handle_points(points):
points = np.array(points)
num_handles = len(points) - 1

View file

@ -2,6 +2,8 @@
from __future__ import annotations
from manim.utils.iterables import resize_with_interpolation
__all__ = [
"color_to_rgb",
"color_to_rgba",
@ -550,3 +552,28 @@ def get_shaded_rgb(
factor *= 0.5
result = rgb + factor
return result
COLORMAP_3B1B: list[Color] = [BLUE_E, GREEN, YELLOW, RED]
def get_colormap_list(map_name: str = "viridis", n_colors: int = 9) -> np.ndarray:
"""
Options for map_name:
3b1b_colormap
magma
inferno
plasma
viridis
cividis
twilight
twilight_shifted
turbo
"""
from matplotlib.cm import get_cmap
if map_name == "3b1b_colormap":
rgbs = np.array([color_to_rgb(color) for color in COLORMAP_3B1B])
else:
rgbs = get_cmap(map_name).colors # Make more general?
return resize_with_interpolation(np.array(rgbs), n_colors)

View file

@ -0,0 +1,50 @@
from __future__ import annotations
import os
from manim._config import config
from manim.utils.file_ops import guarantee_existence
def get_directories() -> dict[str, str]:
return config["directories"]
def get_temp_dir() -> str:
return get_directories()["temporary_storage"]
def get_tex_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Tex"))
def get_text_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Text"))
def get_mobject_data_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
def get_downloads_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))
def get_output_dir() -> str:
return guarantee_existence(get_directories()["output"])
def get_raster_image_dir() -> str:
return get_directories()["raster_images"]
def get_vector_image_dir() -> str:
return get_directories()["vector_images"]
def get_sound_dir() -> str:
return get_directories()["sounds"]
def get_shader_dir() -> str:
return get_directories()["shaders"]

View file

@ -27,8 +27,11 @@ from pathlib import Path
from shutil import copyfile
from typing import TYPE_CHECKING
import validators
if TYPE_CHECKING:
from ..scene.scene_file_writer import SceneFileWriter
from typing import Iterable
from manim import __version__, config, logger

View file

@ -117,6 +117,19 @@ def clip(a, min_a, max_a):
return a
def fdiv(
a: Scalable, b: Scalable, zero_over_zero_value: Scalable | None = None
) -> Scalable:
if zero_over_zero_value is not None:
out = np.full_like(a, zero_over_zero_value)
where = np.logical_or(a != 0, b != 0)
else:
out = None
where = True
return np.true_divide(a, b, out=out, where=where)
def get_parameters(function: Callable) -> MappingProxyType[str, inspect.Parameter]:
"""Return the parameters of ``function`` as an ordered mapping of parameters'
names to their corresponding ``Parameter`` objects.

View file

@ -285,6 +285,32 @@ def rotation_about_z(angle: float) -> np.ndarray:
)
def get_norm(vector: np.ndarray) -> float:
"""Returns the norm of the vector.
Parameters
----------
vector
The vector for which you want to find the norm.
Returns
-------
float
The norm of the vector.
"""
return np.linalg.norm(vector)
def normalize(vect: list[float], fall_back: list[float] | None = None) -> np.ndarray:
norm = get_norm(vect)
if norm > 0:
return np.array(vect) / norm
elif fall_back is not None:
return np.array(fall_back)
else:
return np.zeros(len(vect))
def z_to_vector(vector: np.ndarray) -> np.ndarray:
"""
Returns some matrix in SO(3) which takes the z-axis to the