Skip to content

fix: fix broken multiindex loc cases #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 57 additions & 100 deletions bigframes/core/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import typing
from typing import List, Tuple, Union
from typing import Tuple, Union

import ibis
import pandas as pd
Expand Down Expand Up @@ -147,19 +147,22 @@ def __getitem__(
...

def __getitem__(self, key):
# TODO(swast): If the DataFrame has a MultiIndex, we'll need to
# disambiguate this from a single row selection.
# TODO(tbergeron): Pandas will try both splitting 2-tuple into row, index or as 2-part
# row key. We must choose one, so bias towards treating as multi-part row label
if isinstance(key, tuple) and len(key) == 2:
df = typing.cast(
bigframes.dataframe.DataFrame,
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
)
is_row_multi_index = self._dataframe.index.nlevels > 1
is_first_item_tuple = isinstance(key[0], tuple)
if not is_row_multi_index or is_first_item_tuple:
df = typing.cast(
bigframes.dataframe.DataFrame,
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
)

columns = key[1]
if isinstance(columns, pd.Series) and columns.dtype == "bool":
columns = df.columns[columns]
columns = key[1]
if isinstance(columns, pd.Series) and columns.dtype == "bool":
columns = df.columns[columns]

return df[columns]
return df[columns]

return typing.cast(
bigframes.dataframe.DataFrame,
Expand Down Expand Up @@ -283,94 +286,40 @@ def _loc_getitem_series_or_dataframe(
pd.Series,
bigframes.core.scalar.Scalar,
]:
if isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
return series_or_dataframe[key]
elif isinstance(key, bigframes.series.Series):
temp_name = guid.generate_guid(prefix="temp_series_name_")
if len(series_or_dataframe.index.names) > 1:
temp_name = series_or_dataframe.index.names[0]
key = key.rename(temp_name)
keys_df = key.to_frame()
keys_df = keys_df.set_index(temp_name, drop=True)
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif isinstance(key, bigframes.core.indexes.Index):
block = key._block
block = block.select_columns(())
keys_df = bigframes.dataframe.DataFrame(block)
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif pd.api.types.is_list_like(key):
key = typing.cast(List, key)
if len(key) == 0:
return typing.cast(
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
series_or_dataframe.iloc[0:0],
)
if pd.api.types.is_list_like(key[0]):
original_index_names = series_or_dataframe.index.names
num_index_cols = len(original_index_names)

entry_col_count_correct = [len(entry) == num_index_cols for entry in key]
if not all(entry_col_count_correct):
# pandas usually throws TypeError in these cases- tuple causes IndexError, but that
# seems like unintended behavior
raise TypeError(
"All entries must be of equal length when indexing by list of listlikes"
)
temporary_index_names = [
guid.generate_guid(prefix="temp_loc_index_")
for _ in range(len(original_index_names))
]
index_cols_dict = {}
for i in range(num_index_cols):
index_name = temporary_index_names[i]
values = [entry[i] for entry in key]
index_cols_dict[index_name] = values
keys_df = bigframes.dataframe.DataFrame(
index_cols_dict, session=series_or_dataframe._get_block().expr.session
)
keys_df = keys_df.set_index(temporary_index_names, drop=True)
keys_df = keys_df.rename_axis(original_index_names)
else:
# We can't upload a DataFrame with None as the column name, so set it
# an arbitrary string.
index_name = series_or_dataframe.index.name
index_name_is_none = index_name is None
if index_name_is_none:
index_name = "unnamed_col"
keys_df = bigframes.dataframe.DataFrame(
{index_name: key},
session=series_or_dataframe._get_block().expr.session,
)
keys_df = keys_df.set_index(index_name, drop=True)
if index_name_is_none:
keys_df.index.name = None
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif isinstance(key, slice):
if isinstance(key, slice):
if (key.start is None) and (key.stop is None) and (key.step is None):
return series_or_dataframe.copy()
raise NotImplementedError(
f"loc does not yet support indexing with a slice. {constants.FEEDBACK_LINK}"
)
elif callable(key):
if callable(key):
raise NotImplementedError(
f"loc does not yet support indexing with a callable. {constants.FEEDBACK_LINK}"
)
elif pd.api.types.is_scalar(key):
index_name = "unnamed_col"
keys_df = bigframes.dataframe.DataFrame(
{index_name: [key]}, session=series_or_dataframe._get_block().expr.session
)
keys_df = keys_df.set_index(index_name, drop=True)
keys_df.index.name = None
result = _perform_loc_list_join(series_or_dataframe, keys_df)
pandas_result = result.to_pandas()
# although loc[scalar_key] returns multiple results when scalar_key
# is not unique, we download the results here and return the computed
# individual result (as a scalar or pandas series) when the key is unique,
# since we expect unique index keys to be more common. loc[[scalar_key]]
# can be used to retrieve one-item DataFrames or Series.
if len(pandas_result) == 1:
return pandas_result.iloc[0]
elif isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
return series_or_dataframe[key]
elif (
isinstance(key, bigframes.series.Series)
or isinstance(key, indexes.Index)
or (pd.api.types.is_list_like(key) and not isinstance(key, tuple))
):
index = indexes.Index(key, session=series_or_dataframe._session)
index.names = series_or_dataframe.index.names[: index.nlevels]
return _perform_loc_list_join(series_or_dataframe, index)
elif pd.api.types.is_scalar(key) or isinstance(key, tuple):
index = indexes.Index([key], session=series_or_dataframe._session)
index.names = series_or_dataframe.index.names[: index.nlevels]
result = _perform_loc_list_join(series_or_dataframe, index, drop_levels=True)

if index.nlevels == series_or_dataframe.index.nlevels:
pandas_result = result.to_pandas()
# although loc[scalar_key] returns multiple results when scalar_key
# is not unique, we download the results here and return the computed
# individual result (as a scalar or pandas series) when the key is unique,
# since we expect unique index keys to be more common. loc[[scalar_key]]
# can be used to retrieve one-item DataFrames or Series.
if len(pandas_result) == 1:
return pandas_result.iloc[0]
# when the key is not unique, we return a bigframes data type
# as usual for methods that return dataframes/series
return result
Expand All @@ -385,39 +334,47 @@ def _loc_getitem_series_or_dataframe(
@typing.overload
def _perform_loc_list_join(
series_or_dataframe: bigframes.series.Series,
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> bigframes.series.Series:
...


@typing.overload
def _perform_loc_list_join(
series_or_dataframe: bigframes.dataframe.DataFrame,
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> bigframes.dataframe.DataFrame:
...


def _perform_loc_list_join(
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> Union[bigframes.series.Series, bigframes.dataframe.DataFrame]:
# right join based on the old index so that the matching rows from the user's
# original dataframe will be duplicated and reordered appropriately
original_index_names = series_or_dataframe.index.names
if isinstance(series_or_dataframe, bigframes.series.Series):
original_name = series_or_dataframe.name
name = series_or_dataframe.name if series_or_dataframe.name is not None else "0"
result = typing.cast(
bigframes.series.Series,
series_or_dataframe.to_frame()._perform_join_by_index(keys_df, how="right")[
name
],
series_or_dataframe.to_frame()._perform_join_by_index(
keys_index, how="right"
)[name],
)
result = result.rename(original_name)
else:
result = series_or_dataframe._perform_join_by_index(keys_df, how="right") # type: ignore
result = result.rename_axis(original_index_names)
result = series_or_dataframe._perform_join_by_index(keys_index, how="right") # type: ignore

if drop_levels and series_or_dataframe.index.nlevels > keys_index.nlevels:
# drop common levels
levels_to_drop = [
name for name in series_or_dataframe.index.names if name in keys_index.names
]
result = result.droplevel(levels_to_drop) # type: ignore
return result


Expand Down
3 changes: 2 additions & 1 deletion bigframes/core/indexes/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
dtype=None,
*,
name=None,
session=None,
):
import bigframes.dataframe as df
import bigframes.series as series
Expand All @@ -75,7 +76,7 @@ def __init__(
else:
pd_index = pandas.Index(data=data, dtype=dtype, name=name)
pd_df = pandas.DataFrame(index=pd_index)
block = df.DataFrame(pd_df)._block
block = df.DataFrame(pd_df, session=session)._block
self._query_job = None
self._block: blocks.Block = block

Expand Down
4 changes: 3 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,7 +2310,9 @@ def join(

return left._perform_join_by_index(right, how=how)

def _perform_join_by_index(self, other: DataFrame, *, how: str = "left"):
def _perform_join_by_index(
self, other: Union[DataFrame, indexes.Index], *, how: str = "left"
):
block, _ = self._block.join(other._block, how=how, block_identity_join=True)
return DataFrame(block)

Expand Down
4 changes: 4 additions & 0 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def T(self) -> Series:
def _info_axis(self) -> indexes.Index:
return self.index

@property
def _session(self) -> bigframes.Session:
return self._get_block().expr.session

def transpose(self) -> Series:
return self

Expand Down
25 changes: 22 additions & 3 deletions tests/system/small/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,34 @@ def test_concat_multi_indices_ignore_index(scalars_df_index, scalars_pandas_df_i
pandas.testing.assert_frame_equal(bf_result.to_pandas(), pd_result)


def test_multi_index_loc(scalars_df_index, scalars_pandas_df_index):
@pytest.mark.parametrize(
("key"),
[
(2),
([2, 0]),
([(2, "capitalize, This "), (-2345, "Hello, World!")]),
],
)
def test_multi_index_loc_multi_row(scalars_df_index, scalars_pandas_df_index, key):
bf_result = (
scalars_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]].to_pandas()
scalars_df_index.set_index(["int64_too", "string_col"]).loc[key].to_pandas()
)
pd_result = scalars_pandas_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]]
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[key]

pandas.testing.assert_frame_equal(bf_result, pd_result)


def test_multi_index_loc_single_row(scalars_df_index, scalars_pandas_df_index):
bf_result = scalars_df_index.set_index(["int64_too", "string_col"]).loc[
(2, "capitalize, This ")
]
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[
(2, "capitalize, This ")
]

pandas.testing.assert_series_equal(bf_result, pd_result)


def test_multi_index_getitem_bool(scalars_df_index, scalars_pandas_df_index):
bf_frame = scalars_df_index.set_index(["int64_too", "bool_col"])
pd_frame = scalars_pandas_df_index.set_index(["int64_too", "bool_col"])
Expand Down