Upgrade to modern Python syntax (#1956)

* Upgrade to modern Python syntax

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Christian Clauss 2021-08-24 14:28:55 +02:00 committed by GitHub
commit 66d26380e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 179 additions and 181 deletions

View file

@ -202,7 +202,7 @@ class ManimDirective(Directive):
".. code-block:: python",
"",
" from manim import *\n",
*[" " + line for line in self.content],
*(" " + line for line in self.content),
]
source_block = "\n".join(source_block)
@ -286,7 +286,7 @@ def _write_rendering_stats(scene_name, run_time, file_name):
with open(rendering_times_file_path, "a") as file:
csv.writer(file).writerow(
[
re.sub("^(reference\/)|(manim\.)", "", file_name),
re.sub(r"^(reference\/)|(manim\.)", "", file_name),
scene_name,
"%.3f" % run_time,
]
@ -302,7 +302,7 @@ def _log_rendering_times(*args):
print("\nRendering Summary\n-----------------\n")
max_file_length = max([len(row[0]) for row in data])
max_file_length = max(len(row[0]) for row in data)
for key, group in it.groupby(data, key=lambda row: row[0]):
key = key.ljust(max_file_length + 1, ".")
group = list(group)
@ -310,7 +310,7 @@ def _log_rendering_times(*args):
row = group[0]
print(f"{key}{row[2].rjust(7, '.')}s {row[1]}")
continue
time_sum = sum([float(row[2]) for row in group])
time_sum = sum(float(row[2]) for row in group)
print(
f"{key}{f'{time_sum:.3f}'.rjust(7, '.')}s => {len(group)} EXAMPLES"
)

View file

@ -29,7 +29,7 @@ class OpeningManim(Scene):
transform_title.to_corner(UP + LEFT)
self.play(
Transform(title, transform_title),
LaggedStart(*[FadeOut(obj, shift=DOWN) for obj in basel]),
LaggedStart(*(FadeOut(obj, shift=DOWN) for obj in basel)),
)
self.wait()

View file

@ -106,11 +106,9 @@ class Animation:
if func is not None:
anim = func(mobject, *args, **kwargs)
logger.debug(
(
f"The {cls.__name__} animation has been is overridden for "
f"{type(mobject).__name__} mobjects. use_override = False can "
f" be used as keyword argument to prevent animation overriding."
)
f"The {cls.__name__} animation has been is overridden for "
f"{type(mobject).__name__} mobjects. use_override = False can "
f" be used as keyword argument to prevent animation overriding."
)
return anim
return super().__new__(cls)
@ -224,7 +222,7 @@ class Animation:
def get_all_families_zipped(self) -> Iterable[Tuple]:
return zip(
*[mob.family_members_with_points() for mob in self.get_all_mobjects()]
*(mob.family_members_with_points() for mob in self.get_all_mobjects())
)
def update_mobjects(self, dt: float) -> None:

View file

@ -500,12 +500,12 @@ class AddTextWordByWord(Succession):
self.time_per_char = time_per_char
tpc = self.time_per_char
anims = it.chain(
*[
*(
[
ShowIncreasingSubsets(word, run_time=tpc * len(word)),
Animation(word, run_time=0.005 * len(word) ** 1.5),
]
for word in text_mobject
]
)
)
super().__init__(*anims, **kwargs)

View file

@ -327,7 +327,7 @@ class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup):
max_time_width = kwargs.pop("time_width", self.time_width)
AnimationGroup.__init__(
self,
*[
*(
ShowPassingFlash(
vmobject.deepcopy().set_stroke(width=stroke_width),
time_width=time_width,
@ -337,7 +337,7 @@ class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup):
np.linspace(0, max_stroke_width, self.n_segments),
np.linspace(max_time_width, 0, self.n_segments),
)
],
),
)

View file

@ -136,7 +136,7 @@ class Transform(Animation):
self.starting_mobject,
self.target_copy,
]
return zip(*[mob.family_members_with_points() for mob in mobs])
return zip(*(mob.family_members_with_points() for mob in mobs))
def interpolate_submobject(
self,

View file

@ -678,7 +678,7 @@ class Camera:
else:
points = vmobject.get_gradient_start_and_end_points()
points = self.transform_points_pre_display(vmobject, points)
pat = cairo.LinearGradient(*it.chain(*[point[:2] for point in points]))
pat = cairo.LinearGradient(*it.chain(*(point[:2] for point in points)))
step = 1.0 / (len(rgbas) - 1)
offsets = np.arange(0, 1 + step, step)
for rgba, offset in zip(rgbas, offsets):

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: frameserver.proto
"""Generated protocol buffer code."""

View file

@ -4,7 +4,7 @@ import frameserver_pb2 as frameserver__pb2
import grpc
class FrameServerStub(object):
class FrameServerStub:
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
@ -30,7 +30,7 @@ class FrameServerStub(object):
)
class FrameServerServicer(object):
class FrameServerServicer:
"""Missing associated documentation comment in .proto file."""
def GetFrameAtTime(self, request, context):
@ -77,7 +77,7 @@ def add_FrameServerServicer_to_server(servicer, server):
# This class is part of an EXPERIMENTAL API.
class FrameServer(object):
class FrameServer:
"""Missing associated documentation comment in .proto file."""
@staticmethod

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: renderserver.proto
"""Generated protocol buffer code."""

View file

@ -4,7 +4,7 @@ import grpc
import renderserver_pb2 as renderserver__pb2
class RenderServerStub(object):
class RenderServerStub:
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
@ -20,7 +20,7 @@ class RenderServerStub(object):
)
class RenderServerServicer(object):
class RenderServerServicer:
"""Missing associated documentation comment in .proto file."""
def UpdateSceneData(self, request, context):
@ -45,7 +45,7 @@ def add_RenderServerServicer_to_server(servicer, server):
# This class is part of an EXPERIMENTAL API.
class RenderServer(object):
class RenderServer:
"""Missing associated documentation comment in .proto file."""
@staticmethod

View file

@ -51,7 +51,7 @@ from ..utils.space_ops import angle_of_vector
class CoordinateSystem:
"""
r"""
Abstract class for Axes and NumberPlane
Examples
@ -945,10 +945,10 @@ class CoordinateSystem:
x_range = x_range if x_range is not None else self.x_range
return VGroup(
*[
*(
self.get_vertical_line(self.i2gp(x, graph), **kwargs)
for x in np.linspace(x_range[0], x_range[1], num_lines)
]
)
)
def get_T_label(
@ -1171,7 +1171,7 @@ class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL):
Tuple
Coordinates of the point with respect to :class:`Axes`'s basis
"""
return tuple([axis.point_to_number(point) for axis in self.get_axes()])
return tuple(axis.point_to_number(point) for axis in self.get_axes())
def get_axes(self) -> VGroup:
"""Gets the axes.
@ -1260,10 +1260,10 @@ class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL):
if add_vertex_dots:
vertex_dot_style = vertex_dot_style or {}
vertex_dots = VGroup(
*[
*(
Dot(point=vertex, radius=vertex_dot_radius, **vertex_dot_style)
for vertex in vertices
]
)
)
line_graph["vertex_dots"] = vertex_dots

View file

@ -853,7 +853,7 @@ class AnnularSector(Arc):
)
def generate_points(self):
inner_arc, outer_arc = [
inner_arc, outer_arc = (
Arc(
start_angle=self.start_angle,
angle=self.angle,
@ -861,7 +861,7 @@ class AnnularSector(Arc):
arc_center=self.arc_center,
)
for radius in (self.inner_radius, self.outer_radius)
]
)
outer_arc.reverse_points()
self.append_points(inner_arc.get_points())
self.add_line_to(outer_arc.get_points()[0])
@ -1737,7 +1737,7 @@ class Polygram(VMobject, metaclass=ConvertToOpenGL):
self.start_new_path(first_vertex)
self.add_points_as_corners(
[*[np.array(vertex) for vertex in vertices], first_vertex]
[*(np.array(vertex) for vertex in vertices), first_vertex]
)
def get_vertices(self) -> np.ndarray:
@ -2452,28 +2452,28 @@ class Rectangle(Polygon):
grid_xstep = abs(grid_xstep)
count = int(width / grid_xstep)
grid = VGroup(
*[
*(
Line(
v[1] + i * grid_xstep * RIGHT,
v[1] + i * grid_xstep * RIGHT + height * DOWN,
color=color,
)
for i in range(1, count)
]
)
)
self.add(grid)
if grid_ystep is not None:
grid_ystep = abs(grid_ystep)
count = int(height / grid_ystep)
grid = VGroup(
*[
*(
Line(
v[1] + i * grid_ystep * DOWN,
v[1] + i * grid_ystep * DOWN + width * RIGHT,
color=color,
)
for i in range(1, count)
]
)
)
self.add(grid)

View file

@ -629,7 +629,7 @@ class Graph(VMobject, metaclass=ConvertToOpenGL):
animation = anim_args.pop("animation", Create)
vertex_mobjects = self.add_vertices(*args, **kwargs)
return AnimationGroup(*[animation(v, **anim_args) for v in vertex_mobjects])
return AnimationGroup(*(animation(v, **anim_args) for v in vertex_mobjects))
def _remove_vertex(self, vertex):
"""Remove a vertex (as well as all incident edges) from the graph.
@ -700,7 +700,7 @@ class Graph(VMobject, metaclass=ConvertToOpenGL):
animation = anim_args.pop("animation", Uncreate)
mobjects = self.remove_vertices(*vertices)
return AnimationGroup(*[animation(mobj, **anim_args) for mobj in mobjects])
return AnimationGroup(*(animation(mobj, **anim_args) for mobj in mobjects))
def _add_edge(
self,
@ -794,12 +794,12 @@ class Graph(VMobject, metaclass=ConvertToOpenGL):
edge_config = base_edge_config
added_mobjects = sum(
[
(
self._add_edge(
edge, edge_type=edge_type, edge_config=edge_config[edge]
).submobjects
for edge in edges
],
),
[],
)
return self.get_group_class()(*added_mobjects)
@ -811,7 +811,7 @@ class Graph(VMobject, metaclass=ConvertToOpenGL):
animation = anim_args.pop("animation", Create)
mobjects = self.add_edges(*args, **kwargs)
return AnimationGroup(*[animation(mobj, **anim_args) for mobj in mobjects])
return AnimationGroup(*(animation(mobj, **anim_args) for mobj in mobjects))
def _remove_edge(self, edge: Tuple[Hashable]):
"""Remove an edge from the graph.
@ -867,7 +867,7 @@ class Graph(VMobject, metaclass=ConvertToOpenGL):
animation = anim_args.pop("animation", Uncreate)
mobjects = self.remove_edges(*edges)
return AnimationGroup(*[animation(mobj, **anim_args) for mobj in mobjects])
return AnimationGroup(*(animation(mobj, **anim_args) for mobj in mobjects))
@staticmethod
def from_networkx(nxgraph: nx.classes.graph.Graph, **kwargs) -> "Graph":

View file

@ -267,10 +267,10 @@ class Matrix(VMobject, metaclass=ConvertToOpenGL):
"""
return VGroup(
*[
VGroup(*[row[i] for row in self.mob_matrix])
*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(len(self.mob_matrix[0]))
]
)
)
def set_column_colors(self, *colors):
@ -323,7 +323,7 @@ class Matrix(VMobject, metaclass=ConvertToOpenGL):
m0.add(SurroundingRectangle(m0.get_rows()[1]))
self.add(m0)
"""
return VGroup(*[VGroup(*row) for row in self.mob_matrix])
return VGroup(*(VGroup(*row) for row in self.mob_matrix))
def set_row_colors(self, *colors):
"""Set individual colors for each row of the matrix.

View file

@ -185,12 +185,10 @@ class Mobject:
cls.animation_overrides[animation_class] = override_func
else:
raise MultiAnimationOverrideException(
(
f"The animation {animation_class.__name__} for "
f"{cls.__name__} is overridden by more than one method: "
f"{cls.animation_overrides[animation_class].__qualname__} and "
f"{override_func.__qualname__}."
)
f"The animation {animation_class.__name__} for "
f"{cls.__name__} is overridden by more than one method: "
f"{cls.animation_overrides[animation_class].__qualname__} and "
f"{override_func.__qualname__}."
)
def init_gl_data(self):
@ -860,7 +858,7 @@ class Mobject:
return self.updaters
def get_family_updaters(self):
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()]))
return list(it.chain(*(sm.get_updaters() for sm in self.get_family())))
def add_updater(
self,
@ -2027,10 +2025,10 @@ class Mobject:
template.submobjects = []
alphas = np.linspace(0, 1, n_pieces + 1)
return Group(
*[
*(
template.copy().pointwise_become_partial(self, a1, a2)
for a1, a2 in zip(alphas[:-1], alphas[1:])
]
)
)
def get_z_index_reference_point(self):
@ -2324,7 +2322,7 @@ class Mobject:
# Use cell_alignment as fallback
return [cell_alignment * dir] * num
if len(alignments) != num:
raise ValueError("{}_alignments has a mismatching size.".format(name))
raise ValueError(f"{name}_alignments has a mismatching size.")
alignments = list(alignments)
for i in range(num):
alignments[i] = mapping[alignments[i]]
@ -2374,10 +2372,10 @@ class Mobject:
grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)]
measured_heigths = [
max([grid[r][c].height for c in range(cols)]) for r in range(rows)
max(grid[r][c].height for c in range(cols)) for r in range(rows)
]
measured_widths = [
max([grid[r][c].width for r in range(rows)]) for c in range(cols)
max(grid[r][c].width for r in range(rows)) for c in range(cols)
]
# Initialize row_heights / col_widths correctly using measurements as fallback
@ -2385,7 +2383,7 @@ class Mobject:
if sizes is None:
sizes = [None] * num
if len(sizes) != num:
raise ValueError("{} has a mismatching size.".format(name))
raise ValueError(f"{name} has a mismatching size.")
return [
sizes[i] if sizes[i] is not None else measures[i] for i in range(num)
]

View file

@ -95,7 +95,7 @@ class DecimalNumber(VMobject, metaclass=ConvertToOpenGL):
else:
num_string = num_string[1:]
self.add(*[SingleStringMathTex(char, **kwargs) for char in num_string])
self.add(*(SingleStringMathTex(char, **kwargs) for char in num_string))
# Add non-numerical bits
if self.show_ellipsis:

View file

@ -363,7 +363,7 @@ class OpenGLAnnularSector(OpenGLArc):
)
def init_points(self):
inner_arc, outer_arc = [
inner_arc, outer_arc = (
OpenGLArc(
start_angle=self.start_angle,
angle=self.angle,
@ -371,7 +371,7 @@ class OpenGLAnnularSector(OpenGLArc):
arc_center=self.arc_center,
)
for radius in (self.inner_radius, self.outer_radius)
]
)
outer_arc.reverse_points()
self.append_points(inner_arc.get_points())
self.add_line_to(outer_arc.get_points()[0])

View file

@ -637,7 +637,7 @@ class OpenGLMobject:
# Use cell_alignment as fallback
return [cell_alignment * dir] * num
if len(alignments) != num:
raise ValueError("{}_alignments has a mismatching size.".format(name))
raise ValueError(f"{name}_alignments has a mismatching size.")
alignments = list(alignments)
for i in range(num):
alignments[i] = mapping[alignments[i]]
@ -687,10 +687,10 @@ class OpenGLMobject:
grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)]
measured_heigths = [
max([grid[r][c].height for c in range(cols)]) for r in range(rows)
max(grid[r][c].height for c in range(cols)) for r in range(rows)
]
measured_widths = [
max([grid[r][c].width for r in range(rows)]) for c in range(cols)
max(grid[r][c].width for r in range(rows)) for c in range(cols)
]
# Initialize row_heights / col_widths correctly using measurements as fallback
@ -698,7 +698,7 @@ class OpenGLMobject:
if sizes is None:
sizes = [None] * num
if len(sizes) != num:
raise ValueError("{} has a mismatching size.".format(name))
raise ValueError(f"{name} has a mismatching size.")
return [
sizes[i] if sizes[i] is not None else measures[i] for i in range(num)
]
@ -783,7 +783,7 @@ class OpenGLMobject:
copy_mobject.uniforms = dict(self.uniforms)
copy_mobject.submobjects = []
copy_mobject.add(*[sm.copy() for sm in self.submobjects])
copy_mobject.add(*(sm.copy() for sm in self.submobjects))
copy_mobject.match_updaters(self)
copy_mobject.needs_new_bounding_box = self.needs_new_bounding_box
@ -864,7 +864,7 @@ class OpenGLMobject:
return self.time_based_updaters + self.non_time_updaters
def get_family_updaters(self):
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()]))
return list(it.chain(*(sm.get_updaters() for sm in self.get_family())))
def add_updater(self, update_function, index=None, call_updater=True):
if "dt" in get_parameters(update_function):
@ -1458,10 +1458,10 @@ class OpenGLMobject:
template.set_submobjects([])
alphas = np.linspace(0, 1, n_pieces + 1)
return OpenGLGroup(
*[
*(
template.copy().pointwise_become_partial(self, a1, a2)
for a1, a2 in zip(alphas[:-1], alphas[1:])
]
)
)
def get_z_index_reference_point(self):
@ -1777,7 +1777,7 @@ class OpenGLMobject:
def get_shader_wrapper_list(self):
shader_wrappers = it.chain(
[self.get_shader_wrapper()],
*[sm.get_shader_wrapper_list() for sm in self.submobjects],
*(sm.get_shader_wrapper_list() for sm in self.submobjects),
)
batches = batch_by_property(shader_wrappers, lambda sw: sw.get_id())
@ -1867,7 +1867,7 @@ class OpenGLMobject:
return self.event_listners
def get_family_event_listners(self):
return list(it.chain(*[sm.get_event_listners() for sm in self.get_family()]))
return list(it.chain(*(sm.get_event_listners() for sm in self.get_family())))
def get_has_event_listner(self):
return any(mob.get_event_listners() for mob in self.get_family())

View file

@ -196,7 +196,7 @@ class Code(VGroup):
self.file_name = file_name
if self.file_name:
self.ensure_valid_file()
with open(self.file_path, "r") as f:
with open(self.file_path) as f:
self.code_string = f.read()
elif code:
self.code_string = code
@ -292,7 +292,7 @@ class Code(VGroup):
f"From: {os.getcwd()}, could not find {self.file_name} at either "
+ f"of these locations: {possible_paths}"
)
raise IOError(error)
raise OSError(error)
def gen_line_numbers(self):
"""Function to generate line_numbers.

View file

@ -294,10 +294,10 @@ class OpenGLSingleStringMathTex(OpenGLSVGMobject):
tex = tex.replace("\\\\", "\\quad\\\\")
# Handle imbalanced \left and \right
num_lefts, num_rights = [
num_lefts, num_rights = (
len([s for s in tex.split(substr)[1:] if s and s[0] in "(){}[]|.\\"])
for substr in ("\\left", "\\right")
]
)
if num_lefts != num_rights:
tex = tex.replace("\\left", "\\big")
tex = tex.replace("\\right", "\\big")
@ -469,7 +469,7 @@ class OpenGLMathTex(OpenGLSingleStringMathTex):
return tex1 == tex2
return OpenGLVGroup(
*[m for m in self.submobjects if test(tex, m.get_tex_string())]
*(m for m in self.submobjects if test(tex, m.get_tex_string()))
)
def get_part_by_tex(self, tex, **kwargs):

View file

@ -91,9 +91,9 @@ def remove_invisible_chars(mobject):
if mobject[0].__class__ == VGroup:
for i in range(mobject.__len__()):
mobject_without_dots.add(VGroup())
mobject_without_dots[i].add(*[k for k in mobject[i] if k.__class__ != Dot])
mobject_without_dots[i].add(*(k for k in mobject[i] if k.__class__ != Dot))
else:
mobject_without_dots.add(*[k for k in mobject if k.__class__ != Dot])
mobject_without_dots.add(*(k for k in mobject if k.__class__ != Dot))
if iscode:
code.code = mobject_without_dots
return code
@ -163,7 +163,7 @@ class OpenGLParagraph(OpenGLVGroup):
[self.alignment for _ in range(chars_lines_text_list.__len__())]
)
OpenGLVGroup.__init__(
self, *[self.lines[0][i] for i in range(self.lines[0].__len__())], **config
self, *(self.lines[0][i] for i in range(self.lines[0].__len__())), **config
)
self.move_to(np.array([0, 0, 0]))
if self.alignment:

View file

@ -124,7 +124,7 @@ class SVGMobject(VMobject, metaclass=ConvertToOpenGL):
self.file_path = path
return
error = f"From: {os.getcwd()}, could not find {self.file_name} at either of these locations: {possible_paths}"
raise IOError(error)
raise OSError(error)
def generate_points(self):
"""Called by the Mobject abstract base class. Responsible for generating
@ -181,12 +181,12 @@ class SVGMobject(VMobject, metaclass=ConvertToOpenGL):
pass # TODO, handle style
elif element.tagName in ["g", "svg", "symbol", "defs"]:
result += it.chain(
*[
*(
self.get_mobjects_from(
child, style, within_defs=within_defs or is_defs
)
for child in element.childNodes
]
)
)
elif element.tagName == "path":
temp = element.getAttribute("d")
@ -331,12 +331,12 @@ class SVGMobject(VMobject, metaclass=ConvertToOpenGL):
Line
A Line VMobject
"""
x1, y1, x2, y2 = [
x1, y1, x2, y2 = (
self.attribute_to_float(line_element.getAttribute(key))
if line_element.hasAttribute(key)
else 0.0
for key in ("x1", "y1", "x2", "y2")
]
)
return Line([x1, -y1, 0], [x2, -y2, 0], **parse_style(style))
def rect_to_mobject(self, rect_element: MinidomElement, style: dict):
@ -404,12 +404,12 @@ class SVGMobject(VMobject, metaclass=ConvertToOpenGL):
Circle
A Circle VMobject
"""
x, y, r = [
x, y, r = (
self.attribute_to_float(circle_element.getAttribute(key))
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "r")
]
)
return Circle(radius=r, **parse_style(style)).shift(x * RIGHT + y * DOWN)
def ellipse_to_mobject(self, circle_element: MinidomElement, style: dict):
@ -429,12 +429,12 @@ class SVGMobject(VMobject, metaclass=ConvertToOpenGL):
Circle
A Circle VMobject
"""
x, y, rx, ry = [
x, y, rx, ry = (
self.attribute_to_float(circle_element.getAttribute(key))
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "rx", "ry")
]
)
return (
Circle(**parse_style(style))
.scale(rx * RIGHT + ry * UP)

View file

@ -148,10 +148,10 @@ class SingleStringMathTex(SVGMobject):
tex = tex.replace("\\\\", "\\quad\\\\")
# Handle imbalanced \left and \right
num_lefts, num_rights = [
num_lefts, num_rights = (
len([s for s in tex.split(substr)[1:] if s and s[0] in "(){}[]|.\\"])
for substr in ("\\left", "\\right")
]
)
if num_lefts != num_rights:
tex = tex.replace("\\left", "\\big")
tex = tex.replace("\\right", "\\big")
@ -296,7 +296,7 @@ class MathTex(SingleStringMathTex):
patterns = []
patterns.extend(
[
"({})".format(re.escape(ss))
f"({re.escape(ss)})"
for ss in it.chain(
self.substrings_to_isolate, self.tex_to_color_map.keys()
)
@ -353,7 +353,7 @@ class MathTex(SingleStringMathTex):
else:
return tex1 == tex2
return VGroup(*[m for m in self.submobjects if test(tex, m.get_tex_string())])
return VGroup(*(m for m in self.submobjects if test(tex, m.get_tex_string())))
def get_part_by_tex(self, tex, **kwargs):
all_parts = self.get_parts_by_tex(tex, **kwargs)

View file

@ -96,9 +96,9 @@ def remove_invisible_chars(mobject):
if mobject[0].__class__ == VGroup:
for i in range(mobject.__len__()):
mobject_without_dots.add(VGroup())
mobject_without_dots[i].add(*[k for k in mobject[i] if k.__class__ != Dot])
mobject_without_dots[i].add(*(k for k in mobject[i] if k.__class__ != Dot))
else:
mobject_without_dots.add(*[k for k in mobject if k.__class__ != Dot])
mobject_without_dots.add(*(k for k in mobject if k.__class__ != Dot))
if iscode:
code.code = mobject_without_dots
return code

View file

@ -471,10 +471,10 @@ class Table(VGroup):
self.add(table)
"""
return VGroup(
*[
VGroup(*[row[i] for row in self.mob_table])
*(
VGroup(*(row[i] for row in self.mob_table))
for i in range(len(self.mob_table[0]))
]
)
)
def get_rows(self) -> VGroup:
@ -501,7 +501,7 @@ class Table(VGroup):
table.add(SurroundingRectangle(table.get_rows()[1]))
self.add(table)
"""
return VGroup(*[VGroup(*row) for row in self.mob_table])
return VGroup(*(VGroup(*row) for row in self.mob_table))
def set_column_colors(self, *colors: Iterable[Color]) -> "Table":
"""Set individual colors for each column of the table.

View file

@ -434,7 +434,7 @@ class OpenGLVMobject(OpenGLMobject):
nppc = self.n_points_per_curve
points = np.array(points)
self.set_anchors_and_handles(
*[interpolate(points[:-1], points[1:], a) for a in np.linspace(0, 1, nppc)]
*(interpolate(points[:-1], points[1:], a) for a in np.linspace(0, 1, nppc))
)
return self
@ -1248,12 +1248,12 @@ class OpenGLDashedVMobject(OpenGLVMobject):
void_len = (1 - r) / (n - 1)
self.add(
*[
*(
vmobject.get_subcurve(
i * (dash_len + void_len), i * (dash_len + void_len) + dash_len
)
for i in range(n)
]
)
)
# Family is already taken care of by get_subcurve
# implementation

View file

@ -194,7 +194,7 @@ class PMobject(Mobject):
return self
def pointwise_become_partial(self, mobject, a, b):
lower_index, upper_index = [int(x * mobject.get_num_points()) for x in (a, b)]
lower_index, upper_index = (int(x * mobject.get_num_points()) for x in (a, b))
for attr in self.get_array_attrs():
full_array = getattr(mobject, attr)
partial_array = full_array[lower_index:upper_index]

View file

@ -670,10 +670,10 @@ class VMobject(Mobject):
"""
nppcc = self.n_points_per_cubic_curve
self.add_cubic_bezier_curve_to(
*[
*(
interpolate(self.get_last_point(), point, a)
for a in np.linspace(0, 1, nppcc)[1:]
]
)
)
return self
@ -757,7 +757,7 @@ class VMobject(Mobject):
# This will set the handles aligned with the anchors.
# Id est, a bezier curve will be the segment from the two anchors such that the handles belongs to this segment.
self.set_anchors_and_handles(
*[interpolate(points[:-1], points[1:], a) for a in np.linspace(0, 1, nppcc)]
*(interpolate(points[:-1], points[1:], a) for a in np.linspace(0, 1, nppcc))
)
return self
@ -1197,7 +1197,7 @@ class VMobject(Mobject):
def get_points_defining_boundary(self):
# Probably returns all anchors, but this is weird regarding the name of the method.
return np.array(list(it.chain(*[sm.get_anchors() for sm in self.get_family()])))
return np.array(list(it.chain(*(sm.get_anchors() for sm in self.get_family()))))
def get_arc_length(self, sample_points_per_curve: Optional[int] = None) -> float:
"""Return the approximated length of the whole curve.
@ -2125,12 +2125,12 @@ class DashedVMobject(VMobject, metaclass=ConvertToOpenGL):
void_len = (1 - r) / (n - 1)
self.add(
*[
*(
vmobject.get_subcurve(
i * (dash_len + void_len), i * (dash_len + void_len) + dash_len
)
for i in range(n)
]
)
)
# Family is already taken care of by get_subcurve
# implementation

View file

@ -25,7 +25,7 @@ __all__ = [
def get_shader_code_from_file(file_path):
if file_path in file_path_to_code_map:
return file_path_to_code_map[file_path]
with open(file_path, "r") as f:
with open(file_path) as f:
source = f.read()
include_lines = re.finditer(
r"^#include (?P<include_path>.*\.glsl)$", source, flags=re.MULTILINE

View file

@ -29,10 +29,10 @@ def find_file(file_name, directories=None):
return path
else:
logger.debug(f"{path} does not exist.")
raise IOError(f"{file_name} not Found")
raise OSError(f"{file_name} not Found")
class ShaderWrapper(object):
class ShaderWrapper:
def __init__(
self,
vert_data=None,
@ -102,10 +102,8 @@ class ShaderWrapper(object):
def create_program_id(self):
return hash(
"".join(
(
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)
)
@ -148,7 +146,7 @@ class ShaderWrapper(object):
self.vert_data = np.hstack(data_list)
else:
self.vert_data = np.hstack(
[self.vert_data, *[sw.vert_data for sw in shader_wrappers]]
[self.vert_data, *(sw.vert_data for sw in shader_wrappers)]
)
return self
@ -168,10 +166,10 @@ def get_shader_code_from_file(filename):
filename,
directories=[get_shader_dir(), "/"],
)
except IOError:
except OSError:
return None
with open(filepath, "r") as f:
with open(filepath) as f:
result = f.read()
# To share functionality between shaders, some functions are read in

View file

@ -159,9 +159,9 @@ def triangulate_mobject(mob):
)
texture_mode = np.hstack(
(
np.ones((concave_triangle_indices.shape[0])),
-1 * np.ones((convex_triangle_indices.shape[0])),
np.zeros((inner_tri_indices.shape[0])),
np.ones(concave_triangle_indices.shape[0]),
-1 * np.ones(convex_triangle_indices.shape[0]),
np.zeros(inner_tri_indices.shape[0]),
),
)
@ -199,7 +199,7 @@ def render_mobject_strokes_with_matrix(renderer, model_matrix, mobjects):
points = np.empty((total_size, 3))
colors = np.empty((total_size, 4))
widths = np.empty((total_size))
widths = np.empty(total_size)
write_offset = 0
for submob in mobjects:

View file

@ -93,7 +93,7 @@ class SampleSpaceScene(Scene):
def get_prior_rectangles(self):
return VGroup(
*[self.sample_space.horizontal_parts[i].vertical_parts[0] for i in range(2)]
*(self.sample_space.horizontal_parts[i].vertical_parts[0] for i in range(2))
)
def get_posterior_rectangles(self, buff=MED_LARGE_BUFF):

View file

@ -357,7 +357,7 @@ class Scene:
families = [m.get_family() for m in self.mobjects]
def is_top_level(mobject):
num_families = sum([(mobject in family) for family in families])
num_families = sum((mobject in family) for family in families)
return num_families == 1
return list(filter(is_top_level, self.mobjects))

View file

@ -30,7 +30,7 @@ from ..utils.file_ops import (
from ..utils.sounds import get_full_sound_file_path
class SceneFileWriter(object):
class SceneFileWriter:
"""
SceneFileWriter is the object that actually writes the animations
played, into video files, using FFMPEG.

View file

@ -214,10 +214,10 @@ class VectorScene(Scene):
VGroup of the Vector Mobjects representing the basis vectors.
"""
return VGroup(
*[
*(
Vector(vect, color=color, stroke_width=self.basis_vector_stroke_width)
for vect, color in [([1, 0], i_hat_color), ([0, 1], j_hat_color)]
]
)
)
def get_basis_vector_labels(self, **kwargs):
@ -238,7 +238,7 @@ class VectorScene(Scene):
"""
i_hat, j_hat = self.get_basis_vectors()
return VGroup(
*[
*(
self.get_vector_label(
vect, label, color=color, label_scale_factor=1, **kwargs
)
@ -246,7 +246,7 @@ class VectorScene(Scene):
(i_hat, "\\hat{\\imath}", X_COLOR),
(j_hat, "\\hat{\\jmath}", Y_COLOR),
]
]
)
)
def get_vector_label(
@ -402,7 +402,7 @@ class VectorScene(Scene):
FadeOut(array.get_brackets()),
]
self.play(*animations)
y_coord, _ = [anim.mobject for anim in animations]
y_coord, _ = (anim.mobject for anim in animations)
self.play(Create(y_line))
self.play(Create(arrow))
self.wait()
@ -487,11 +487,11 @@ class VectorScene(Scene):
x_max = int(config["frame_x_radius"] + abs(vector[0]))
y_max = int(config["frame_y_radius"] + abs(vector[1]))
dots = VMobject(
*[
*(
Dot(x * RIGHT + y * UP)
for x in range(-x_max, x_max)
for y in range(-y_max, y_max)
]
)
)
dots.set_fill(BLACK, opacity=0)
dots_halfway = dots.copy().shift(vector / 2).set_fill(WHITE, 1)
@ -852,7 +852,7 @@ class LinearTransformationScene(VectorScene):
if new_label:
label_mob.target_text = new_label
else:
label_mob.target_text = "%s(%s)" % (
label_mob.target_text = "{}({})".format(
transformation_name,
label_mob.get_tex_string(),
)
@ -945,7 +945,7 @@ class LinearTransformationScene(VectorScene):
The animation of the movement.
"""
start = VGroup(*pieces)
target = VGroup(*[mob.target for mob in pieces])
target = VGroup(*(mob.target for mob in pieces))
if self.leave_ghost_vectors:
self.add(start.copy().fade(0.7))
return Transform(start, target, lag_ratio=0)

View file

@ -44,10 +44,8 @@ def bezier(
"""
n = len(points) - 1
return lambda t: sum(
[
((1 - t) ** (n - k)) * (t ** k) * choose(n, k) * point
for k, point in enumerate(points)
]
((1 - t) ** (n - k)) * (t ** k) * choose(n, k) * point
for k, point in enumerate(points)
)

View file

@ -23,7 +23,7 @@ def merge_dicts_recursively(*dicts):
When values are dictionaries, it is applied recursively
"""
result = {}
all_items = it.chain(*[d.items() for d in dicts])
all_items = it.chain(*(d.items() for d in dicts))
for key, value in all_items:
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = merge_dicts_recursively(result[key], value)
@ -41,7 +41,7 @@ def update_dict_recursively(current_dict, *others):
# (and less in keeping with all other attr accesses) dict["x"]
class DictAsObject(object):
class DictAsObject:
def __init__(self, dictin):
self.__dict__ = dictin

View file

@ -29,7 +29,7 @@ def extract_mobject_family_members(
else:
method = Mobject.get_family
extracted_mobjects = remove_list_redundancies(
list(it.chain(*[method(m) for m in mobjects]))
list(it.chain(*(method(m) for m in mobjects)))
)
if use_z_index:
return sorted(extracted_mobjects, key=lambda m: m.z_index)

View file

@ -2,7 +2,7 @@ import itertools as it
def extract_mobject_family_members(mobject_list, only_those_with_points=False):
result = list(it.chain(*[mob.get_family() for mob in mobject_list]))
result = list(it.chain(*(mob.get_family() for mob in mobject_list)))
if only_those_with_points:
result = [mob for mob in result if mob.has_points()]
return result

View file

@ -141,7 +141,7 @@ def seek_full_path_from_defaults(file_name, default_dir, extensions):
if os.path.exists(path):
return path
error = f"From: {os.getcwd()}, could not find {file_name} at either of these locations: {possible_paths}"
raise IOError(error)
raise OSError(error)
def modify_atime(file_path):

View file

@ -16,9 +16,12 @@ from .. import config, logger
# Sometimes there are elements that are not suitable for hashing (too long or run-dependent)
# This is used to filter them out.
KEYS_TO_FILTER_OUT = set(
["original_id", "background", "pixel_array", "pixel_array_to_cairo_context"]
)
KEYS_TO_FILTER_OUT = {
"original_id",
"background",
"pixel_array",
"pixel_array_to_cairo_context",
}
class _Memoizer:
@ -321,10 +324,10 @@ def get_hash_from_play_call(
camera_json = get_json(camera_object)
animations_list_json = [get_json(x) for x in sorted(animations_list, key=str)]
current_mobjects_list_json = [get_json(x) for x in current_mobjects_list]
hash_camera, hash_animations, hash_current_mobjects = [
hash_camera, hash_animations, hash_current_mobjects = (
zlib.crc32(repr(json_val).encode())
for json_val in [camera_json, animations_list_json, current_mobjects_list_json]
]
)
hash_complete = f"{hash_camera}_{hash_animations}_{hash_current_mobjects}"
t_end = perf_counter()
logger.debug("Hashing done in %(time)s s.", {"time": str(t_end - t_start)[:8]})

View file

@ -26,7 +26,7 @@ else:
@magics_class
class ManimMagic(Magics):
def __init__(self, shell):
super(ManimMagic, self).__init__(shell)
super().__init__(shell)
self.rendered_files = {}
@needs_local_scope

View file

@ -82,7 +82,7 @@ def all_elements_are_instances(iterable, Class):
def adjacent_n_tuples(objects, n):
return zip(*[[*objects[k:], *objects[:k]] for k in range(n)])
return zip(*([*objects[k:], *objects[:k]] for k in range(n)))
def adjacent_pairs(objects):

View file

@ -93,7 +93,7 @@ def binary_search(function, target, lower_bound, upper_bound, tolerance=1e-4):
rh = upper_bound
while abs(rh - lh) > tolerance:
mh = np.mean([lh, rh])
lx, mx, rx = [function(h) for h in (lh, mh, rh)]
lx, mx, rx = (function(h) for h in (lh, mh, rh))
if lx == target:
return lx
if rx == target:

View file

@ -688,7 +688,7 @@ def earclip_triangulation(verts: np.ndarray, ring_ends: list) -> list:
loop_connections = {}
while detached_rings:
i_range, j_range = [
i_range, j_range = (
list(
filter(
# Ignore indices that are already being
@ -698,7 +698,7 @@ def earclip_triangulation(verts: np.ndarray, ring_ends: list) -> list:
)
)
for ring_group in (attached_rings, detached_rings)
]
)
# Closest point on the attached rings to an estimated midpoint
# of the detached rings

View file

@ -74,9 +74,9 @@ def split_string_list_to_isolate_substrings(string_list, *substrings_to_isolate)
"""
return list(
it.chain(
*[
*(
split_string_to_isolate_substrings(s, *substrings_to_isolate)
for s in string_list
]
)
)
)

View file

@ -256,7 +256,7 @@ class TexTemplateFromFile(TexTemplate):
super().__init__(**kwargs)
def _rebuild(self):
with open(self.template_file, "r") as infile:
with open(self.template_file) as infile:
self.body = infile.read()
def file_not_mutable(self):

View file

@ -80,7 +80,7 @@ def generate_tex_file(expression, environment=None, tex_template=None):
result = os.path.join(tex_dir, tex_hash(output)) + ".tex"
if not os.path.exists(result):
logger.info('Writing "%s" to %s' % ("".join(expression), result))
logger.info('Writing "{}" to {}'.format("".join(expression), result))
with open(result, "w", encoding="utf-8") as outfile:
outfile.write(output)
return result
@ -141,7 +141,7 @@ def tex_compilation_command(tex_compiler, output_format, tex_file, tex_dir):
def insight_inputenc_error(match):
code_point = chr(int(match[1], 16))
name = unicodedata.name(code_point)
yield "TexTemplate does not support character '{}' (U+{})".format(name, match[1])
yield f"TexTemplate does not support character '{name}' (U+{match[1]})"
yield "See the documentation for manim.mobject.svg.tex_mobject for details on using a custom TexTemplate"
@ -187,13 +187,13 @@ def compile_tex(tex_file, tex_compiler, output_format):
f"{tex_compiler} failed but did not produce a log file. "
"Check your LaTeX installation."
)
with open(log_file, "r") as f:
with open(log_file) as f:
log = f.readlines()
error_pos = [
index for index, line in enumerate(log) if line.startswith("!")
]
if error_pos:
with open(tex_file, "r") as g:
with open(tex_file) as g:
tex = g.readlines()
for log_index in error_pos:
logger.error(

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""A library of LaTeX templates."""
__all__ = [
"TexTemplateLibrary",
@ -47,7 +46,7 @@ _3b1b_preamble = r"""
# TexTemplateLibrary
#
class TexTemplateLibrary(object):
class TexTemplateLibrary:
"""
A collection of basic TeX template objects
@ -918,7 +917,7 @@ italichelveticaf.add_to_preamble(
)
class TexFontTemplates(object):
class TexFontTemplates:
"""
A collection of TeX templates for the fonts described at http://jf.burnol.free.fr/showcase.html

View file

@ -6,7 +6,7 @@ from manim import BraceLabel, Mobject, config
def test_mobject_copy():
"""Test that a copy is a deepcopy."""
orig = Mobject()
orig.add(*[Mobject() for _ in range(10)])
orig.add(*(Mobject() for _ in range(10)))
copy = orig.copy()
assert orig is orig

View file

@ -43,23 +43,29 @@ def test_graph_add_edges():
added_mobjects = G.add_edges((1, 3))
assert str(added_mobjects.submobjects) == "[Line]"
assert str(G) == "Graph on 5 vertices and 3 edges"
assert set(G.vertices.keys()) == set([1, 2, 3, 4, 5])
assert set(G.edges.keys()) == set([(1, 2), (2, 3), (1, 3)])
assert set(G.vertices.keys()) == {1, 2, 3, 4, 5}
assert set(G.edges.keys()) == {(1, 2), (2, 3), (1, 3)}
added_mobjects = G.add_edges((1, 42))
assert str(added_mobjects.submobjects) == "[Dot, Line]"
assert str(G) == "Graph on 6 vertices and 4 edges"
assert set(G.vertices.keys()) == set([1, 2, 3, 4, 5, 42])
assert set(G.edges.keys()) == set([(1, 2), (2, 3), (1, 3), (1, 42)])
assert set(G.vertices.keys()) == {1, 2, 3, 4, 5, 42}
assert set(G.edges.keys()) == {(1, 2), (2, 3), (1, 3), (1, 42)}
added_mobjects = G.add_edges((4, 5), (5, 6), (6, 7))
assert len(added_mobjects) == 5
assert str(G) == "Graph on 8 vertices and 7 edges"
assert set(G.vertices.keys()) == set([1, 2, 3, 4, 5, 42, 6, 7])
assert set(G.vertices.keys()) == {1, 2, 3, 4, 5, 42, 6, 7}
assert set(G._graph.nodes()) == set(G.vertices.keys())
assert set(G.edges.keys()) == set(
[(1, 2), (2, 3), (1, 3), (1, 42), (4, 5), (5, 6), (6, 7)]
)
assert set(G.edges.keys()) == {
(1, 2),
(2, 3),
(1, 3),
(1, 42),
(4, 5),
(5, 6),
(6, 7),
}
assert set(G._graph.edges()) == set(G.edges.keys())
@ -68,7 +74,7 @@ def test_graph_remove_edges():
removed_mobjects = G.remove_edges((1, 2))
assert str(removed_mobjects.submobjects) == "[Line]"
assert str(G) == "Graph on 5 vertices and 4 edges"
assert set(G.edges.keys()) == set([(2, 3), (3, 4), (4, 5), (1, 5)])
assert set(G.edges.keys()) == {(2, 3), (3, 4), (4, 5), (1, 5)}
assert set(G._graph.edges()) == set(G.edges.keys())
removed_mobjects = G.remove_edges((2, 3), (3, 4), (4, 5), (5, 1))

View file

@ -20,7 +20,7 @@ def test_TransformFromCopy(scene):
@frames_comparison(last_frame=False)
def test_FullRotation(scene):
s = VGroup(*[Square() for _ in range(4)]).arrange()
s = VGroup(*(Square() for _ in range(4))).arrange()
scene.play(
Rotate(s[0], -2 * TAU),
Rotate(s[1], -1 * TAU),

View file

@ -101,7 +101,7 @@ def test_vgroup_remove_dunder():
obj = VGroup(a, b)
assert len(obj.submobjects) == 2
assert len(b.submobjects) == 1
assert len((obj - a)) == 1
assert len(obj - a) == 1
assert len(obj.submobjects) == 2
obj -= a
b -= c

View file

@ -5,8 +5,8 @@ from functools import wraps
def _check_logs(reference_logfile, generated_logfile):
with open(reference_logfile, "r") as reference_logs, open(
generated_logfile, "r"
with open(reference_logfile) as reference_logs, open(
generated_logfile
) as generated_logs:
reference_logs = reference_logs.readlines()
generated_logs = generated_logs.readlines()

View file

@ -25,7 +25,7 @@ def _get_config_from_video(path_to_video):
def _load_video_data(path_to_data):
return json.load(open(path_to_data, "r"))
return json.load(open(path_to_data))
def _check_video_data(path_control_data, path_to_video_generated):