Skip to content

Commit f7628a9

Browse files
authored
Inference worker inform backend on safety intervention (LAION-AI#2505)
1 parent a18aa70 commit f7628a9

File tree

10 files changed

+135
-14
lines changed

10 files changed

+135
-14
lines changed

docker/inference/Dockerfile.safety

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib
5555
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .
5656

5757

58-
CMD python3 __main__.py
58+
CMD python3 main.py
5959

6060
FROM base-env as prod
6161
ARG APP_USER

inference/safety/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
git+https://github.com/LAION-AI/blade2blade@8fd43bcbc5ff35fd59663c77ef08b3ec6c239dd4#egg=blade2blade
1+
blade2blade
22
fastapi
33
loguru
44
pydantic
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Add safe_content to message
2+
3+
Revision ID: ea19bbc743f9
4+
Revises: 401eef162771
5+
Create Date: 2023-04-14 22:37:41.373382
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "ea19bbc743f9"
14+
down_revision = "401eef162771"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column("message", sa.Column("safe_content", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
22+
op.add_column("message", sa.Column("safety_level", sa.Integer(), nullable=True))
23+
op.add_column("message", sa.Column("safety_label", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
24+
op.add_column("message", sa.Column("safety_rots", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.drop_column("message", "safe_content")
31+
op.drop_column("message", "safety_level")
32+
op.drop_column("message", "safety_label")
33+
op.drop_column("message", "safety_rots")
34+
# ### end Alembic commands ###

inference/server/oasst_inference_server/chat_repository.py

+9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ async def get_assistant_message_by_id(self, message_id: str) -> models.DbMessage
2626
message = (await self.session.exec(query)).one()
2727
return message
2828

29+
async def get_prompter_message_by_id(self, message_id: str) -> models.DbMessage:
30+
query = (
31+
sqlmodel.select(models.DbMessage)
32+
.options(sqlalchemy.orm.selectinload(models.DbMessage.reports))
33+
.where(models.DbMessage.id == message_id, models.DbMessage.role == "prompter")
34+
)
35+
message = (await self.session.exec(query)).one()
36+
return message
37+
2938
async def start_work(
3039
self, *, message_id: str, worker_id: str, worker_config: inference.WorkerConfig
3140
) -> models.DbMessage:

inference/server/oasst_inference_server/models/chat.py

+9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class DbMessage(SQLModel, table=True):
2323
content: str | None = Field(None)
2424
error: str | None = Field(None)
2525

26+
safe_content: str | None = Field(None)
27+
safety_level: int | None = Field(None)
28+
safety_label: str | None = Field(None)
29+
safety_rots: str | None = Field(None)
30+
2631
state: inference.MessageState = Field(inference.MessageState.manual)
2732
work_parameters: inference.WorkParameters = Field(None, sa_column=sa.Column(pg.JSONB))
2833
work_begin_at: datetime.datetime | None = Field(None)
@@ -59,6 +64,10 @@ def to_read(self) -> inference.MessageRead:
5964
score=self.score,
6065
work_parameters=self.work_parameters,
6166
reports=[r.to_read() for r in self.reports],
67+
safe_content=self.safe_content,
68+
safety_level=self.safety_level,
69+
safety_label=self.safety_label,
70+
safety_rots=self.safety_rots,
6271
)
6372

6473

inference/server/oasst_inference_server/routes/chats.py

+8
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ async def event_generator(chat_id: str, message_id: str, worker_compat_hash: str
234234
)
235235
break
236236

237+
if response_packet.response_type == "safe_prompt":
238+
logger.info(f"Received safety intervention for {chat_id}")
239+
yield {
240+
"data": chat_schema.SafePromptResponseEvent(
241+
safe_prompt=response_packet.safe_prompt,
242+
).json(),
243+
}
244+
237245
if response_packet.response_type == "internal_error":
238246
yield {
239247
"data": chat_schema.ErrorResponseEvent(

inference/server/oasst_inference_server/routes/workers.py

+28
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ def _add_receive(ftrs: set):
211211
response=worker_response,
212212
)
213213
await _update_session(worker_response.metrics)
214+
case "safe_prompt":
215+
logger.info("Received safe prompt response")
216+
worker_response = cast(inference.SafePromptResponse, worker_response)
217+
await handle_safe_prompt_response(
218+
response=worker_response,
219+
work_request_map=work_request_map,
220+
)
214221
case _:
215222
raise RuntimeError(f"Unknown response type: {worker_response.response_type}")
216223
finally:
@@ -387,6 +394,27 @@ async def handle_general_error_response(
387394
logger.warning(f"Got general error {response=}")
388395

389396

397+
async def handle_safe_prompt_response(
398+
response: inference.SafePromptResponse,
399+
work_request_map: WorkRequestContainerMap,
400+
):
401+
"""
402+
Handle the case where the worker informs the server that the safety model has intervened and modified the user prompt to be safe.
403+
"""
404+
work_response_container = get_work_request_container(work_request_map, response.request_id)
405+
message_id = work_response_container.message_id
406+
407+
async with deps.manual_create_session() as session:
408+
cr = chat_repository.ChatRepository(session=session)
409+
message = await cr.get_assistant_message_by_id(message_id)
410+
prompt = await cr.get_prompter_message_by_id(message.parent_id)
411+
prompt.safe_content = response.safe_prompt
412+
prompt.safety_level = response.safety_parameters.level
413+
prompt.safety_label = response.safety_label
414+
prompt.safety_rots = response.safety_rots
415+
await session.commit()
416+
417+
390418
async def handle_timeout(message: inference.MessageRead):
391419
response = inference.InternalErrorResponse(
392420
error="Timeout",

inference/server/oasst_inference_server/schemas/chat.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@ class MessageResponseEvent(pydantic.BaseModel):
3838
message: inference.MessageRead
3939

4040

41+
class SafePromptResponseEvent(pydantic.BaseModel):
42+
event_type: Literal["safe_prompt"] = "safe_prompt"
43+
safe_prompt: str
44+
message: inference.MessageRead
45+
46+
4147
ResponseEvent = Annotated[
42-
Union[TokenResponseEvent, ErrorResponseEvent, MessageResponseEvent], pydantic.Field(discriminator="event_type")
48+
Union[TokenResponseEvent, ErrorResponseEvent, MessageResponseEvent, SafePromptResponseEvent],
49+
pydantic.Field(discriminator="event_type"),
4350
]
4451

4552

inference/worker/work.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,22 @@ def _prepare_message(message: inference.MessageRead) -> str:
8080
return prompt, parameters
8181

8282

83-
def prepare_safe_prompt(prompt: str, label: str, rots: str):
83+
def prepare_safe_prompt(prompt: str, label: str, rots: str) -> str:
8484
pre_prompt = f"Answer the following request with {label} as responsible chatbot that believes that {rots}: "
8585
input_list = prompt.split(V2_PROMPTER_PREFIX)
8686
input_list[-1] = pre_prompt + input_list[-1]
8787
return V2_PROMPTER_PREFIX.join(input_list)
8888

8989

90-
def get_safety_opinion(prompt: str, safety_opinion: str, safety_level: int):
90+
def is_safety_triggered(safety_label: str, safety_level: int) -> bool:
91+
return ("caution" in safety_label and safety_level > 1) or ("intervention" in safety_label and safety_level > 0)
92+
93+
94+
def parse_safety_response(safety_opinion: str) -> tuple[str, str]:
9195
safety_opinion = re.sub(r"<pad>|</s>", "", safety_opinion).split("<sep>")
9296
label, rots = safety_opinion[0], "and".join([x.strip(".") for x in safety_opinion[1:]])
9397
label = label.replace("<pad>", "").strip()
94-
95-
if "caution" in label and safety_level > 1:
96-
return prepare_safe_prompt(prompt, label, rots)
97-
elif "intervention" in label and safety_level > 0:
98-
return prepare_safe_prompt(prompt, label, rots)
99-
else:
100-
return prompt
98+
return label, rots
10199

102100

103101
def handle_work_request(
@@ -115,8 +113,23 @@ def handle_work_request(
115113
if settings.enable_safety and work_request.safety_parameters.level:
116114
safety_request = inference.SafetyRequest(inputs=prompt, parameters=work_request.safety_parameters)
117115
safety_response = get_safety_server_response(safety_request)
118-
prompt = get_safety_opinion(prompt, safety_response.outputs, work_request.safety_parameters.level)
119-
logger.debug(f"Safe prompt: {prompt}")
116+
safety_label, safety_rots = parse_safety_response(safety_response.outputs)
117+
118+
if is_safety_triggered(safety_label, work_request.safety_parameters.level):
119+
prompt = prepare_safe_prompt(prompt, safety_label, safety_rots)
120+
121+
utils.send_response(
122+
ws,
123+
inference.SafePromptResponse(
124+
request_id=work_request.id,
125+
safe_prompt=prompt,
126+
safety_parameters=work_request.safety_parameters,
127+
safety_label=safety_label,
128+
safety_rots=safety_rots,
129+
),
130+
)
131+
132+
logger.debug(f"Safe prompt: {prompt}")
120133

121134
stream_response = None
122135
token_buffer = utils.TokenBuffer(stop_sequences=parameters.stop)

oasst-shared/oasst_shared/schemas/inference.py

+13
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ class MessageRead(pydantic.BaseModel):
169169
reports: list[Report] = []
170170
# work parameters will be None on user prompts
171171
work_parameters: WorkParameters | None
172+
safe_content: str | None
173+
safety_level: int | None
174+
safety_label: str | None
175+
safety_rots: str | None
172176

173177
@property
174178
def is_assistant(self) -> bool:
@@ -240,6 +244,14 @@ class PongResponse(WorkerResponseBase):
240244
metrics: WorkerMetricsInfo | None = None
241245

242246

247+
class SafePromptResponse(WorkerResponseBase):
248+
response_type: Literal["safe_prompt"] = "safe_prompt"
249+
safe_prompt: str
250+
safety_parameters: SafetyParameters
251+
safety_label: str
252+
safety_rots: str
253+
254+
243255
class TokenResponse(WorkerResponseBase):
244256
response_type: Literal["token"] = "token"
245257
text: str
@@ -298,6 +310,7 @@ class GeneralErrorResponse(WorkerResponseBase):
298310
PongResponse,
299311
InternalFinishedMessageResponse,
300312
InternalErrorResponse,
313+
SafePromptResponse,
301314
],
302315
pydantic.Field(discriminator="response_type"),
303316
]

0 commit comments

Comments
 (0)