Skip to content

ENH: speed up wide DataFrame.line plots by using a single LineCollection #61764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ Performance improvements
- Performance improvement in :meth:`DataFrame.stack` when using ``future_stack=True`` and the DataFrame does not have a :class:`MultiIndex` (:issue:`58391`)
- Performance improvement in :meth:`DataFrame.where` when ``cond`` is a :class:`DataFrame` with many columns (:issue:`61010`)
- Performance improvement in :meth:`to_hdf` avoid unnecessary reopenings of the HDF5 file to speedup data addition to files with a very large number of groups . (:issue:`58248`)
- Performance improvement in ``DataFrame.plot(kind="line")``: very wide DataFrames (more than 200 columns) are now rendered with a single :class:`matplotlib.collections.LineCollection` instead of one ``Line2D`` per column, reducing draw time by roughly 7 × on a 2000-column frame. (:issue:`61532`)
- Performance improvement in ``DataFrameGroupBy.__len__`` and ``SeriesGroupBy.__len__`` (:issue:`57595`)
- Performance improvement in indexing operations for string dtypes (:issue:`56997`)
- Performance improvement in unary methods on a :class:`RangeIndex` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57825`)
Expand Down
28 changes: 28 additions & 0 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
Series,
)

import itertools

from matplotlib.collections import LineCollection


def holds_integer(column: Index) -> bool:
return column.inferred_type in {"integer", "mixed-integer"}
Expand Down Expand Up @@ -1549,6 +1553,30 @@ def __init__(self, data, **kwargs) -> None:
self.data = self.data.fillna(value=0)

def _make_plot(self, fig: Figure) -> None:
threshold = 200 # switch when DataFrame has more than this many columns
can_use_lc = (
not self._is_ts_plot() # not a TS plot
and not self.stacked # stacking not requested
and not com.any_not_none(*self.errors.values()) # no error bars
and len(self.data.columns) > threshold
)
if can_use_lc:
Comment on lines +1556 to +1563
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to have a special casing like this because it's difficult to maintain parity between a "fast path" and the existing path.

Is there a way to refactor our plotting here to generalize the plotting to this form rather than the iterative approach below?

ax = self._get_ax(0)
x = self._get_xticks()
segments = [
np.column_stack((x, self.data[col].values)) for col in self.data.columns
]
base_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
colors = list(itertools.islice(itertools.cycle(base_colors), len(segments)))
lc = LineCollection(
segments,
colors=colors,
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]),
)
ax.add_collection(lc)
ax.margins(0.05)
return # skip the per-column Line2D loop

if self._is_ts_plot():
data = maybe_convert_index(self._get_ax(0), self.data)

Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/plotting/frame/test_linecollection_speedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Ensure wide DataFrame.line plots use a single LineCollection
instead of one Line2D per column (PR #61764).
"""

import numpy as np
import pytest

import pandas as pd

# Skip this entire module if matplotlib is not installed
mpl = pytest.importorskip("matplotlib")
plt = pytest.importorskip("matplotlib.pyplot")
from matplotlib.collections import LineCollection


def test_linecollection_used_for_wide_dataframe():
rng = np.random.default_rng(0)
df = pd.DataFrame(rng.standard_normal((10, 201)).cumsum(axis=0))

ax = df.plot(legend=False)

# exactly one LineCollection, and no Line2D artists
assert sum(isinstance(c, LineCollection) for c in ax.collections) == 1
assert len(ax.lines) == 0

plt.close(ax.figure)
Loading