Fix graph animations for remove_vertices and remove_edges (TODO: add_vertices animation is still broken)

This commit is contained in:
Francisco Manríquez Novoa 2026-02-23 00:07:09 -03:00
commit 6b3c191dd1
2 changed files with 108 additions and 65 deletions

View file

@ -59,7 +59,21 @@ class Test(Scene):
positions={4: [2, 1, 0], 5: [2, -1, 0]},
)
)
self.play( # TODO: this animation is currently broken
graph.animate.add_edges(
(2, 4),
(3, 5),
(4, 5),
edge_config={
(2, 4): {"stroke_color": GREEN},
(3, 5): {"stroke_color": GREEN},
(4, 5): {"stroke_color": YELLOW},
},
)
)
self.wait(1)
self.play(graph.animate.remove_vertices(1))
self.play(graph.animate.remove_edges((4, 5)))
self.play(Uncreate(graph))

View file

@ -26,12 +26,10 @@ from manim.animation.composition import AnimationGroup
from manim.animation.creation import Create, Uncreate
from manim.mobject.geometry.arc import Dot, LabeledDot
from manim.mobject.geometry.line import Line
from manim.mobject.opengl.opengl_mobject import (
OpenGLMobject as Mobject,
)
from manim.mobject.opengl.opengl_mobject import (
override_animate,
)
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup as VGroup
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject as VMobject
from manim.mobject.text.tex_mobject import MathTex
from manim.utils.color import BLACK
@ -571,10 +569,10 @@ class GenericGraph(VMobject):
layout: LayoutName | dict[Hashable, Point3DLike] | LayoutFunction = "spring",
layout_scale: float | tuple[float, float, float] = 2,
layout_config: dict | None = None,
vertex_type: type[Mobject] = Dot,
vertex_type: type[VMobject] = Dot,
vertex_config: dict | None = None,
vertex_mobjects: dict | None = None,
edge_type: type[Mobject] = Line,
edge_type: type[VMobject] = Line,
partitions: Sequence[Sequence[Hashable]] | None = None,
root_vertex: Hashable | None = None,
edge_config: dict | None = None,
@ -666,12 +664,12 @@ class GenericGraph(VMobject):
raise NotImplementedError("To be implemented in concrete subclasses")
def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[VMobject]
):
"""Helper method for populating the edges of the graph."""
raise NotImplementedError("To be implemented in concrete subclasses")
def __getitem__(self: Graph, v: Hashable) -> Mobject:
def __getitem__(self: Graph, v: Hashable) -> VMobject:
return self.vertices[v]
def _create_vertex(
@ -680,10 +678,10 @@ class GenericGraph(VMobject):
position: Point3DLike | None = None,
label: bool = False,
label_fill_color: str = BLACK,
vertex_type: type[Mobject] = Dot,
vertex_type: type[VMobject] = Dot,
vertex_config: dict | None = None,
vertex_mobject: dict | None = None,
) -> tuple[Hashable, Point3D, dict, Mobject]:
) -> tuple[Hashable, Point3D, dict, VMobject]:
np_position: Point3D = (
self.get_center() if position is None else np.asarray(position)
)
@ -700,7 +698,7 @@ class GenericGraph(VMobject):
label = MathTex(vertex, color=label_fill_color)
elif vertex in self._labels:
label = self._labels[vertex]
elif not isinstance(label, Mobject):
elif not isinstance(label, VMobject):
label = None
base_vertex_config = copy(self.default_vertex_config)
@ -724,8 +722,8 @@ class GenericGraph(VMobject):
vertex: Hashable,
position: Point3DLike,
vertex_config: dict,
vertex_mobject: Mobject,
) -> Mobject:
vertex_mobject: VMobject,
) -> VMobject:
if vertex in self.vertices:
raise ValueError(
f"Vertex identifier '{vertex}' is already used for a vertex in this graph.",
@ -751,10 +749,10 @@ class GenericGraph(VMobject):
position: Point3DLike | None = None,
label: bool = False,
label_fill_color: str = BLACK,
vertex_type: type[Mobject] = Dot,
vertex_type: type[VMobject] = Dot,
vertex_config: dict | None = None,
vertex_mobject: dict | None = None,
) -> Mobject:
) -> VMobject:
"""Add a vertex to the graph.
Parameters
@ -769,7 +767,7 @@ class GenericGraph(VMobject):
Controls whether or not the vertex is labeled. If ``False`` (the default),
the vertex is not labeled; if ``True`` it is labeled using its
names (as specified in ``vertex``) via :class:`~.MathTex`. Alternatively,
any :class:`~.Mobject` can be passed to be used as the label.
any :class:`~.VMobject` can be passed to be used as the label.
label_fill_color
Sets the fill color of the default labels generated when ``labels``
is set to ``True``. Has no effect for other values of ``label``.
@ -800,10 +798,10 @@ class GenericGraph(VMobject):
positions: dict | None = None,
labels: bool = False,
label_fill_color: str = BLACK,
vertex_type: type[Mobject] = Dot,
vertex_type: type[VMobject] = Dot,
vertex_config: dict | None = None,
vertex_mobjects: dict | None = None,
) -> Iterable[tuple[Hashable, Point3D, dict, Mobject]]:
) -> Iterable[tuple[Hashable, Point3D, dict, VMobject]]:
if positions is None:
positions = {}
if vertex_mobjects is None:
@ -854,10 +852,10 @@ class GenericGraph(VMobject):
positions: dict | None = None,
labels: bool = False,
label_fill_color: str = BLACK,
vertex_type: type[Mobject] = Dot,
vertex_type: type[VMobject] = Dot,
vertex_config: dict | None = None,
vertex_mobjects: dict | None = None,
) -> list[Mobject]:
) -> VGroup:
"""Add a list of vertices to the graph.
Parameters
@ -872,7 +870,7 @@ class GenericGraph(VMobject):
Controls whether or not the vertex is labeled. If ``False`` (the default),
the vertex is not labeled; if ``True`` it is labeled using its
names (as specified in ``vertex``) via :class:`~.MathTex`. Alternatively,
any :class:`~.Mobject` can be passed to be used as the label.
any :class:`~.VMobject` can be passed to be used as the label.
label_fill_color
Sets the fill color of the default labels generated when ``labels``
is set to ``True``. Has no effect for other values of ``labels``.
@ -886,7 +884,7 @@ class GenericGraph(VMobject):
values are mobjects that should be used as vertices. Overrides
all other vertex customization options.
"""
return [
return VGroup(
self._add_created_vertex(*v)
for v in self._create_vertices(
*vertices,
@ -897,7 +895,7 @@ class GenericGraph(VMobject):
vertex_config=vertex_config,
vertex_mobjects=vertex_mobjects,
)
]
)
@override_animate(add_vertices)
def _add_vertices_animation(
@ -906,19 +904,21 @@ class GenericGraph(VMobject):
anim_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> AnimationGroup:
vertex_mobjects = self.add_vertices(*vertices, **kwargs)
# Use introducer=False to prevent re-adding the vertices when animating them
base_anim_args = {"animation": Create, "introducer": False}
if anim_args is not None:
base_anim_args.update(anim_args)
animation = base_anim_args.pop("animation")
vertex_mobjects = self.add_vertices(*vertices, **kwargs)
return AnimationGroup(
animation(vertex_mobject, **base_anim_args)
for vertex_mobject in vertex_mobjects
*(
animation(vertex_mobject, **base_anim_args)
for vertex_mobject in vertex_mobjects
),
)
def _remove_vertex(self, vertex):
def _remove_vertex(self, vertex: Hashable) -> VGroup:
"""Remove a vertex (as well as all incident edges) from the graph.
Parameters
@ -952,9 +952,9 @@ class GenericGraph(VMobject):
to_remove.append(self.vertices.pop(vertex))
self.remove(*to_remove)
return self.get_group_class()(*to_remove)
return VGroup(*to_remove)
def remove_vertices(self, *vertices):
def remove_vertices(self, *vertices: Hashable) -> VGroup:
"""Remove several vertices from the graph.
Parameters
@ -978,26 +978,32 @@ class GenericGraph(VMobject):
mobjects = []
for v in vertices:
mobjects.extend(self._remove_vertex(v).submobjects)
return self.get_group_class()(*mobjects)
return VGroup(*mobjects)
@override_animate(remove_vertices)
def _remove_vertices_animation(self, *vertices, anim_args=None):
if anim_args is None:
anim_args = {}
def _remove_vertices_animation(
self, *vertices: Hashable, anim_args: dict[str, Any] | None = None
) -> AnimationGroup:
base_anim_args = {"animation": Uncreate}
if anim_args is not None:
base_anim_args.update(anim_args)
animation = base_anim_args.pop("animation")
animation = anim_args.pop("animation", Uncreate)
mobjects = self.remove_vertices(*vertices)
vertex_and_edge_mobjects = self.remove_vertices(*vertices)
return AnimationGroup(
*(animation(mobj, **anim_args) for mobj in mobjects), group=self
*(
animation(vertex_or_edge_mobject, **anim_args)
for vertex_or_edge_mobject in vertex_and_edge_mobjects
),
introducer=True, # Reintroduce vertices and edges temporarily to animate them
)
def _add_edge(
self,
edge: tuple[Hashable, Hashable],
edge_type: type[Mobject] = Line,
edge_type: type[VMobject] = Line,
edge_config: dict | None = None,
):
) -> VGroup:
"""Add a new edge to the graph.
Parameters
@ -1041,15 +1047,17 @@ class GenericGraph(VMobject):
self.add(edge_mobject)
added_mobjects.append(edge_mobject)
return self.get_group_class()(*added_mobjects)
return VGroup(*added_mobjects)
def add_edges(
self,
*edges: tuple[Hashable, Hashable],
edge_type: type[Mobject] = Line,
edge_config: dict | None = None,
**kwargs,
):
edge_type: type[VMobject] = Line,
edge_config: dict[str, Any]
| dict[tuple[Hashable, Hashable], dict[str, Any]]
| None = None,
**kwargs: Any,
) -> VGroup:
"""Add new edges to the graph.
Parameters
@ -1102,20 +1110,33 @@ class GenericGraph(VMobject):
),
added_vertices,
)
return self.get_group_class()(*added_mobjects)
return VGroup(*added_mobjects)
@override_animate(add_edges)
def _add_edges_animation(self, *args, anim_args=None, **kwargs):
if anim_args is None:
anim_args = {}
animation = anim_args.pop("animation", Create)
def _add_edges_animation(
self,
*edges: tuple[Hashable, Hashable],
anim_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> AnimationGroup:
# TODO: the animation is broken with introducer=False, but not passing it
# disbands the graph upon re-adding the edges and vertices. Fix this
mobjects = self.add_edges(*args, **kwargs)
# Use introducer=False to prevent re-adding the edges and vertices when animating
base_anim_args = {"animation": Create, "introducer": False}
if anim_args is not None:
base_anim_args.update(anim_args)
animation = base_anim_args.pop("animation")
edge_and_vertex_mobjects = self.add_edges(*edges, **kwargs)
return AnimationGroup(
*(animation(mobj, **anim_args) for mobj in mobjects), group=self
*(
animation(edge_or_vertex_mobject, **base_anim_args)
for edge_or_vertex_mobject in edge_and_vertex_mobjects
)
)
def _remove_edge(self, edge: tuple[Hashable]):
def _remove_edge(self, edge: tuple[Hashable]) -> VMobject:
"""Remove an edge from the graph.
Parameters
@ -1127,7 +1148,7 @@ class GenericGraph(VMobject):
Returns
-------
Mobject
VMobject
The removed edge.
"""
@ -1142,7 +1163,7 @@ class GenericGraph(VMobject):
self.remove(edge_mobject)
return edge_mobject
def remove_edges(self, *edges: tuple[Hashable]):
def remove_edges(self, *edges: tuple[Hashable]) -> VGroup:
"""Remove several edges from the graph.
Parameters
@ -1157,17 +1178,25 @@ class GenericGraph(VMobject):
"""
edge_mobjects = [self._remove_edge(edge) for edge in edges]
return self.get_group_class()(*edge_mobjects)
return VGroup(*edge_mobjects)
@override_animate(remove_edges)
def _remove_edges_animation(self, *edges, anim_args=None):
if anim_args is None:
anim_args = {}
def _remove_edges_animation(
self, *edges: tuple[Hashable, Hashable], anim_args: dict[str, Any] | None = None
) -> AnimationGroup:
base_anim_args = {"animation": Uncreate}
if anim_args is not None:
base_anim_args.update(anim_args)
animation = base_anim_args.pop("animation")
animation = anim_args.pop("animation", Uncreate)
mobjects = self.remove_edges(*edges)
return AnimationGroup(*(animation(mobj, **anim_args) for mobj in mobjects))
edge_and_vertex_mobjects = self.remove_edges(*edges)
return AnimationGroup(
*(
animation(edge_or_vertex_mobject, **anim_args)
for edge_or_vertex_mobject in edge_and_vertex_mobjects
),
introducer=True, # Reintroduce edges and vertices temporarily to animate them
)
@classmethod
def from_networkx(
@ -1539,7 +1568,7 @@ class Graph(GenericGraph):
return nx.Graph()
def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[VMobject]
):
self.edges = {
(u, v): edge_type(
@ -1749,7 +1778,7 @@ class DiGraph(GenericGraph):
return nx.DiGraph()
def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[VMobject]
):
self.edges = {
(u, v): edge_type(
@ -1772,7 +1801,7 @@ class DiGraph(GenericGraph):
"""
for (u, v), edge in graph.edges.items():
tip = edge.pop_tips()[0]
# Passing the Mobject instead of the vertex makes the tip
# Passing the VMobject instead of the vertex makes the tip
# stop on the bounding box of the vertex.
edge.set_points_by_ends(
graph[u],