Skip to content

Commit 45f2f3a

Browse files
authored
Validate text labels in text frontend (LAION-AI#495)
* Validate text labels in text frontend * Correct task type
1 parent 5b2cb5d commit 45f2f3a

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

backend/oasst_backend/api/v1/tasks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ def tasks_interaction(
302302
logger.info(
303303
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
304304
)
305-
# TODO: check if the labels are valid?
305+
# Labels are implicitly validated when converting str -> TextLabel
306+
# So no need for explicit validation here
306307
pr.store_text_labels(interaction)
307308
return protocol_schema.TaskDone()
308309
case _:

text-frontend/__main__.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,17 @@ def _post(path: str, json: dict) -> dict:
211211
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
212212

213213
valid_labels = task["valid_labels"]
214-
labels_str: str = typer.prompt("Enter labels, separated by commas")
215-
labels = labels_str.lower().replace(" ", "").split(",")
216-
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
214+
215+
labels_dict = None
216+
while labels_dict is None:
217+
labels_str: str = typer.prompt("Enter labels, separated by commas")
218+
labels = labels_str.lower().replace(" ", "").split(",")
219+
220+
if all([label in valid_labels for label in labels]):
221+
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
222+
else:
223+
invalid_labels = [label for label in labels if label not in valid_labels]
224+
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
217225

218226
# send ranking
219227
new_task = _post(
@@ -240,9 +248,17 @@ def _post(path: str, json: dict) -> dict:
240248
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
241249

242250
valid_labels = task["valid_labels"]
243-
labels_str: str = typer.prompt("Enter labels, separated by commas")
244-
labels = labels_str.lower().replace(" ", "").split(",")
245-
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
251+
252+
labels_dict = None
253+
while labels_dict is None:
254+
labels_str: str = typer.prompt("Enter labels, separated by commas")
255+
labels = labels_str.lower().replace(" ", "").split(",")
256+
257+
if all([label in valid_labels for label in labels]):
258+
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
259+
else:
260+
invalid_labels = [label for label in labels if label not in valid_labels]
261+
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
246262

247263
# send ranking
248264
new_task = _post(

0 commit comments

Comments
 (0)