@@ -80,24 +80,22 @@ def _prepare_message(message: inference.MessageRead) -> str:
80
80
return prompt , parameters
81
81
82
82
83
- def prepare_safe_prompt (prompt : str , label : str , rots : str ):
83
+ def prepare_safe_prompt (prompt : str , label : str , rots : str ) -> str :
84
84
pre_prompt = f"Answer the following request with { label } as responsible chatbot that believes that { rots } : "
85
85
input_list = prompt .split (V2_PROMPTER_PREFIX )
86
86
input_list [- 1 ] = pre_prompt + input_list [- 1 ]
87
87
return V2_PROMPTER_PREFIX .join (input_list )
88
88
89
89
90
- def get_safety_opinion (prompt : str , safety_opinion : str , safety_level : int ):
90
+ def is_safety_triggered (safety_label : str , safety_level : int ) -> bool :
91
+ return ("caution" in safety_label and safety_level > 1 ) or ("intervention" in safety_label and safety_level > 0 )
92
+
93
+
94
+ def parse_safety_response (safety_opinion : str ) -> tuple [str , str ]:
91
95
safety_opinion = re .sub (r"<pad>|</s>" , "" , safety_opinion ).split ("<sep>" )
92
96
label , rots = safety_opinion [0 ], "and" .join ([x .strip ("." ) for x in safety_opinion [1 :]])
93
97
label = label .replace ("<pad>" , "" ).strip ()
94
-
95
- if "caution" in label and safety_level > 1 :
96
- return prepare_safe_prompt (prompt , label , rots )
97
- elif "intervention" in label and safety_level > 0 :
98
- return prepare_safe_prompt (prompt , label , rots )
99
- else :
100
- return prompt
98
+ return label , rots
101
99
102
100
103
101
def handle_work_request (
@@ -115,8 +113,23 @@ def handle_work_request(
115
113
if settings .enable_safety and work_request .safety_parameters .level :
116
114
safety_request = inference .SafetyRequest (inputs = prompt , parameters = work_request .safety_parameters )
117
115
safety_response = get_safety_server_response (safety_request )
118
- prompt = get_safety_opinion (prompt , safety_response .outputs , work_request .safety_parameters .level )
119
- logger .debug (f"Safe prompt: { prompt } " )
116
+ safety_label , safety_rots = parse_safety_response (safety_response .outputs )
117
+
118
+ if is_safety_triggered (safety_label , work_request .safety_parameters .level ):
119
+ prompt = prepare_safe_prompt (prompt , safety_label , safety_rots )
120
+
121
+ utils .send_response (
122
+ ws ,
123
+ inference .SafePromptResponse (
124
+ request_id = work_request .id ,
125
+ safe_prompt = prompt ,
126
+ safety_parameters = work_request .safety_parameters ,
127
+ safety_label = safety_label ,
128
+ safety_rots = safety_rots ,
129
+ ),
130
+ )
131
+
132
+ logger .debug (f"Safe prompt: { prompt } " )
120
133
121
134
stream_response = None
122
135
token_buffer = utils .TokenBuffer (stop_sequences = parameters .stop )
0 commit comments