Skip to content

Commit b519197

Browse files
fix: fix broken multiindex loc cases (#467)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent f55680c commit b519197

File tree

5 files changed

+88
-105
lines changed

5 files changed

+88
-105
lines changed

bigframes/core/indexers.py

+57-100
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import List, Tuple, Union
18+
from typing import Tuple, Union
1919

2020
import ibis
2121
import pandas as pd
@@ -147,19 +147,22 @@ def __getitem__(
147147
...
148148

149149
def __getitem__(self, key):
150-
# TODO(swast): If the DataFrame has a MultiIndex, we'll need to
151-
# disambiguate this from a single row selection.
150+
# TODO(tbergeron): Pandas will try both splitting 2-tuple into row, index or as 2-part
151+
# row key. We must choose one, so bias towards treating as multi-part row label
152152
if isinstance(key, tuple) and len(key) == 2:
153-
df = typing.cast(
154-
bigframes.dataframe.DataFrame,
155-
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
156-
)
153+
is_row_multi_index = self._dataframe.index.nlevels > 1
154+
is_first_item_tuple = isinstance(key[0], tuple)
155+
if not is_row_multi_index or is_first_item_tuple:
156+
df = typing.cast(
157+
bigframes.dataframe.DataFrame,
158+
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
159+
)
157160

158-
columns = key[1]
159-
if isinstance(columns, pd.Series) and columns.dtype == "bool":
160-
columns = df.columns[columns]
161+
columns = key[1]
162+
if isinstance(columns, pd.Series) and columns.dtype == "bool":
163+
columns = df.columns[columns]
161164

162-
return df[columns]
165+
return df[columns]
163166

164167
return typing.cast(
165168
bigframes.dataframe.DataFrame,
@@ -283,94 +286,40 @@ def _loc_getitem_series_or_dataframe(
283286
pd.Series,
284287
bigframes.core.scalar.Scalar,
285288
]:
286-
if isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
287-
return series_or_dataframe[key]
288-
elif isinstance(key, bigframes.series.Series):
289-
temp_name = guid.generate_guid(prefix="temp_series_name_")
290-
if len(series_or_dataframe.index.names) > 1:
291-
temp_name = series_or_dataframe.index.names[0]
292-
key = key.rename(temp_name)
293-
keys_df = key.to_frame()
294-
keys_df = keys_df.set_index(temp_name, drop=True)
295-
return _perform_loc_list_join(series_or_dataframe, keys_df)
296-
elif isinstance(key, bigframes.core.indexes.Index):
297-
block = key._block
298-
block = block.select_columns(())
299-
keys_df = bigframes.dataframe.DataFrame(block)
300-
return _perform_loc_list_join(series_or_dataframe, keys_df)
301-
elif pd.api.types.is_list_like(key):
302-
key = typing.cast(List, key)
303-
if len(key) == 0:
304-
return typing.cast(
305-
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
306-
series_or_dataframe.iloc[0:0],
307-
)
308-
if pd.api.types.is_list_like(key[0]):
309-
original_index_names = series_or_dataframe.index.names
310-
num_index_cols = len(original_index_names)
311-
312-
entry_col_count_correct = [len(entry) == num_index_cols for entry in key]
313-
if not all(entry_col_count_correct):
314-
# pandas usually throws TypeError in these cases- tuple causes IndexError, but that
315-
# seems like unintended behavior
316-
raise TypeError(
317-
"All entries must be of equal length when indexing by list of listlikes"
318-
)
319-
temporary_index_names = [
320-
guid.generate_guid(prefix="temp_loc_index_")
321-
for _ in range(len(original_index_names))
322-
]
323-
index_cols_dict = {}
324-
for i in range(num_index_cols):
325-
index_name = temporary_index_names[i]
326-
values = [entry[i] for entry in key]
327-
index_cols_dict[index_name] = values
328-
keys_df = bigframes.dataframe.DataFrame(
329-
index_cols_dict, session=series_or_dataframe._get_block().expr.session
330-
)
331-
keys_df = keys_df.set_index(temporary_index_names, drop=True)
332-
keys_df = keys_df.rename_axis(original_index_names)
333-
else:
334-
# We can't upload a DataFrame with None as the column name, so set it
335-
# an arbitrary string.
336-
index_name = series_or_dataframe.index.name
337-
index_name_is_none = index_name is None
338-
if index_name_is_none:
339-
index_name = "unnamed_col"
340-
keys_df = bigframes.dataframe.DataFrame(
341-
{index_name: key},
342-
session=series_or_dataframe._get_block().expr.session,
343-
)
344-
keys_df = keys_df.set_index(index_name, drop=True)
345-
if index_name_is_none:
346-
keys_df.index.name = None
347-
return _perform_loc_list_join(series_or_dataframe, keys_df)
348-
elif isinstance(key, slice):
289+
if isinstance(key, slice):
349290
if (key.start is None) and (key.stop is None) and (key.step is None):
350291
return series_or_dataframe.copy()
351292
raise NotImplementedError(
352293
f"loc does not yet support indexing with a slice. {constants.FEEDBACK_LINK}"
353294
)
354-
elif callable(key):
295+
if callable(key):
355296
raise NotImplementedError(
356297
f"loc does not yet support indexing with a callable. {constants.FEEDBACK_LINK}"
357298
)
358-
elif pd.api.types.is_scalar(key):
359-
index_name = "unnamed_col"
360-
keys_df = bigframes.dataframe.DataFrame(
361-
{index_name: [key]}, session=series_or_dataframe._get_block().expr.session
362-
)
363-
keys_df = keys_df.set_index(index_name, drop=True)
364-
keys_df.index.name = None
365-
result = _perform_loc_list_join(series_or_dataframe, keys_df)
366-
pandas_result = result.to_pandas()
367-
# although loc[scalar_key] returns multiple results when scalar_key
368-
# is not unique, we download the results here and return the computed
369-
# individual result (as a scalar or pandas series) when the key is unique,
370-
# since we expect unique index keys to be more common. loc[[scalar_key]]
371-
# can be used to retrieve one-item DataFrames or Series.
372-
if len(pandas_result) == 1:
373-
return pandas_result.iloc[0]
299+
elif isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
300+
return series_or_dataframe[key]
301+
elif (
302+
isinstance(key, bigframes.series.Series)
303+
or isinstance(key, indexes.Index)
304+
or (pd.api.types.is_list_like(key) and not isinstance(key, tuple))
305+
):
306+
index = indexes.Index(key, session=series_or_dataframe._session)
307+
index.names = series_or_dataframe.index.names[: index.nlevels]
308+
return _perform_loc_list_join(series_or_dataframe, index)
309+
elif pd.api.types.is_scalar(key) or isinstance(key, tuple):
310+
index = indexes.Index([key], session=series_or_dataframe._session)
311+
index.names = series_or_dataframe.index.names[: index.nlevels]
312+
result = _perform_loc_list_join(series_or_dataframe, index, drop_levels=True)
313+
314+
if index.nlevels == series_or_dataframe.index.nlevels:
315+
pandas_result = result.to_pandas()
316+
# although loc[scalar_key] returns multiple results when scalar_key
317+
# is not unique, we download the results here and return the computed
318+
# individual result (as a scalar or pandas series) when the key is unique,
319+
# since we expect unique index keys to be more common. loc[[scalar_key]]
320+
# can be used to retrieve one-item DataFrames or Series.
321+
if len(pandas_result) == 1:
322+
return pandas_result.iloc[0]
374323
# when the key is not unique, we return a bigframes data type
375324
# as usual for methods that return dataframes/series
376325
return result
@@ -385,39 +334,47 @@ def _loc_getitem_series_or_dataframe(
385334
@typing.overload
386335
def _perform_loc_list_join(
387336
series_or_dataframe: bigframes.series.Series,
388-
keys_df: bigframes.dataframe.DataFrame,
337+
keys_index: indexes.Index,
338+
drop_levels: bool = False,
389339
) -> bigframes.series.Series:
390340
...
391341

392342

393343
@typing.overload
394344
def _perform_loc_list_join(
395345
series_or_dataframe: bigframes.dataframe.DataFrame,
396-
keys_df: bigframes.dataframe.DataFrame,
346+
keys_index: indexes.Index,
347+
drop_levels: bool = False,
397348
) -> bigframes.dataframe.DataFrame:
398349
...
399350

400351

401352
def _perform_loc_list_join(
402353
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
403-
keys_df: bigframes.dataframe.DataFrame,
354+
keys_index: indexes.Index,
355+
drop_levels: bool = False,
404356
) -> Union[bigframes.series.Series, bigframes.dataframe.DataFrame]:
405357
# right join based on the old index so that the matching rows from the user's
406358
# original dataframe will be duplicated and reordered appropriately
407-
original_index_names = series_or_dataframe.index.names
408359
if isinstance(series_or_dataframe, bigframes.series.Series):
409360
original_name = series_or_dataframe.name
410361
name = series_or_dataframe.name if series_or_dataframe.name is not None else "0"
411362
result = typing.cast(
412363
bigframes.series.Series,
413-
series_or_dataframe.to_frame()._perform_join_by_index(keys_df, how="right")[
414-
name
415-
],
364+
series_or_dataframe.to_frame()._perform_join_by_index(
365+
keys_index, how="right"
366+
)[name],
416367
)
417368
result = result.rename(original_name)
418369
else:
419-
result = series_or_dataframe._perform_join_by_index(keys_df, how="right") # type: ignore
420-
result = result.rename_axis(original_index_names)
370+
result = series_or_dataframe._perform_join_by_index(keys_index, how="right") # type: ignore
371+
372+
if drop_levels and series_or_dataframe.index.nlevels > keys_index.nlevels:
373+
# drop common levels
374+
levels_to_drop = [
375+
name for name in series_or_dataframe.index.names if name in keys_index.names
376+
]
377+
result = result.droplevel(levels_to_drop) # type: ignore
421378
return result
422379

423380

bigframes/core/indexes/index.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
dtype=None,
5050
*,
5151
name=None,
52+
session=None,
5253
):
5354
import bigframes.dataframe as df
5455
import bigframes.series as series
@@ -75,7 +76,7 @@ def __init__(
7576
else:
7677
pd_index = pandas.Index(data=data, dtype=dtype, name=name)
7778
pd_df = pandas.DataFrame(index=pd_index)
78-
block = df.DataFrame(pd_df)._block
79+
block = df.DataFrame(pd_df, session=session)._block
7980
self._query_job = None
8081
self._block: blocks.Block = block
8182

bigframes/dataframe.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2310,7 +2310,9 @@ def join(
23102310

23112311
return left._perform_join_by_index(right, how=how)
23122312

2313-
def _perform_join_by_index(self, other: DataFrame, *, how: str = "left"):
2313+
def _perform_join_by_index(
2314+
self, other: Union[DataFrame, indexes.Index], *, how: str = "left"
2315+
):
23142316
block, _ = self._block.join(other._block, how=how, block_identity_join=True)
23152317
return DataFrame(block)
23162318

bigframes/series.py

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def T(self) -> Series:
151151
def _info_axis(self) -> indexes.Index:
152152
return self.index
153153

154+
@property
155+
def _session(self) -> bigframes.Session:
156+
return self._get_block().expr.session
157+
154158
def transpose(self) -> Series:
155159
return self
156160

tests/system/small/test_multiindex.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,34 @@ def test_concat_multi_indices_ignore_index(scalars_df_index, scalars_pandas_df_i
169169
pandas.testing.assert_frame_equal(bf_result.to_pandas(), pd_result)
170170

171171

172-
def test_multi_index_loc(scalars_df_index, scalars_pandas_df_index):
172+
@pytest.mark.parametrize(
173+
("key"),
174+
[
175+
(2),
176+
([2, 0]),
177+
([(2, "capitalize, This "), (-2345, "Hello, World!")]),
178+
],
179+
)
180+
def test_multi_index_loc_multi_row(scalars_df_index, scalars_pandas_df_index, key):
173181
bf_result = (
174-
scalars_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]].to_pandas()
182+
scalars_df_index.set_index(["int64_too", "string_col"]).loc[key].to_pandas()
175183
)
176-
pd_result = scalars_pandas_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]]
184+
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[key]
177185

178186
pandas.testing.assert_frame_equal(bf_result, pd_result)
179187

180188

189+
def test_multi_index_loc_single_row(scalars_df_index, scalars_pandas_df_index):
190+
bf_result = scalars_df_index.set_index(["int64_too", "string_col"]).loc[
191+
(2, "capitalize, This ")
192+
]
193+
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[
194+
(2, "capitalize, This ")
195+
]
196+
197+
pandas.testing.assert_series_equal(bf_result, pd_result)
198+
199+
181200
def test_multi_index_getitem_bool(scalars_df_index, scalars_pandas_df_index):
182201
bf_frame = scalars_df_index.set_index(["int64_too", "bool_col"])
183202
pd_frame = scalars_pandas_df_index.set_index(["int64_too", "bool_col"])

0 commit comments

Comments
 (0)