forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompliance.py
149 lines (127 loc) · 6.61 KB
/
compliance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import datetime
from typing import cast
import fastapi
import sqlmodel
from loguru import logger
from oasst_inference_server import database, deps, models, worker_utils
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from sqlalchemy.sql.functions import random as sql_random
from sqlmodel import not_, or_
async def find_compliance_work_request_message(
session: database.AsyncSession, worker_config: inference.WorkerConfig, worker_id: str
) -> models.DbMessage | None:
compat_hash = worker_config.compat_hash
query = (
sqlmodel.select(models.DbMessage)
.where(
models.DbMessage.role == "assistant",
models.DbMessage.state == inference.MessageState.complete,
models.DbMessage.worker_compat_hash == compat_hash,
models.DbMessage.worker_id != worker_id,
)
.order_by(sql_random())
)
message = (await session.exec(query)).first()
return message
async def should_do_compliance_check(session: database.AsyncSession, worker_id: str) -> bool:
worker = await worker_utils.get_worker(worker_id, session)
if worker.trusted:
return False
if worker.in_compliance_check:
return False
if worker.next_compliance_check is None:
return True
if worker.next_compliance_check < datetime.datetime.utcnow():
return True
return False
async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, worker_config: inference.WorkerConfig):
async with deps.manual_create_session() as session:
try:
worker = await worker_utils.get_worker(worker_id, session)
if worker.in_compliance_check:
logger.info(f"Worker {worker.id} is already in compliance check")
return
worker.in_compliance_check_since = datetime.datetime.utcnow()
finally:
await session.commit()
logger.info(f"Running compliance check for worker {worker_id}")
async with deps.manual_create_session(autoflush=False) as session:
compliance_check = models.DbWorkerComplianceCheck(worker_id=worker_id)
try:
message = await find_compliance_work_request_message(session, worker_config, worker_id)
if message is None:
logger.warning(
f"Could not find message for compliance check for worker {worker_id} with config {worker_config}"
)
return
compliance_check.compare_worker_id = message.worker_id
compliance_work_request = await worker_utils.build_work_request(session, message.id)
logger.info(f"Found work request for compliance check for worker {worker_id}: {compliance_work_request}")
await worker_utils.send_worker_request(websocket, compliance_work_request)
response = None
while True:
response = await worker_utils.receive_worker_response(websocket)
if response.response_type == "error":
compliance_check.responded = True
compliance_check.error = response.error
logger.warning(f"Worker {worker_id} errored during compliance check: {response.error}")
return
if response.response_type == "generated_text":
break
if response is None:
logger.warning(f"Worker {worker_id} did not respond to compliance check")
return
compliance_check.responded = True
response = cast(inference.GeneratedTextResponse, response)
passes = response.text == message.content
compliance_check.passed = passes
logger.info(f"Worker {worker_id} passed compliance check: {passes}")
finally:
compliance_check.end_time = datetime.datetime.utcnow()
session.add(compliance_check)
worker = await worker_utils.get_worker(worker_id, session)
worker.next_compliance_check = datetime.datetime.utcnow() + datetime.timedelta(
seconds=settings.compliance_check_interval
)
worker.in_compliance_check_since = None
logger.info(f"set next compliance check for worker {worker_id} to {worker.next_compliance_check}")
await session.commit()
await session.flush()
async def maybe_do_compliance_check(websocket, worker_id, worker_config, worker_session_id):
async with deps.manual_create_session() as session:
should_check = await should_do_compliance_check(session, worker_id)
if should_check:
logger.info(f"Worker {worker_id} needs compliance check")
try:
await worker_utils.update_worker_session_status(
worker_session_id, worker_utils.WorkerSessionStatus.compliance_check
)
await run_compliance_check(websocket, worker_id, worker_config)
finally:
await worker_utils.update_worker_session_status(worker_session_id, worker_utils.WorkerSessionStatus.waiting)
async def compute_worker_compliance_score(worker_id: str) -> float:
"""
Compute a float between 0 and 1 (inclusive) representing the compliance score of the worker.
Workers are rewarded for passing compliance checks, and penalised for failing to respond to a check, erroring during a check, or failing a check.
In-progress checks are ignored.
"""
async with deps.manual_create_session() as session:
query = sqlmodel.select(models.DbWorkerComplianceCheck).where(
or_(
models.DbWorkerComplianceCheck.worker_id == worker_id,
models.DbWorkerComplianceCheck.compare_worker_id == worker_id,
),
not_(models.DbWorkerComplianceCheck.end_time.is_(None)),
)
worker_checks: list[models.DbWorkerComplianceCheck] = (await session.exec(query)).all()
# Rudimentary scoring algorithm, we may want to add weightings or other factors
total_count = len(worker_checks)
checked = [c for c in worker_checks if c.worker_id == worker_id]
compared = [c for c in worker_checks if c.compare_worker_id == worker_id]
pass_count = sum(1 for _ in filter(lambda c: c.passed, checked))
error_count = sum(1 for _ in filter(lambda c: c.error is not None, checked))
no_response_count = sum(1 for _ in filter(lambda c: not c.responded, checked))
compare_fail_count = sum(1 for _ in filter(lambda c: not c.passed, compared))
fail_count = len(checked) - pass_count - error_count - no_response_count
return (fail_count + compare_fail_count) / total_count