Skip to content

Commit 0f2a897

Browse files
committed
added tasks to act as user or assistant
1 parent 0846682 commit 0f2a897

File tree

3 files changed

+145
-24
lines changed

3 files changed

+145
-24
lines changed

backend/app/api/v1/tasks.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
1818
match (request.type):
19-
case protocol_schema.TaskRequestType.generic:
20-
logger.info("Frontend requested a generic task.")
21-
while request.type == protocol_schema.TaskRequestType.generic:
19+
case protocol_schema.TaskRequestType.random:
20+
logger.info("Frontend requested a random task.")
21+
while request.type == protocol_schema.TaskRequestType.random:
2222
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
2323
return generate_task(request)
2424
case protocol_schema.TaskRequestType.summarize_story:
@@ -38,14 +38,42 @@ def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
3838
task = protocol_schema.InitialPromptTask(
3939
hint="Ask the assistant about a current event." # this is optional
4040
)
41+
case protocol_schema.TaskRequestType.user_reply:
42+
logger.info("Generating a UserReplyTask.")
43+
task = protocol_schema.UserReplyTask(
44+
conversation=protocol_schema.Conversation(
45+
messages=[
46+
protocol_schema.ConversationMessage(
47+
text="Hey, assistant, what's going on in the world?",
48+
is_assistant=False,
49+
),
50+
protocol_schema.ConversationMessage(
51+
text="I'm not sure I understood correctly, could you rephrase that?",
52+
is_assistant=True,
53+
),
54+
],
55+
)
56+
)
57+
case protocol_schema.TaskRequestType.assistant_reply:
58+
logger.info("Generating a AssistantReplyTask.")
59+
task = protocol_schema.AssistantReplyTask(
60+
conversation=protocol_schema.Conversation(
61+
messages=[
62+
protocol_schema.ConversationMessage(
63+
text="Hey, assistant, write me an English essay about water.",
64+
is_assistant=False,
65+
),
66+
],
67+
)
68+
)
4169
case _:
4270
raise HTTPException(
4371
status_code=HTTP_400_BAD_REQUEST,
4472
detail="Invalid request type.",
4573
)
4674
logger.info(f"Generated {task=}.")
4775
if request.user is not None:
48-
task.addressed_users = [request.user]
76+
task.addressed_user = request.user
4977

5078
return task
5179

@@ -122,7 +150,7 @@ def post_interaction(
122150
# here we would store the text reply in the database
123151
return protocol_schema.TaskDone(
124152
reply_to_post_id=interaction.user_post_id,
125-
addressed_users=[interaction.user],
153+
addressed_user=interaction.user,
126154
)
127155
case protocol_schema.PostRating:
128156
logger.info(
@@ -132,7 +160,7 @@ def post_interaction(
132160
# here we would store the rating in the database
133161
return protocol_schema.TaskDone(
134162
reply_to_post_id=interaction.post_id,
135-
addressed_users=[interaction.user],
163+
addressed_user=interaction.user,
136164
)
137165
case _:
138166
raise HTTPException(

backend/app/schemas/protocol.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,37 @@
88

99

1010
class TaskRequestType(str, enum.Enum):
11-
generic = "generic"
11+
random = "random"
1212
summarize_story = "summarize_story"
1313
rate_summary = "rate_summary"
1414
initial_prompt = "initial_prompt"
15+
user_reply = "user_reply"
16+
assistant_reply = "assistant_reply"
1517

1618

1719
class User(BaseModel):
1820
id: str
19-
name: str
21+
display_name: str
22+
auth_method: Literal["discord", "local"]
23+
24+
25+
class ConversationMessage(BaseModel):
26+
"""Represents a message in a conversation between the user and the assistant."""
27+
28+
text: str
29+
is_assistant: bool
30+
31+
32+
class Conversation(BaseModel):
33+
"""Represents a conversation between the user and the assistant."""
34+
35+
messages: list[ConversationMessage] = []
2036

2137

2238
class TaskRequest(BaseModel):
2339
"""The frontend asks the backend for a task."""
2440

25-
type: TaskRequestType = TaskRequestType.generic
41+
type: TaskRequestType = TaskRequestType.random
2642
user: Optional[User] = None
2743

2844

@@ -31,7 +47,7 @@ class Task(BaseModel):
3147

3248
id: UUID = pydantic.Field(default_factory=uuid4)
3349
type: str
34-
addressed_users: Optional[list[User]] = None
50+
addressed_user: Optional[User] = None
3551

3652

3753
class TaskResponse(BaseModel):
@@ -91,6 +107,21 @@ class InitialPromptTask(Task):
91107
)
92108

93109

110+
class UserReplyTask(Task):
111+
"""A task to prompt the user to submit a reply to the assistant."""
112+
113+
type: Literal["user_reply"] = "user_reply"
114+
conversation: Conversation # the conversation so far
115+
hint: str | None = None # e.g. "Try to ask for clarification."
116+
117+
118+
class AssistantReplyTask(Task):
119+
"""A task to prompt the user to act as the assistant."""
120+
121+
type: Literal["assistant_reply"] = "assistant_reply"
122+
conversation: Conversation # the conversation so far
123+
124+
94125
class TaskDone(Task):
95126
"""Signals to the frontend that the task is done."""
96127

@@ -99,10 +130,12 @@ class TaskDone(Task):
99130

100131

101132
AnyTask = Union[
133+
TaskDone,
102134
SummarizeStoryTask,
103135
RateSummaryTask,
104136
InitialPromptTask,
105-
TaskDone,
137+
UserReplyTask,
138+
AssistantReplyTask,
106139
]
107140

108141

text-frontend/__main__.py

+73-13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
app = typer.Typer()
88

99

10+
# debug constants
11+
POST_ID = "1234"
12+
USER_POST_ID = "5678"
13+
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
14+
15+
16+
def _render_message(message: dict) -> str:
17+
"""Render a message to the user."""
18+
if message["is_assistant"]:
19+
return f"Assistant: {message['text']}"
20+
return f"User: {message['text']}"
21+
22+
1023
@app.command()
1124
def main(backend_url: str, api_key: str):
1225
"""Simple REPL frontend."""
@@ -17,7 +30,7 @@ def _post(path: str, json: dict) -> dict:
1730
return response.json()
1831

1932
typer.echo("Requesting work...")
20-
tasks = [_post("/api/v1/tasks/", {"type": "generic"})]
33+
tasks = [_post("/api/v1/tasks/", {"type": "random"})]
2134
while tasks:
2235
task = tasks.pop(0)
2336
match (task["type"]):
@@ -26,7 +39,7 @@ def _post(path: str, json: dict) -> dict:
2639
typer.echo(task["story"])
2740

2841
# acknowledge task
29-
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"})
42+
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
3043

3144
summary = typer.prompt("Enter your summary")
3245

@@ -35,10 +48,10 @@ def _post(path: str, json: dict) -> dict:
3548
"/api/v1/tasks/interaction",
3649
{
3750
"type": "text_reply_to_post",
38-
"post_id": "1234",
39-
"user_post_id": "5678",
51+
"post_id": POST_ID,
52+
"user_post_id": USER_POST_ID,
4053
"text": summary,
41-
"user": {"id": "1234", "name": "John Doe"},
54+
"user": USER,
4255
},
4356
)
4457
tasks.append(new_task)
@@ -50,17 +63,17 @@ def _post(path: str, json: dict) -> dict:
5063
typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}")
5164

5265
# acknowledge task
53-
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": "1234"})
66+
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "rating_created", "post_id": POST_ID})
5467

5568
rating = typer.prompt("Enter your rating", type=int)
5669
# send interaction
5770
new_task = _post(
5871
"/api/v1/tasks/interaction",
5972
{
6073
"type": "post_rating",
61-
"post_id": "1234",
74+
"post_id": POST_ID,
6275
"rating": rating,
63-
"user": {"id": "1234", "name": "John Doe"},
76+
"user": USER,
6477
},
6578
)
6679
tasks.append(new_task)
@@ -69,23 +82,70 @@ def _post(path: str, json: dict) -> dict:
6982
if task["hint"]:
7083
typer.echo(f"Hint: {task['hint']}")
7184
# acknowledge task
72-
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": "1234"})
85+
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
7386
prompt = typer.prompt("Enter your prompt")
7487
# send interaction
7588
new_task = _post(
7689
"/api/v1/tasks/interaction",
7790
{
7891
"type": "text_reply_to_post",
79-
"post_id": "1234",
80-
"user_post_id": "5678",
92+
"post_id": POST_ID,
93+
"user_post_id": USER_POST_ID,
8194
"text": prompt,
82-
"user": {"id": "1234", "name": "John Doe"},
95+
"user": USER,
96+
},
97+
)
98+
tasks.append(new_task)
99+
100+
case "user_reply":
101+
typer.echo("Please provide a reply to the assistant.")
102+
typer.echo("Here is the conversation so far:")
103+
for message in task["conversation"]["messages"]:
104+
typer.echo(_render_message(message))
105+
if task["hint"]:
106+
typer.echo(f"Hint: {task['hint']}")
107+
# acknowledge task
108+
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
109+
reply = typer.prompt("Enter your reply")
110+
# send interaction
111+
new_task = _post(
112+
"/api/v1/tasks/interaction",
113+
{
114+
"type": "text_reply_to_post",
115+
"post_id": POST_ID,
116+
"user_post_id": USER_POST_ID,
117+
"text": reply,
118+
"user": USER,
119+
},
120+
)
121+
tasks.append(new_task)
122+
123+
case "assistant_reply":
124+
typer.echo("Act as the assistant and reply to the user.")
125+
typer.echo("Here is the conversation so far:")
126+
for message in task["conversation"]["messages"]:
127+
typer.echo(_render_message(message))
128+
# acknowledge task
129+
_post(f"/api/v1/tasks/{task['id']}/ack", {"type": "post_created", "post_id": POST_ID})
130+
reply = typer.prompt("Enter your reply")
131+
# send interaction
132+
new_task = _post(
133+
"/api/v1/tasks/interaction",
134+
{
135+
"type": "text_reply_to_post",
136+
"post_id": POST_ID,
137+
"user_post_id": USER_POST_ID,
138+
"text": reply,
139+
"user": USER,
83140
},
84141
)
85142
tasks.append(new_task)
86143

87144
case "task_done":
88-
typer.echo("Task done!")
145+
if addressed_user := task["addressed_user"]:
146+
typer.echo(f"Hey, {addressed_user['display_name']}! Thank you!")
147+
else:
148+
typer.echo("Task done!")
89149
case _:
90150
typer.echo(f"Unknown task type {task['type']}")
91151

0 commit comments

Comments
 (0)