forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate_message_attributes.py
91 lines (78 loc) · 3.21 KB
/
update_message_attributes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
from loguru import logger
from oasst_backend.models import ApiClient, Message
from oasst_backend.scheduled_tasks import hf_feature_extraction, toxicity
from oasst_backend.utils.database_utils import default_session_factory
from sqlmodel import text
def get_messageids_without_toxicity():
message_ids = None
with default_session_factory() as session:
sql = """
SELECT m.id FROM message as m
left join message_toxicity mt on mt.message_id = m.id
where mt.message_id is NULL
"""
result = session.execute(
text(sql),
).all()
message_ids = []
for row in result:
message_id = row[0]
message_ids.append(message_id)
return message_ids
def get_messageids_without_embedding():
message_ids = None
with default_session_factory() as session:
sql = """
SELECT m.id FROM message as m
left join message_embedding mt on mt.message_id = m.id
where mt.message_id is NULL
"""
result = session.execute(
text(sql),
).all()
message_ids = []
for row in result:
message_id = row[0]
message_ids.append(message_id)
return message_ids
def find_and_update_embeddings(message_ids):
try:
with default_session_factory() as session:
for message_id in message_ids:
result = session.query(Message).filter(Message.id == message_id).first()
if result is not None:
api_client_id = result.api_client_id
text = result.payload.payload.text
api_client = session.query(ApiClient).filter(ApiClient.id == api_client_id).first()
if api_client is not None and text is not None:
hf_feature_extraction(text=text, message_id=message_id, api_client=api_client.__dict__)
# to not get rate limited from HF
time.sleep(10)
except Exception as e:
logger.error(str(e))
logger.debug("Done: find_and_update_embeddings")
def find_and_update_toxicity(message_ids):
try:
with default_session_factory() as session:
for message_id in message_ids:
result = session.query(Message).filter(Message.id == message_id).first()
if result is not None:
api_client_id = result.api_client_id
text = result.payload.payload.text
api_client = session.query(ApiClient).filter(ApiClient.id == api_client_id).first()
if api_client is not None and text is not None:
toxicity(text=text, message_id=message_id, api_client=api_client.__dict__)
# to not get rate limited from HF
time.sleep(10)
except Exception as e:
logger.error(str(e))
logger.debug("Done: find_and_update_toxicity")
def main():
message_ids = get_messageids_without_toxicity()
find_and_update_toxicity(message_ids=message_ids)
message_ids = get_messageids_without_embedding()
find_and_update_embeddings(message_ids=message_ids)
return
if __name__ == "__main__":
main()