Skip to content

Commit c5eccad

Browse files
authored
chore: implement semantics topk (#1072)
* chore: implement semantics topk * rename top_k * fix mypy * fix lint
1 parent 8821dd4 commit c5eccad

File tree

3 files changed

+487
-15
lines changed

3 files changed

+487
-15
lines changed

bigframes/operations/semantics.py

+176-9
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@
1717
import typing
1818
from typing import List, Optional
1919

20-
import bigframes
21-
import bigframes.core.guid
20+
import numpy as np
21+
22+
import bigframes.core.guid as guid
2223
import bigframes.dtypes as dtypes
2324

2425

2526
class Semantics:
2627
def __init__(self, df) -> None:
28+
import bigframes
29+
import bigframes.dataframe
30+
2731
if not bigframes.options.experiments.semantic_operators:
2832
raise NotImplementedError()
2933

30-
self._df = df
34+
self._df: bigframes.dataframe.DataFrame = df
3135

3236
def agg(
3337
self,
@@ -130,15 +134,15 @@ def agg(
130134
f"{type(df[cluster_column])}"
131135
)
132136

133-
num_cluster = len(df[cluster_column].unique())
137+
num_cluster = df[cluster_column].unique().shape[0]
134138
df = df.sort_values(cluster_column)
135139
else:
136-
cluster_column = bigframes.core.guid.generate_guid("pid")
140+
cluster_column = guid.generate_guid("pid")
137141
df[cluster_column] = 0
138142

139-
aggregation_group_id = bigframes.core.guid.generate_guid("agg")
140-
group_row_index = bigframes.core.guid.generate_guid("gid")
141-
llm_prompt = bigframes.core.guid.generate_guid("prompt")
143+
aggregation_group_id = guid.generate_guid("agg")
144+
group_row_index = guid.generate_guid("gid")
145+
llm_prompt = guid.generate_guid("prompt")
142146
df = (
143147
df.reset_index(drop=True)
144148
.reset_index()
@@ -609,6 +613,169 @@ def search(
609613

610614
return typing.cast(bigframes.dataframe.DataFrame, search_result)
611615

616+
def top_k(self, instruction: str, model, k=10):
617+
"""
618+
Ranks each tuple and returns the k best according to the instruction.
619+
620+
This method employs a quick select algorithm to efficiently compare the pivot
621+
with all other items. By leveraging an LLM (Large Language Model), it then
622+
identifies the top 'k' best answers from these comparisons.
623+
624+
**Examples:**
625+
626+
>>> import bigframes.pandas as bpd
627+
>>> bpd.options.display.progress_bar = None
628+
>>> bpd.options.experiments.semantic_operators = True
629+
630+
>>> import bigframes.ml.llm as llm
631+
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
632+
633+
>>> df = bpd.DataFrame({"Animals": ["Dog", "Bird", "Cat", "Horse"]})
634+
>>> df.semantics.top_k("{Animals} are more popular as pets", model=model, k=2)
635+
Animals
636+
0 Dog
637+
2 Cat
638+
<BLANKLINE>
639+
[2 rows x 1 columns]
640+
641+
Args:
642+
instruction (str):
643+
An instruction on how to map the data. This value must contain
644+
column references by name enclosed in braces.
645+
For example, to reference a column named "Animals", use "{Animals}" in the
646+
instruction, like: "{Animals} are more popular as pets"
647+
648+
model (bigframes.ml.llm.GeminiTextGenerator):
649+
A GeminiTextGenerator provided by the Bigframes ML package.
650+
651+
k (int, default 10):
652+
The number of rows to return.
653+
654+
Returns:
655+
bigframes.dataframe.DataFrame: A new DataFrame with the top k rows.
656+
657+
Raises:
658+
NotImplementedError: when the semantic operator experiment is off.
659+
ValueError: when the instruction refers to a non-existing column, or when no
660+
columns are referred to.
661+
"""
662+
self._validate_model(model)
663+
columns = self._parse_columns(instruction)
664+
for column in columns:
665+
if column not in self._df.columns:
666+
raise ValueError(f"Column {column} not found.")
667+
if len(columns) > 1:
668+
raise NotImplementedError(
669+
"Semantic aggregations are limited to a single column."
670+
)
671+
column = columns[0]
672+
if self._df[column].dtype != dtypes.STRING_DTYPE:
673+
raise TypeError(
674+
"Referred column must be a string type, not "
675+
f"{type(self._df[column])}"
676+
)
677+
# `index` is reserved for the `reset_index` below.
678+
if column == "index":
679+
raise ValueError(
680+
"Column name 'index' is reserved. Please choose a different name."
681+
)
682+
683+
if k < 1:
684+
raise ValueError("k must be an integer greater than or equal to 1.")
685+
686+
user_instruction = self._format_instruction(instruction, columns)
687+
688+
import bigframes.dataframe
689+
import bigframes.series
690+
691+
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
692+
n = df.shape[0]
693+
694+
if k >= n:
695+
return df
696+
697+
# Create a unique index and duplicate it as the "index" column. This workaround
698+
# is needed for the select search algorithm due to unimplemented bigFrame methods.
699+
df = df.reset_index().rename(columns={"index": "old_index"}).reset_index()
700+
701+
# Initialize a status column to track the selection status of each item.
702+
# - None: Unknown/not yet processed
703+
# - 1.0: Selected as part of the top-k items
704+
# - -1.0: Excluded from the top-k items
705+
status_column = guid.generate_guid("status")
706+
df[status_column] = bigframes.series.Series(None, dtype=dtypes.FLOAT_DTYPE)
707+
708+
num_selected = 0
709+
while num_selected < k:
710+
df, num_new_selected = self._topk_partition(
711+
df,
712+
column,
713+
status_column,
714+
user_instruction,
715+
model,
716+
k - num_selected,
717+
)
718+
num_selected += num_new_selected
719+
720+
df = (
721+
df[df[status_column] > 0]
722+
.drop(["index", status_column], axis=1)
723+
.rename(columns={"old_index": "index"})
724+
.set_index("index")
725+
)
726+
df.index.name = None
727+
return df
728+
729+
@staticmethod
730+
def _topk_partition(
731+
df, column: str, status_column: str, user_instruction: str, model, k
732+
):
733+
output_instruction = (
734+
"Given a question and two documents, choose the document that best answers "
735+
"the question. Respond with 'Document 1' or 'Document 2'. You must choose "
736+
"one, even if neither is ideal. "
737+
)
738+
739+
# Random pivot selection for improved average quickselect performance.
740+
pending_df = df[df[status_column].isna()]
741+
pivot_iloc = np.random.randint(0, pending_df.shape[0] - 1)
742+
pivot_index = pending_df.iloc[pivot_iloc]["index"]
743+
pivot_df = pending_df[pending_df["index"] == pivot_index]
744+
745+
# Build a prompt to compare the pivot item's relevance to other pending items.
746+
prompt_s = pending_df[pending_df["index"] != pivot_index][column]
747+
prompt_s = (
748+
f"{output_instruction}\n\nQuestion: {user_instruction}\n"
749+
+ "\nDocument 1: "
750+
+ pivot_df.iloc[0][column]
751+
+ "\nDocument 2: "
752+
+ prompt_s # type:ignore
753+
)
754+
755+
import bigframes.dataframe
756+
757+
predict_df = typing.cast(bigframes.dataframe.DataFrame, model.predict(prompt_s))
758+
759+
marks = predict_df["ml_generate_text_llm_result"].str.contains("2")
760+
more_relavant: bigframes.dataframe.DataFrame = df[marks]
761+
less_relavent: bigframes.dataframe.DataFrame = df[~marks]
762+
763+
num_more_relavant = more_relavant.shape[0]
764+
if k < num_more_relavant:
765+
less_relavent[status_column] = -1.0
766+
pivot_df[status_column] = -1.0
767+
df = df.combine_first(less_relavent).combine_first(pivot_df)
768+
return df, 0
769+
else: # k >= num_more_relavant
770+
more_relavant[status_column] = 1.0
771+
df = df.combine_first(more_relavant)
772+
if k >= num_more_relavant + 1:
773+
pivot_df[status_column] = 1.0
774+
df = df.combine_first(pivot_df)
775+
return df, num_more_relavant + 1
776+
else:
777+
return df, num_more_relavant
778+
612779
def sim_join(
613780
self,
614781
other,
@@ -688,7 +855,7 @@ def sim_join(
688855
f"Number of rows that need processing is {joined_table_rows}, which exceeds row limit {max_rows}."
689856
)
690857

691-
base_table_embedding_column = bigframes.core.guid.generate_guid()
858+
base_table_embedding_column = guid.generate_guid()
692859
base_table = self._attach_embedding(
693860
other, right_on, base_table_embedding_column, model
694861
).to_gbq()

0 commit comments

Comments
 (0)