Skip to content

Commit 0772510

Browse files
authored
feat: (Series|DataFrame).plot.(line|area|scatter) (#431)
Fixing internal bugs: line: b/322177942 scatter: b/322178336 area: b/322178394
1 parent 7f3d41c commit 0772510

File tree

5 files changed

+396
-30
lines changed

5 files changed

+396
-30
lines changed

Diff for: bigframes/operations/_matplotlib/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
PLOT_CLASSES: dict[str, type[core.MPLPlot]] = {
1919
"hist": hist.HistPlot,
20+
"line": core.LinePlot,
21+
"area": core.AreaPlot,
22+
"scatter": core.ScatterPlot,
2023
}
2124

2225

Diff for: bigframes/operations/_matplotlib/core.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import abc
16+
import typing
1617

1718
import matplotlib.pyplot as plt
1819

@@ -28,3 +29,44 @@ def draw(self) -> None:
2829
@property
2930
def result(self):
3031
return self.axes
32+
33+
34+
class SamplingPlot(MPLPlot):
35+
@abc.abstractproperty
36+
def _kind(self):
37+
pass
38+
39+
def __init__(self, data, **kwargs) -> None:
40+
self.kwargs = kwargs
41+
self.data = self._compute_plot_data(data)
42+
43+
def generate(self) -> None:
44+
self.axes = self.data.plot(kind=self._kind, **self.kwargs)
45+
46+
def _compute_plot_data(self, data):
47+
# TODO: Cache the sampling data in the PlotAccessor.
48+
sampling_n = self.kwargs.pop("sampling_n", 100)
49+
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
50+
return (
51+
data.sample(n=sampling_n, random_state=sampling_random_state)
52+
.to_pandas()
53+
.sort_index()
54+
)
55+
56+
57+
class LinePlot(SamplingPlot):
58+
@property
59+
def _kind(self) -> typing.Literal["line"]:
60+
return "line"
61+
62+
63+
class AreaPlot(SamplingPlot):
64+
@property
65+
def _kind(self) -> typing.Literal["area"]:
66+
return "area"
67+
68+
69+
class ScatterPlot(SamplingPlot):
70+
@property
71+
def _kind(self) -> typing.Literal["scatter"]:
72+
return "scatter"

Diff for: bigframes/operations/plotting.py

+53-4
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,73 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Sequence
15+
import typing
1616

1717
import bigframes_vendored.pandas.plotting._core as vendordt
1818

1919
import bigframes.constants as constants
2020
import bigframes.operations._matplotlib as bfplt
2121

2222

23-
class PlotAccessor:
23+
class PlotAccessor(vendordt.PlotAccessor):
2424
__doc__ = vendordt.PlotAccessor.__doc__
2525

2626
def __init__(self, data) -> None:
2727
self._parent = data
2828

29-
def hist(self, by: Optional[Sequence[str]] = None, bins: int = 10, **kwargs):
29+
def hist(
30+
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs
31+
):
3032
if kwargs.pop("backend", None) is not None:
3133
raise NotImplementedError(
3234
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
3335
)
34-
# Calls matplotlib backend to plot the data.
3536
return bfplt.plot(self._parent.copy(), kind="hist", by=by, bins=bins, **kwargs)
37+
38+
def line(
39+
self,
40+
x: typing.Optional[typing.Hashable] = None,
41+
y: typing.Optional[typing.Hashable] = None,
42+
**kwargs,
43+
):
44+
return bfplt.plot(
45+
self._parent.copy(),
46+
kind="line",
47+
x=x,
48+
y=y,
49+
**kwargs,
50+
)
51+
52+
def area(
53+
self,
54+
x: typing.Optional[typing.Hashable] = None,
55+
y: typing.Optional[typing.Hashable] = None,
56+
stacked: bool = True,
57+
**kwargs,
58+
):
59+
return bfplt.plot(
60+
self._parent.copy(),
61+
kind="area",
62+
x=x,
63+
y=y,
64+
stacked=stacked,
65+
**kwargs,
66+
)
67+
68+
def scatter(
69+
self,
70+
x: typing.Optional[typing.Hashable] = None,
71+
y: typing.Optional[typing.Hashable] = None,
72+
s: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
73+
c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
74+
**kwargs,
75+
):
76+
return bfplt.plot(
77+
self._parent.copy(),
78+
kind="scatter",
79+
x=x,
80+
y=y,
81+
s=s,
82+
c=c,
83+
**kwargs,
84+
)

Diff for: tests/system/small/operations/test_plot.py renamed to tests/system/small/operations/test_plotting.py

+67
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import pandas._testing as tm
1617
import pytest
1718

19+
import bigframes.pandas as bpd
20+
1821

1922
def _check_legend_labels(ax, labels):
2023
"""
@@ -166,3 +169,67 @@ def test_hist_kwargs_ticks_props(scalars_dfs):
166169
for i in range(len(pd_xlables)):
167170
tm.assert_almost_equal(ylabels[i].get_fontsize(), pd_ylables[i].get_fontsize())
168171
tm.assert_almost_equal(ylabels[i].get_rotation(), pd_ylables[i].get_rotation())
172+
173+
174+
def test_line(scalars_dfs):
175+
scalars_df, scalars_pandas_df = scalars_dfs
176+
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
177+
ax = scalars_df[col_names].plot.line()
178+
pd_ax = scalars_pandas_df[col_names].plot.line()
179+
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
180+
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
181+
for line, pd_line in zip(ax.lines, pd_ax.lines):
182+
# Compare y coordinates between the lines
183+
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])
184+
185+
186+
def test_area(scalars_dfs):
187+
scalars_df, scalars_pandas_df = scalars_dfs
188+
col_names = ["int64_col", "float64_col", "int64_too"]
189+
ax = scalars_df[col_names].plot.area(stacked=False)
190+
pd_ax = scalars_pandas_df[col_names].plot.area(stacked=False)
191+
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
192+
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
193+
for line, pd_line in zip(ax.lines, pd_ax.lines):
194+
# Compare y coordinates between the lines
195+
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])
196+
197+
198+
def test_scatter(scalars_dfs):
199+
scalars_df, scalars_pandas_df = scalars_dfs
200+
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
201+
ax = scalars_df[col_names].plot.scatter(x="int64_col", y="float64_col")
202+
pd_ax = scalars_pandas_df[col_names].plot.scatter(x="int64_col", y="float64_col")
203+
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
204+
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
205+
tm.assert_almost_equal(
206+
ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes()
207+
)
208+
209+
210+
def test_sampling_plot_args_n():
211+
df = bpd.DataFrame(np.arange(1000), columns=["one"])
212+
ax = df.plot.line()
213+
assert len(ax.lines) == 1
214+
# Default sampling_n is 100
215+
assert len(ax.lines[0].get_data()[1]) == 100
216+
217+
ax = df.plot.line(sampling_n=2)
218+
assert len(ax.lines) == 1
219+
assert len(ax.lines[0].get_data()[1]) == 2
220+
221+
222+
def test_sampling_plot_args_random_state():
223+
df = bpd.DataFrame(np.arange(1000), columns=["one"])
224+
ax_0 = df.plot.line()
225+
ax_1 = df.plot.line()
226+
ax_2 = df.plot.line(sampling_random_state=100)
227+
ax_3 = df.plot.line(sampling_random_state=100)
228+
229+
# Setting a fixed sampling_random_state guarantees reproducible plotted sampling.
230+
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_1.lines[0].get_data()[1])
231+
tm.assert_almost_equal(ax_2.lines[0].get_data()[1], ax_3.lines[0].get_data()[1])
232+
233+
msg = "numpy array are different"
234+
with pytest.raises(AssertionError, match=msg):
235+
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])

0 commit comments

Comments
 (0)