Skip to content

Commit 860f866

Browse files
author
James Melvin
committed
feature: added standalone script to update messages which dont have an embedding or toxicity
1 parent 736ea30 commit 860f866

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

backend/oasst_backend/celery_worker/scheduled_tasks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def toxicity(text, message_id, api_client):
6363
pr.insert_toxicity(
6464
message_id=message_id, model=model_name, score=toxicity["score"], label=toxicity["label"]
6565
)
66+
session.commit()
6667

6768
except Exception as e:
6869
logger.error(f"Could not compute toxicity for text reply to {message_id=} with {text=} by.error {str(e)}")
@@ -82,6 +83,7 @@ def hf_feature_extraction(text, message_id, api_client):
8283
pr.insert_message_embedding(
8384
message_id=message_id, model=HfEmbeddingModel.MINILM.value, embedding=embedding
8485
)
86+
session.commit()
8587

8688
except Exception as e:
8789
logger.error(f"Could not extract embedding for text reply to {message_id=} with {text=} by.error {str(e)}")

backend/update_message_attributes.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import time
2+
3+
from loguru import logger
4+
from oasst_backend.celery_worker.scheduled_tasks import hf_feature_extraction, toxicity
5+
from oasst_backend.models import ApiClient, Message
6+
from oasst_backend.utils.database_utils import default_session_factory
7+
from sqlmodel import text
8+
9+
# from sqlmodel import Session, select
10+
11+
12+
def get_messageids_without_toxicity():
13+
message_ids = None
14+
with default_session_factory() as session:
15+
sql = """
16+
SELECT m.id FROM message as m
17+
left join message_toxicity mt on mt.message_id = m.id
18+
where mt.message_id is NULL
19+
"""
20+
result = session.execute(
21+
text(sql),
22+
).all()
23+
message_ids = []
24+
for row in result:
25+
message_id = row[0]
26+
message_ids.append(message_id)
27+
return message_ids
28+
29+
30+
def get_messageids_without_embedding():
31+
message_ids = None
32+
with default_session_factory() as session:
33+
sql = """
34+
SELECT m.id FROM message as m
35+
left join message_embedding mt on mt.message_id = m.id
36+
where mt.message_id is NULL
37+
"""
38+
result = session.execute(
39+
text(sql),
40+
).all()
41+
message_ids = []
42+
for row in result:
43+
message_id = row[0]
44+
message_ids.append(message_id)
45+
return message_ids
46+
47+
48+
def find_and_update_embeddings(message_ids):
49+
try:
50+
with default_session_factory() as session:
51+
for message_id in message_ids:
52+
result = session.query(Message).filter(Message.id == message_id).first()
53+
if result is not None:
54+
api_client_id = result.api_client_id
55+
text = result.payload.payload.text
56+
api_client = session.query(ApiClient).filter(ApiClient.id == api_client_id).first()
57+
if api_client is not None and text is not None:
58+
hf_feature_extraction(text=text, message_id=message_id, api_client=api_client.__dict__)
59+
# to not get rate limited from HF
60+
time.sleep(10)
61+
except Exception as e:
62+
logger.error(str(e))
63+
logger.debug("Done: find_and_update_embeddings")
64+
65+
66+
def find_and_update_toxicity(message_ids):
67+
try:
68+
with default_session_factory() as session:
69+
for message_id in message_ids:
70+
result = session.query(Message).filter(Message.id == message_id).first()
71+
if result is not None:
72+
api_client_id = result.api_client_id
73+
text = result.payload.payload.text
74+
api_client = session.query(ApiClient).filter(ApiClient.id == api_client_id).first()
75+
if api_client is not None and text is not None:
76+
toxicity(text=text, message_id=message_id, api_client=api_client.__dict__)
77+
# to not get rate limited from HF
78+
time.sleep(10)
79+
except Exception as e:
80+
logger.error(str(e))
81+
logger.debug("Done: find_and_update_toxicity")
82+
83+
84+
def main():
85+
message_ids = get_messageids_without_toxicity()
86+
find_and_update_toxicity(message_ids=message_ids)
87+
message_ids = get_messageids_without_embedding()
88+
find_and_update_embeddings(message_ids=message_ids)
89+
return
90+
91+
92+
if __name__ == "__main__":
93+
main()

0 commit comments

Comments
 (0)