Skip to content

Commit 1c3e668

Browse files
authored
feat: (Series|DataFrame).plot (#438)
1 parent 91bd39e commit 1c3e668

File tree

3 files changed

+93
-30
lines changed

3 files changed

+93
-30
lines changed

bigframes/operations/plotting.py

+28-29
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,45 @@
2323
class PlotAccessor(vendordt.PlotAccessor):
2424
__doc__ = vendordt.PlotAccessor.__doc__
2525

26+
_common_kinds = ("line", "area", "hist")
27+
_dataframe_kinds = ("scatter",)
28+
_all_kinds = _common_kinds + _dataframe_kinds
29+
30+
def __call__(self, **kwargs):
31+
import bigframes.series as series
32+
33+
if kwargs.pop("backend", None) is not None:
34+
raise NotImplementedError(
35+
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
36+
)
37+
38+
kind = kwargs.pop("kind", "line")
39+
if kind not in self._all_kinds:
40+
raise NotImplementedError(
41+
f"{kind} is not a valid plot kind supported for now. {constants.FEEDBACK_LINK}"
42+
)
43+
44+
data = self._parent.copy()
45+
if kind in self._dataframe_kinds and isinstance(data, series.Series):
46+
raise ValueError(f"plot kind {kind} can only be used for data frames")
47+
48+
return bfplt.plot(data, kind=kind, **kwargs)
49+
2650
def __init__(self, data) -> None:
2751
self._parent = data
2852

2953
def hist(
3054
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs
3155
):
32-
if kwargs.pop("backend", None) is not None:
33-
raise NotImplementedError(
34-
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
35-
)
36-
return bfplt.plot(self._parent.copy(), kind="hist", by=by, bins=bins, **kwargs)
56+
return self(kind="hist", by=by, bins=bins, **kwargs)
3757

3858
def line(
3959
self,
4060
x: typing.Optional[typing.Hashable] = None,
4161
y: typing.Optional[typing.Hashable] = None,
4262
**kwargs,
4363
):
44-
return bfplt.plot(
45-
self._parent.copy(),
46-
kind="line",
47-
x=x,
48-
y=y,
49-
**kwargs,
50-
)
64+
return self(kind="line", x=x, y=y, **kwargs)
5165

5266
def area(
5367
self,
@@ -56,14 +70,7 @@ def area(
5670
stacked: bool = True,
5771
**kwargs,
5872
):
59-
return bfplt.plot(
60-
self._parent.copy(),
61-
kind="area",
62-
x=x,
63-
y=y,
64-
stacked=stacked,
65-
**kwargs,
66-
)
73+
return self(kind="area", x=x, y=y, stacked=stacked, **kwargs)
6774

6875
def scatter(
6976
self,
@@ -73,12 +80,4 @@ def scatter(
7380
c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
7481
**kwargs,
7582
):
76-
return bfplt.plot(
77-
self._parent.copy(),
78-
kind="scatter",
79-
x=x,
80-
y=y,
81-
s=s,
82-
c=c,
83-
**kwargs,
84-
)
83+
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwargs)

tests/system/small/operations/test_plotting.py

+28
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,31 @@ def test_sampling_plot_args_random_state():
233233
msg = "numpy array are different"
234234
with pytest.raises(AssertionError, match=msg):
235235
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])
236+
237+
238+
@pytest.mark.parametrize(
239+
("kind", "col_names", "kwargs"),
240+
[
241+
pytest.param("hist", ["int64_col", "int64_too"], {}),
242+
pytest.param("line", ["int64_col", "int64_too"], {}),
243+
pytest.param("area", ["int64_col", "int64_too"], {"stacked": False}),
244+
pytest.param(
245+
"scatter", ["int64_col", "int64_too"], {"x": "int64_col", "y": "int64_too"}
246+
),
247+
pytest.param(
248+
"scatter",
249+
["int64_col"],
250+
{},
251+
marks=pytest.mark.xfail(raises=ValueError),
252+
),
253+
pytest.param(
254+
"uknown",
255+
["int64_col", "int64_too"],
256+
{},
257+
marks=pytest.mark.xfail(raises=NotImplementedError),
258+
),
259+
],
260+
)
261+
def test_plot_call(scalars_dfs, kind, col_names, kwargs):
262+
scalars_df, _ = scalars_dfs
263+
scalars_df[col_names].plot(kind=kind, **kwargs)

third_party/bigframes_vendored/pandas/plotting/_core.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,43 @@
44

55

66
class PlotAccessor:
7-
"""Make plots of Series or DataFrame with the `matplotlib` backend."""
7+
"""
8+
Make plots of Series or DataFrame with the `matplotlib` backend.
9+
10+
**Examples:**
11+
For Series:
12+
13+
>>> import bigframes.pandas as bpd
14+
>>> ser = bpd.Series([1, 2, 3, 3])
15+
>>> plot = ser.plot(kind='hist', title="My plot")
16+
17+
For DataFrame:
18+
19+
>>> df = bpd.DataFrame({'length': [1.5, 0.5, 1.2, 0.9, 3],
20+
... 'width': [0.7, 0.2, 0.15, 0.2, 1.1]},
21+
... index=['pig', 'rabbit', 'duck', 'chicken', 'horse'])
22+
>>> plot = df.plot(title="DataFrame Plot")
23+
24+
Args:
25+
data (Series or DataFrame):
26+
The object for which the method is called.
27+
kind (str):
28+
The kind of plot to produce:
29+
30+
- 'line' : line plot (default)
31+
- 'hist' : histogram
32+
- 'area' : area plot
33+
- 'scatter' : scatter plot (DataFrame only)
34+
35+
**kwargs:
36+
Options to pass to `pandas.DataFrame.plot` method. See pandas
37+
documentation online for more on these arguments.
38+
39+
Returns:
40+
matplotlib.axes.Axes or np.ndarray of them:
41+
An ndarray is returned with one :class:`matplotlib.axes.Axes`
42+
per column when ``subplots=True``.
43+
"""
844

945
def hist(
1046
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs

0 commit comments

Comments
 (0)