Implemented :class:.LineJointTypes for both Cairo and OpenGL renderer (#3016)

* LineJoins added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added joint type enum, refactored proposed implementation

* added test for joint types

* added documentation

* let LineJointType.AUTO be rendered like before

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update added example in basic.py to reflect changed implementation

* fix RTD build

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* moved rendered example in documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Benjamin Hackl <devel@benjamin-hackl.at>
This commit is contained in:
Alexander Vázquez 2022-11-13 17:14:55 -06:00 committed by GitHub
commit 388504307a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 119 additions and 44 deletions

View file

@ -161,4 +161,19 @@ class SpiralInExample(Scene):
self.play(FadeOut(shapes))
Triangle.set_default(stroke_width=20)
class LineJoints(Scene):
def construct(self):
t1 = Triangle()
t2 = Triangle(line_join=LineJointType.ROUND)
t3 = Triangle(line_join=LineJointType.BEVEL)
grp = VGroup(t1, t2, t3).arrange(RIGHT)
grp.set(width=config.frame_width - 1)
self.add(grp)
# See many more examples at https://docs.manim.community/en/stable/examples.html

View file

@ -30,6 +30,13 @@ from ..utils.images import get_full_raster_image_path
from ..utils.iterables import list_difference_update
from ..utils.space_ops import angle_of_vector
LINE_JOIN_MAP = {
LineJointType.AUTO: None, # TODO: this could be improved
LineJointType.ROUND: cairo.LineJoin.ROUND,
LineJointType.BEVEL: cairo.LineJoin.BEVEL,
LineJointType.MITER: cairo.LineJoin.MITER,
}
class Camera:
"""Base camera class.
@ -768,6 +775,8 @@ class Camera:
# This ensures lines have constant width as you zoom in on them.
* (self.frame_width / self.frame_width),
)
if vmobject.joint_type != LineJointType.AUTO:
ctx.set_line_join(LINE_JOIN_MAP[vmobject.joint_type])
ctx.stroke_preserve()
return self

View file

@ -72,6 +72,7 @@ __all__ = [
"SHIFT_VALUE",
"CTRL_VALUE",
"RendererType",
"LineJointType",
]
# Messages
@ -264,3 +265,39 @@ class RendererType(Enum):
CAIRO = "cairo" #: A renderer based on the cairo backend.
OPENGL = "opengl" #: An OpenGL-based renderer.
class LineJointType(Enum):
"""Collection of available line joint types.
See the example below for a visual illustration of the different
joint types.
Examples
--------
.. manim:: LineJointVariants
:save_last_frame:
class LineJointVariants(Scene):
def construct(self):
mob = VMobject(stroke_width=20, color=GREEN).set_points_as_corners([
np.array([-2, 0, 0]),
np.array([0, 0, 0]),
np.array([-2, 1, 0]),
])
lines = VGroup(*[mob.copy() for _ in range(len(LineJointType))])
for line, joint_type in zip(lines, LineJointType):
line.joint_type = joint_type
lines.arrange(RIGHT, buff=1)
self.add(lines)
for line in lines:
label = Text(line.joint_type.name).next_to(line, DOWN)
self.add(label)
"""
AUTO = 0
ROUND = 1
BEVEL = 2
MITER = 3

View file

@ -35,13 +35,6 @@ from manim.utils.space_ops import (
z_to_vector,
)
JOINT_TYPE_MAP = {
"auto": 0,
"round": 1,
"bevel": 2,
"miter": 3,
}
def triggers_refreshed_triangulation(func):
@wraps(func)
@ -108,7 +101,7 @@ class OpenGLVMobject(OpenGLMobject):
should_subdivide_sharp_curves: bool = False,
should_remove_null_curves: bool = False,
# Could also be "bevel", "miter", "round"
joint_type: str = "auto",
joint_type: LineJointType | None = None,
flat_stroke: bool = True,
render_primitive=moderngl.TRIANGLES,
triangulation_locked: bool = False,
@ -134,7 +127,8 @@ class OpenGLVMobject(OpenGLMobject):
self.long_lines = long_lines
self.should_subdivide_sharp_curves = should_subdivide_sharp_curves
self.should_remove_null_curves = should_remove_null_curves
# Could also be "bevel", "miter", "round"
if joint_type is None:
joint_type = LineJointType.AUTO
self.joint_type = joint_type
self.flat_stroke = flat_stroke
self.render_primitive = render_primitive
@ -1564,7 +1558,7 @@ class OpenGLVMobject(OpenGLMobject):
def get_stroke_uniforms(self):
result = dict(super().get_shader_uniforms())
result["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
result["joint_type"] = self.joint_type.value
result["flat_stroke"] = float(self.flat_stroke)
return result

View file

@ -1,5 +1,6 @@
"""Mobjects that use vector graphics."""
from __future__ import annotations
__all__ = [
"VMobject",
@ -70,6 +71,10 @@ class VMobject(Mobject):
that it should count in parent mobject's path
tolerance_for_point_equality
This is within a pixel
joint_type
The line joint type used to connect the curve segments
of this vectorized mobject. See :class:`.LineJointType`
for options.
"""
sheen_factor = 0.0
@ -85,6 +90,7 @@ class VMobject(Mobject):
background_stroke_opacity=1.0,
background_stroke_width=0,
sheen_factor=0.0,
joint_type: LineJointType | None = None,
sheen_direction=UL,
close_new_points=False,
pre_function_handle_to_anchor_scale_factor=0.01,
@ -103,6 +109,9 @@ class VMobject(Mobject):
self.background_stroke_opacity = background_stroke_opacity
self.background_stroke_width = background_stroke_width
self.sheen_factor = sheen_factor
if joint_type is None:
joint_type = LineJointType.AUTO
self.joint_type = joint_type
self.sheen_direction = sheen_direction
self.close_new_points = close_new_points
self.pre_function_handle_to_anchor_scale_factor = (
@ -208,8 +217,8 @@ class VMobject(Mobject):
def set_fill(
self,
color: Optional[str] = None,
opacity: Optional[float] = None,
color: str | None = None,
opacity: float | None = None,
family: bool = True,
):
"""Set the fill color and fill opacity of a :class:`VMobject`.
@ -574,14 +583,14 @@ class VMobject(Mobject):
offset = np.dot(bases, direction)
return (c - offset, c + offset)
def color_using_background_image(self, background_image: Union[Image, str]):
def color_using_background_image(self, background_image: Image | str):
self.background_image = background_image
self.set_color(WHITE)
for submob in self.submobjects:
submob.color_using_background_image(background_image)
return self
def get_background_image(self) -> Union[Image, str]:
def get_background_image(self) -> Image | str:
return self.background_image
def match_background_image(self, vmobject):
@ -827,7 +836,7 @@ class VMobject(Mobject):
if not self.is_closed():
self.add_line_to(self.get_subpaths()[-1][0])
def add_points_as_corners(self, points: np.ndarray) -> "VMobject":
def add_points_as_corners(self, points: np.ndarray) -> VMobject:
for point in points:
self.add_line_to(point)
return points
@ -929,7 +938,7 @@ class VMobject(Mobject):
self,
angle: float,
axis: np.ndarray = OUT,
about_point: Optional[Sequence[float]] = None,
about_point: Sequence[float] | None = None,
**kwargs,
):
self.rotate_sheen_direction(angle, axis)
@ -1000,7 +1009,7 @@ class VMobject(Mobject):
def get_cubic_bezier_tuples_from_points(self, points):
return np.array(list(self.gen_cubic_bezier_tuples_from_points(points)))
def gen_cubic_bezier_tuples_from_points(self, points: np.ndarray) -> typing.Tuple:
def gen_cubic_bezier_tuples_from_points(self, points: np.ndarray) -> tuple:
"""Returns the bezier tuples from an array of points.
self.points is a list of the anchors and handles of the bezier curves of the mobject (ie [anchor1, handle1, handle2, anchor2, anchor3 ..])
@ -1031,7 +1040,7 @@ class VMobject(Mobject):
self,
points: np.ndarray,
filter_func: typing.Callable[[int], bool],
) -> typing.Tuple:
) -> tuple:
"""Given an array of points defining the bezier curves of the vmobject, return subpaths formed by these points.
Here, Two bezier curves form a path if at least two of their anchors are evaluated True by the relation defined by filter_func.
@ -1075,7 +1084,7 @@ class VMobject(Mobject):
lambda n: not self.consider_points_equals_2d(points[n - 1], points[n]),
)
def get_subpaths(self) -> typing.Tuple:
def get_subpaths(self) -> tuple:
"""Returns subpaths formed by the curves of the VMobject.
Subpaths are ranges of curves with each pair of consecutive curves having their end/start points coincident.
@ -1122,7 +1131,7 @@ class VMobject(Mobject):
def get_nth_curve_length_pieces(
self,
n: int,
sample_points: Optional[int] = None,
sample_points: int | None = None,
) -> np.ndarray:
"""Returns the array of short line lengths used for length approximation.
@ -1151,7 +1160,7 @@ class VMobject(Mobject):
def get_nth_curve_length(
self,
n: int,
sample_points: Optional[int] = None,
sample_points: int | None = None,
) -> float:
"""Returns the (approximate) length of the nth curve.
@ -1175,8 +1184,8 @@ class VMobject(Mobject):
def get_nth_curve_function_with_length(
self,
n: int,
sample_points: Optional[int] = None,
) -> typing.Tuple[typing.Callable[[float], np.ndarray], float]:
sample_points: int | None = None,
) -> tuple[typing.Callable[[float], np.ndarray], float]:
"""Returns the expression of the nth curve along with its (approximate) length.
Parameters
@ -1229,7 +1238,7 @@ class VMobject(Mobject):
def get_curve_functions_with_lengths(
self, **kwargs
) -> typing.Iterable[typing.Tuple[typing.Callable[[float], np.ndarray], float]]:
) -> typing.Iterable[tuple[typing.Callable[[float], np.ndarray], float]]:
"""Gets the functions and lengths of the curves for the mobject.
Parameters
@ -1294,7 +1303,7 @@ class VMobject(Mobject):
def proportion_from_point(
self,
point: typing.Iterable[typing.Union[float, int]],
point: typing.Iterable[float | int],
) -> float:
"""Returns the proportion along the path of the :class:`VMobject`
a particular given point is at.
@ -1401,7 +1410,7 @@ class VMobject(Mobject):
# Probably returns all anchors, but this is weird regarding the name of the method.
return np.array(list(it.chain(*(sm.get_anchors() for sm in self.get_family()))))
def get_arc_length(self, sample_points_per_curve: Optional[int] = None) -> float:
def get_arc_length(self, sample_points_per_curve: int | None = None) -> float:
"""Return the approximated length of the whole curve.
Parameters
@ -1423,7 +1432,7 @@ class VMobject(Mobject):
)
# Alignment
def align_points(self, vmobject: "VMobject"):
def align_points(self, vmobject: VMobject):
"""Adds points to self and vmobject so that they both have the same number of subpaths, with
corresponding subpaths each containing the same number of points.
@ -1614,7 +1623,7 @@ class VMobject(Mobject):
def pointwise_become_partial(
self,
vmobject: "VMobject",
vmobject: VMobject,
a: float,
b: float,
):
@ -1674,7 +1683,7 @@ class VMobject(Mobject):
)
return self
def get_subcurve(self, a: float, b: float) -> "VMobject":
def get_subcurve(self, a: float, b: float) -> VMobject:
"""Returns the subcurve of the VMobject between the interval [a, b].
The curve is a VMobject itself.
@ -1906,7 +1915,7 @@ class VGroup(VMobject, metaclass=ConvertToOpenGL):
def __isub__(self, vmobject):
return self.remove(vmobject)
def __setitem__(self, key: int, value: Union[VMobject, typing.Sequence[VMobject]]):
def __setitem__(self, key: int, value: VMobject | typing.Sequence[VMobject]):
"""Override the [] operator for item assignment.
Parameters
@ -2032,10 +2041,10 @@ class VDict(VMobject, metaclass=ConvertToOpenGL):
def __init__(
self,
mapping_or_iterable: Union[
typing.Mapping[typing.Hashable, VMobject],
typing.Iterable[typing.Tuple[typing.Hashable, VMobject]],
] = {},
mapping_or_iterable: (
typing.Mapping[typing.Hashable, VMobject]
| typing.Iterable[tuple[typing.Hashable, VMobject]]
) = {},
show_keys: bool = False,
**kwargs,
):
@ -2049,10 +2058,10 @@ class VDict(VMobject, metaclass=ConvertToOpenGL):
def add(
self,
mapping_or_iterable: Union[
typing.Mapping[typing.Hashable, VMobject],
typing.Iterable[typing.Tuple[typing.Hashable, VMobject]],
],
mapping_or_iterable: (
typing.Mapping[typing.Hashable, VMobject]
| typing.Iterable[tuple[typing.Hashable, VMobject]]
),
):
"""Adds the key-value pairs to the :class:`VDict` object.

View file

@ -218,12 +218,6 @@ class OpenGLCamera(OpenGLMobject):
points_per_curve = 3
JOINT_TYPE_MAP = {
"auto": 0,
"round": 1,
"bevel": 2,
"miter": 3,
}
class OpenGLRenderer:

View file

@ -36,3 +36,20 @@ def test_match_style(scene):
VGroup(square, circle).arrange()
circle.match_style(square)
scene.add(square, circle)
@frames_comparison
def test_vmobject_joint_types(scene):
angled_line = VMobject(stroke_width=20, color=GREEN).set_points_as_corners(
[
np.array([-2, 0, 0]),
np.array([0, 0, 0]),
np.array([-2, 1, 0]),
]
)
lines = VGroup(*[angled_line.copy() for _ in range(len(LineJointType))])
for line, joint_type in zip(lines, LineJointType):
line.joint_type = joint_type
lines.arrange(RIGHT, buff=1)
scene.add(lines)